feat: OIDC SSO provider support (Keycloak, Azure AD, custom)
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 52s
CI Pipeline / Rust Unit Tests (push) Successful in 1m11s
CI Pipeline / Security Audit (push) Successful in 5s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 15s
CI Pipeline / Build .deb & Release (push) Has been skipped
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 52s
CI Pipeline / Rust Unit Tests (push) Successful in 1m11s
CI Pipeline / Security Audit (push) Successful in 5s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 15s
CI Pipeline / Build .deb & Release (push) Has been skipped
- Refactored azure_sso.rs to sso.rs with generic OIDC provider support - Added OIDC discovery URL lookup with 1hr TTL caching - Added PKCE for all providers, client_secret optional for public clients - Added /api/v1/auth/sso/login and /api/v1/auth/sso/callback routes - Added /api/v1/auth/azure/* backward-compatible routes - Added POST /settings/sso/discover and POST /settings/sso/test endpoints - Frontend: Provider dropdown (Keycloak/Azure AD/Custom OIDC) - Frontend: Auto-fill discovery URL for Keycloak - Frontend: Discover Endpoints and Test Connection buttons - Frontend: Dynamic SSO button based on provider display name - Made migration 014 idempotent with DO blocks and IF NOT EXISTS - Fixed debian/install to use /usr/local/bin/ for binaries - Fixed frontend file path in .deb package - Reset admin password on dev server - Fixed database permissions for oidc_config table
This commit is contained in:
@ -10,7 +10,7 @@ use pm_auth::{
|
||||
rbac::{require_auth, AuthConfig},
|
||||
};
|
||||
use pm_core::{config::AppConfig, db, logging, request_id::request_id_middleware};
|
||||
use routes::azure_sso::{JwksCache, SsoSession};
|
||||
use routes::sso::{OidcCache, SsoSession};
|
||||
use routes::ws::WsTicket;
|
||||
use serde_json::{json, Value};
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
@ -31,8 +31,8 @@ pub struct AppState {
|
||||
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
|
||||
/// In-memory store for SSO PKCE sessions (state → code_verifier).
|
||||
pub sso_sessions: Arc<DashMap<String, SsoSession>>,
|
||||
/// Cached Azure AD JWKS for id_token signature verification.
|
||||
pub jwks_cache: Arc<Mutex<JwksCache>>,
|
||||
/// Cached OIDC discovery document and JWKS for SSO id_token verification.
|
||||
pub oidc_cache: Arc<Mutex<OidcCache>>,
|
||||
/// Internal certificate authority for mTLS client cert issuance.
|
||||
pub ca: Arc<pm_ca::CertAuthority>,
|
||||
}
|
||||
@ -90,7 +90,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
|
||||
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
|
||||
let jwks_cache: Arc<Mutex<JwksCache>> = Arc::new(Mutex::new(JwksCache::default()));
|
||||
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
|
||||
|
||||
// Background task: purge expired WS tickets every 30 seconds.
|
||||
{
|
||||
@ -137,7 +137,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
ws_tickets,
|
||||
sso_sessions,
|
||||
ca: Arc::new(ca),
|
||||
jwks_cache,
|
||||
oidc_cache,
|
||||
};
|
||||
|
||||
let app = build_router(state);
|
||||
@ -234,7 +234,7 @@ pub fn build_router(state: AppState) -> Router {
|
||||
// Public auth routes (no JWT needed)
|
||||
.nest("/api/v1/auth", routes::auth::public_router())
|
||||
// Public Azure SSO routes (no JWT needed)
|
||||
.nest("/api/v1/auth/azure", routes::azure_sso::public_router())
|
||||
.nest("/api/v1/auth/azure", routes::sso::azure_compat_router())
|
||||
// Protected API routes (JWT required)
|
||||
.nest("/api/v1", protected_api)
|
||||
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
//! Route modules for the pm-web API.
|
||||
pub mod auth;
|
||||
pub mod azure_sso;
|
||||
pub mod ca;
|
||||
pub mod discovery;
|
||||
pub mod groups;
|
||||
@ -9,6 +8,7 @@ pub mod hosts;
|
||||
pub mod jobs;
|
||||
pub mod maintenance_windows;
|
||||
pub mod settings;
|
||||
pub mod sso;
|
||||
pub mod status;
|
||||
pub mod users;
|
||||
pub mod ws;
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
//!
|
||||
//! GET /api/v1/settings — get all settings (admin only)
|
||||
//! PUT /api/v1/settings — update settings (admin only)
|
||||
//! POST /api/v1/settings/azure-sso/test — test Azure SSO connectivity (admin only)
|
||||
//! POST /api/v1/settings/sso/discover — discover OIDC endpoints (admin only)
|
||||
//! POST /api/v1/settings/sso/test — test OIDC provider connectivity (admin only)
|
||||
//! POST /api/v1/settings/azure-sso/test — backward-compat alias for SSO test (admin only)
|
||||
//! POST /api/v1/settings/smtp/test — send test email (admin only)
|
||||
//! GET /api/v1/settings/ip-whitelist — get IP whitelist (admin only)
|
||||
//! PUT /api/v1/settings/ip-whitelist — update IP whitelist (admin only)
|
||||
@ -34,7 +36,7 @@ use crate::AppState;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SettingsResponse {
|
||||
pub azure_sso: AzureSsoConfig,
|
||||
pub oidc: OidcConfigResponse,
|
||||
pub smtp: SmtpConfig,
|
||||
pub polling: PollingConfig,
|
||||
pub ip_whitelist: Vec<String>,
|
||||
@ -44,10 +46,13 @@ pub struct SettingsResponse {
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AzureSsoConfig {
|
||||
pub struct OidcConfigResponse {
|
||||
pub enabled: bool,
|
||||
pub tenant_id: String,
|
||||
pub provider_type: String, // "keycloak", "azure", "custom"
|
||||
pub display_name: String,
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String, // Always masked in responses
|
||||
pub redirect_uri: String,
|
||||
pub scopes: String,
|
||||
}
|
||||
@ -70,7 +75,7 @@ pub struct PollingConfig {
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateSettingsRequest {
|
||||
pub azure_sso: Option<AzureSsoConfigUpdate>,
|
||||
pub oidc: Option<OidcConfigUpdate>,
|
||||
pub smtp: Option<SmtpConfigUpdate>,
|
||||
pub polling: Option<PollingConfigUpdate>,
|
||||
pub ip_whitelist: Option<Vec<String>>,
|
||||
@ -93,15 +98,31 @@ pub struct NotificationConfigUpdate {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AzureSsoConfigUpdate {
|
||||
pub struct OidcConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub tenant_id: Option<String>,
|
||||
pub provider_type: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub discovery_url: Option<String>,
|
||||
pub client_id: Option<String>,
|
||||
pub client_secret: Option<String>,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub scopes: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcDiscoveryRequest {
|
||||
pub discovery_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct OidcDiscoveryResult {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmtpConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
@ -131,7 +152,9 @@ pub struct IpWhitelistUpdate {
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(get_settings).put(update_settings))
|
||||
.route("/azure-sso/test", post(test_azure_sso))
|
||||
.route("/sso/discover", post(discover_oidc))
|
||||
.route("/sso/test", post(test_oidc))
|
||||
.route("/azure-sso/test", post(test_azure_sso_compat))
|
||||
.route("/smtp/test", post(test_smtp))
|
||||
.route(
|
||||
"/ip-whitelist",
|
||||
@ -175,7 +198,7 @@ async fn load_system_config(
|
||||
|
||||
fn build_settings_response(
|
||||
cfg: &HashMap<String, String>,
|
||||
azure: AzureSsoConfig,
|
||||
oidc: OidcConfigResponse,
|
||||
) -> SettingsResponse {
|
||||
let get = |key: &str| -> String { cfg.get(key).cloned().unwrap_or_default() };
|
||||
|
||||
@ -183,7 +206,7 @@ fn build_settings_response(
|
||||
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
|
||||
|
||||
SettingsResponse {
|
||||
azure_sso: azure,
|
||||
oidc,
|
||||
smtp: SmtpConfig {
|
||||
enabled: get("smtp_enabled") == "true",
|
||||
host: get("smtp_host"),
|
||||
@ -227,16 +250,16 @@ async fn update_config_key(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_azure_sso_config(
|
||||
async fn fetch_oidc_config(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Result<AzureSsoConfig, (StatusCode, Json<Value>)> {
|
||||
let row: Option<(bool, String, String, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, tenant_id, client_id, redirect_uri, scopes FROM azure_sso_config WHERE id = 1",
|
||||
) -> Result<OidcConfigResponse, (StatusCode, Json<Value>)> {
|
||||
let row: Option<(bool, String, String, String, String, String, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load azure_sso_config");
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
@ -244,19 +267,38 @@ async fn fetch_azure_sso_config(
|
||||
})?;
|
||||
|
||||
Ok(match row {
|
||||
Some((enabled, tenant_id, client_id, redirect_uri, scopes)) => AzureSsoConfig {
|
||||
Some((
|
||||
enabled,
|
||||
tenant_id,
|
||||
provider_type,
|
||||
display_name,
|
||||
discovery_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
scopes,
|
||||
)) => OidcConfigResponse {
|
||||
enabled,
|
||||
provider_type,
|
||||
display_name,
|
||||
discovery_url,
|
||||
client_id,
|
||||
client_secret: if client_secret.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
MASKED.to_string()
|
||||
},
|
||||
redirect_uri,
|
||||
scopes,
|
||||
},
|
||||
None => AzureSsoConfig {
|
||||
None => OidcConfigResponse {
|
||||
enabled: false,
|
||||
tenant_id: String::new(),
|
||||
provider_type: "azure".to_string(),
|
||||
display_name: "Azure AD".to_string(),
|
||||
discovery_url: String::new(),
|
||||
client_id: String::new(),
|
||||
client_secret: String::new(),
|
||||
redirect_uri: String::new(),
|
||||
scopes: "openid email profile".to_string(),
|
||||
scopes: "openid profile email".to_string(),
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -277,8 +319,8 @@ async fn get_settings(
|
||||
"sso_callback_url".to_string(),
|
||||
state.config.security.sso_callback_url.clone(),
|
||||
);
|
||||
let azure = fetch_azure_sso_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, azure)))
|
||||
let oidc = fetch_oidc_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, oidc)))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
@ -292,56 +334,66 @@ async fn update_settings(
|
||||
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
|
||||
admin_only(&auth)?;
|
||||
|
||||
// Update Azure SSO config
|
||||
// Use static queries with proper typed bindings to avoid boolean→string mismatch
|
||||
if let Some(azure) = req.azure_sso {
|
||||
let update_secret = azure.client_secret.as_ref().is_some_and(|s| s != MASKED);
|
||||
// Update OIDC config
|
||||
if let Some(oidc) = req.oidc {
|
||||
let update_secret = oidc
|
||||
.client_secret
|
||||
.as_ref()
|
||||
.is_some_and(|s| s != MASKED && !s.is_empty());
|
||||
|
||||
let result = if update_secret {
|
||||
sqlx::query(
|
||||
"UPDATE azure_sso_config SET \
|
||||
"UPDATE oidc_config SET \
|
||||
enabled = COALESCE($1, enabled), \
|
||||
tenant_id = COALESCE($2, tenant_id), \
|
||||
client_id = COALESCE($3, client_id), \
|
||||
client_secret = $4, \
|
||||
redirect_uri = COALESCE($5, redirect_uri), \
|
||||
scopes = COALESCE($6, scopes), \
|
||||
provider_type = COALESCE($2, provider_type), \
|
||||
display_name = COALESCE($3, display_name), \
|
||||
discovery_url = COALESCE($4, discovery_url), \
|
||||
client_id = COALESCE($5, client_id), \
|
||||
client_secret = $6, \
|
||||
redirect_uri = COALESCE($7, redirect_uri), \
|
||||
scopes = COALESCE($8, scopes), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = 1",
|
||||
)
|
||||
.bind(azure.enabled)
|
||||
.bind(&azure.tenant_id)
|
||||
.bind(&azure.client_id)
|
||||
.bind(azure.client_secret.as_deref().unwrap_or(""))
|
||||
.bind(&azure.redirect_uri)
|
||||
.bind(&azure.scopes)
|
||||
.bind(oidc.enabled)
|
||||
.bind(&oidc.provider_type)
|
||||
.bind(&oidc.display_name)
|
||||
.bind(&oidc.discovery_url)
|
||||
.bind(&oidc.client_id)
|
||||
.bind(oidc.client_secret.as_deref().unwrap_or(""))
|
||||
.bind(&oidc.redirect_uri)
|
||||
.bind(&oidc.scopes)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE azure_sso_config SET \
|
||||
"UPDATE oidc_config SET \
|
||||
enabled = COALESCE($1, enabled), \
|
||||
tenant_id = COALESCE($2, tenant_id), \
|
||||
client_id = COALESCE($3, client_id), \
|
||||
redirect_uri = COALESCE($4, redirect_uri), \
|
||||
scopes = COALESCE($5, scopes), \
|
||||
provider_type = COALESCE($2, provider_type), \
|
||||
display_name = COALESCE($3, display_name), \
|
||||
discovery_url = COALESCE($4, discovery_url), \
|
||||
client_id = COALESCE($5, client_id), \
|
||||
redirect_uri = COALESCE($6, redirect_uri), \
|
||||
scopes = COALESCE($7, scopes), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = 1",
|
||||
)
|
||||
.bind(azure.enabled)
|
||||
.bind(&azure.tenant_id)
|
||||
.bind(&azure.client_id)
|
||||
.bind(&azure.redirect_uri)
|
||||
.bind(&azure.scopes)
|
||||
.bind(oidc.enabled)
|
||||
.bind(&oidc.provider_type)
|
||||
.bind(&oidc.display_name)
|
||||
.bind(&oidc.discovery_url)
|
||||
.bind(&oidc.client_id)
|
||||
.bind(&oidc.redirect_uri)
|
||||
.bind(&oidc.scopes)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
result.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update azure_sso_config");
|
||||
tracing::error!(error = %e, "Failed to update oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": format!("Failed to update Azure SSO config: {}", e) } })),
|
||||
Json(json!({ "error": { "code": "internal_error", "message": format!("Failed to update OIDC config: {}", e) } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
@ -350,9 +402,9 @@ async fn update_settings(
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("azure_sso"),
|
||||
Some("oidc"),
|
||||
Some("1"),
|
||||
json!({ "section": "azure_sso" }),
|
||||
json!({ "section": "oidc" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
@ -497,55 +549,30 @@ async fn update_settings(
|
||||
"sso_callback_url".to_string(),
|
||||
state.config.security.sso_callback_url.clone(),
|
||||
);
|
||||
let azure = fetch_azure_sso_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, azure)))
|
||||
let oidc = fetch_oidc_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, oidc)))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/azure-sso/test
|
||||
// POST /api/v1/settings/sso/discover
|
||||
// ============================================================
|
||||
|
||||
async fn test_azure_sso(
|
||||
async fn discover_oidc(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<OidcDiscoveryRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
admin_only(&auth)?;
|
||||
|
||||
let row: Option<(String, String)> = sqlx::query_as(
|
||||
"SELECT tenant_id, client_id FROM azure_sso_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load azure_sso_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (tenant_id, _client_id) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "Azure SSO is not configured"
|
||||
})));
|
||||
},
|
||||
};
|
||||
|
||||
if tenant_id.is_empty() {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "Azure tenant ID is not set"
|
||||
})));
|
||||
if req.discovery_url.is_empty() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": "discovery_url is required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let url = format!(
|
||||
"https://login.microsoftonline.com/{}/v2.0/.well-known/openid-configuration",
|
||||
tenant_id
|
||||
);
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
@ -557,35 +584,129 @@ async fn test_azure_sso(
|
||||
)
|
||||
})?;
|
||||
|
||||
match client.get(&url).send().await {
|
||||
match client.get(&req.discovery_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: Value = resp.json().await.unwrap_or(json!({}));
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"issuer": body.get("issuer").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"authorization_endpoint": body.get("authorization_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"token_endpoint": body.get("token_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"jwks_uri": body.get("jwks_uri").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"userinfo_endpoint": body.get("userinfo_endpoint").and_then(|v| v.as_str()),
|
||||
})))
|
||||
},
|
||||
Ok(resp) => Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({ "error": { "code": "discovery_failed", "message": format!("Discovery endpoint returned HTTP {}", resp.status()) } }),
|
||||
),
|
||||
)),
|
||||
Err(e) => Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({ "error": { "code": "discovery_failed", "message": format!("Failed to reach discovery endpoint: {}", e) } }),
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/sso/test
|
||||
// ============================================================
|
||||
|
||||
async fn test_oidc(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
admin_only(&auth)?;
|
||||
|
||||
let row: Option<(bool, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, discovery_url FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (enabled, provider_type, discovery_url) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC is not configured"
|
||||
})));
|
||||
},
|
||||
};
|
||||
|
||||
if !enabled {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC is not enabled"
|
||||
})));
|
||||
}
|
||||
|
||||
if discovery_url.is_empty() {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC discovery URL is not set"
|
||||
})));
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
match client.get(&discovery_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: Value = resp.json().await.unwrap_or(json!({}));
|
||||
let issuer = body.get("issuer").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if issuer.contains(&tenant_id) {
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": "Azure AD tenant verified successfully",
|
||||
"issuer": issuer
|
||||
})))
|
||||
} else {
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": "Azure AD endpoint reached, but issuer does not match tenant_id",
|
||||
"issuer": issuer
|
||||
})))
|
||||
}
|
||||
let provider_label = match provider_type.as_str() {
|
||||
"keycloak" => "Keycloak",
|
||||
"azure" => "Azure AD",
|
||||
_ => "OIDC",
|
||||
};
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": format!("{} provider verified successfully", provider_label),
|
||||
"issuer": issuer,
|
||||
"provider_type": provider_type,
|
||||
})))
|
||||
},
|
||||
Ok(resp) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to reach Azure AD: HTTP {}", resp.status())
|
||||
"message": format!("Failed to reach OIDC provider: HTTP {}", resp.status())
|
||||
}))),
|
||||
Err(e) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to reach Azure AD: {}", e)
|
||||
"message": format!("Failed to reach OIDC provider: {}", e)
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/azure-sso/test (backward-compatible alias)
|
||||
// ============================================================
|
||||
|
||||
async fn test_azure_sso_compat(
|
||||
state: State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
test_oidc(state, auth).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/smtp/test
|
||||
// ============================================================
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
//! Azure SSO OAuth2/OIDC flow routes.
|
||||
//! Generic OIDC SSO routes (Keycloak, Azure AD, Custom).
|
||||
//!
|
||||
//! Public routes (no auth required):
|
||||
//! GET /api/v1/auth/azure/login — redirect to Azure AD authorization URL
|
||||
//! GET /api/v1/auth/azure/callback — handle Azure AD callback, redirect to frontend SPA
|
||||
//! GET /api/v1/auth/sso/login — redirect to OIDC provider authorization URL
|
||||
//! GET /api/v1/auth/sso/callback — handle OIDC provider callback, redirect to frontend SPA
|
||||
//!
|
||||
//! Backward-compatible aliases:
|
||||
//! GET /api/v1/auth/azure/login → redirects to generic SSO login
|
||||
//! GET /api/v1/auth/azure/callback → redirects to generic SSO callback
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
@ -51,6 +55,7 @@ struct TokenResponse {
|
||||
struct IdTokenClaims {
|
||||
email: Option<String>,
|
||||
name: Option<String>,
|
||||
sub: Option<String>,
|
||||
oid: Option<String>,
|
||||
preferred_username: Option<String>,
|
||||
}
|
||||
@ -65,23 +70,51 @@ struct DbUserForSso {
|
||||
mfa_enabled: bool,
|
||||
}
|
||||
|
||||
/// Cache for Azure AD JWKS (JSON Web Key Set) with TTL-based refresh.
|
||||
pub struct JwksCache {
|
||||
pub keys: Option<serde_json::Value>,
|
||||
pub fetched_at: Option<chrono::DateTime<Utc>>,
|
||||
/// OIDC provider configuration from database.
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
pub struct OidcConfig {
|
||||
pub enabled: bool,
|
||||
pub provider_type: String,
|
||||
pub display_name: String,
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub redirect_uri: String,
|
||||
pub scopes: String,
|
||||
}
|
||||
|
||||
impl Default for JwksCache {
|
||||
/// Cached OIDC discovery document.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OidcDiscovery {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>,
|
||||
pub fetched_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Cache for OIDC discovery documents and JWKS with TTL-based refresh.
|
||||
pub struct OidcCache {
|
||||
pub discovery: Option<OidcDiscovery>,
|
||||
pub jwks: Option<serde_json::Value>,
|
||||
pub jwks_fetched_at: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl Default for OidcCache {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
keys: None,
|
||||
fetched_at: None,
|
||||
discovery: None,
|
||||
jwks: None,
|
||||
jwks_fetched_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JWKS cache TTL in seconds (1 hour).
|
||||
const JWKS_CACHE_TTL_SECS: i64 = 3600;
|
||||
/// Discovery cache TTL in seconds (1 hour).
|
||||
const DISCOVERY_CACHE_TTL_SECS: i64 = 3600;
|
||||
|
||||
// ============================================================
|
||||
// Router
|
||||
@ -89,52 +122,55 @@ const JWKS_CACHE_TTL_SECS: i64 = 3600;
|
||||
|
||||
pub fn public_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", get(azure_login))
|
||||
.route("/callback", get(azure_callback))
|
||||
.route("/login", get(sso_login))
|
||||
.route("/callback", get(sso_callback))
|
||||
}
|
||||
|
||||
/// Backward-compatible Azure SSO routes — redirect to generic SSO endpoints.
|
||||
pub fn azure_compat_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", get(azure_login_redirect))
|
||||
.route("/callback", get(azure_callback_redirect))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/azure/login
|
||||
// GET /api/v1/auth/sso/login
|
||||
// ============================================================
|
||||
|
||||
async fn azure_login(
|
||||
async fn sso_login(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
// Read Azure SSO config from DB
|
||||
let row: Option<(bool, String, String, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, tenant_id, client_id, redirect_uri, scopes FROM azure_sso_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load azure_sso_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
let config = load_oidc_config(&state.db).await?;
|
||||
|
||||
let (enabled, tenant_id, client_id, redirect_uri, scopes) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
if !config.enabled {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "SSO is not enabled" } })),
|
||||
));
|
||||
}
|
||||
|
||||
if config.discovery_url.is_empty() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "SSO discovery URL is not configured" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch OIDC discovery document (with caching)
|
||||
let discovery = match fetch_discovery(&state).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "Azure SSO is not configured" } }),
|
||||
json!({ "error": { "code": "internal_error", "message": format!("Failed to fetch OIDC discovery: {}", e) } }),
|
||||
),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
if !enabled {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "Azure SSO is not enabled" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Generate PKCE code_verifier (32 random bytes → base64url)
|
||||
let mut verifier_bytes = [0u8; 32];
|
||||
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut verifier_bytes);
|
||||
@ -156,19 +192,23 @@ async fn azure_login(
|
||||
},
|
||||
);
|
||||
|
||||
// Build authorization URL
|
||||
let encoded_scopes = urlencoding::encode(&scopes);
|
||||
// Build authorization URL from discovery
|
||||
let encoded_scopes = urlencoding::encode(&config.scopes);
|
||||
let auth_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/authorize?client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||
tenant_id, client_id, redirect_uri, encoded_scopes, code_challenge, state_token
|
||||
"{}?client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||
discovery.authorization_endpoint,
|
||||
urlencoding::encode(&config.client_id),
|
||||
urlencoding::encode(&config.redirect_uri),
|
||||
encoded_scopes,
|
||||
code_challenge,
|
||||
state_token
|
||||
);
|
||||
|
||||
// Redirect to Azure AD
|
||||
Ok(Redirect::to(&auth_url))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/azure/callback
|
||||
// GET /api/v1/auth/sso/callback
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -179,13 +219,12 @@ struct CallbackParams {
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
async fn azure_callback(
|
||||
async fn sso_callback(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<Redirect, Redirect> {
|
||||
let callback_url = &state.config.security.sso_callback_url;
|
||||
|
||||
// Helper to build error redirect
|
||||
let error_redirect = |code: &str, message: &str| -> Redirect {
|
||||
let url = format!(
|
||||
"{}?error={}&error_description={}",
|
||||
@ -196,10 +235,9 @@ async fn azure_callback(
|
||||
Redirect::to(&url)
|
||||
};
|
||||
|
||||
// Check for error from Azure AD
|
||||
if let Some(error) = params.error {
|
||||
let desc = params.error_description.unwrap_or_default();
|
||||
let message = format!("Azure AD error: {} - {}", error, desc);
|
||||
let message = format!("OIDC provider error: {} - {}", error, desc);
|
||||
return Err(error_redirect("sso_error", &message));
|
||||
}
|
||||
|
||||
@ -213,7 +251,6 @@ async fn azure_callback(
|
||||
None => return Err(error_redirect("bad_request", "Missing state parameter")),
|
||||
};
|
||||
|
||||
// Look up code_verifier from sso_sessions
|
||||
let sso_session = match state.sso_sessions.remove(&state_token).map(|(_, v)| v) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
@ -224,33 +261,28 @@ async fn azure_callback(
|
||||
},
|
||||
};
|
||||
|
||||
// Read Azure SSO config (including client_secret for token exchange)
|
||||
let row: Option<(bool, String, String, String, String)> = match sqlx::query_as(
|
||||
"SELECT enabled, tenant_id, client_id, client_secret, redirect_uri FROM azure_sso_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to load azure_sso_config");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
}
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to load OIDC config",
|
||||
))
|
||||
},
|
||||
};
|
||||
|
||||
let (_enabled, tenant_id, client_id, client_secret, redirect_uri) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Err(error_redirect("internal_error", "Azure SSO not configured"));
|
||||
let discovery = match fetch_discovery(&state).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to fetch OIDC discovery");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to fetch OIDC discovery",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Exchange code for tokens
|
||||
let token_url = format!(
|
||||
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
tenant_id
|
||||
);
|
||||
|
||||
let client = match reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
@ -262,18 +294,25 @@ async fn azure_callback(
|
||||
},
|
||||
};
|
||||
|
||||
let params = [
|
||||
let mut params_vec: Vec<(&str, String)> = vec![
|
||||
("grant_type", "authorization_code".to_string()),
|
||||
("code", code.clone()),
|
||||
("redirect_uri", redirect_uri.clone()),
|
||||
("client_id", client_id.clone()),
|
||||
("client_secret", client_secret.clone()),
|
||||
("redirect_uri", config.redirect_uri.clone()),
|
||||
("client_id", config.client_id.clone()),
|
||||
("code_verifier", sso_session.code_verifier.clone()),
|
||||
];
|
||||
|
||||
let form_params: Vec<(&str, String)> = params.to_vec();
|
||||
// For confidential clients (Azure AD), include client_secret
|
||||
if !config.client_secret.is_empty() {
|
||||
params_vec.push(("client_secret", config.client_secret.clone()));
|
||||
}
|
||||
|
||||
let token_resp = match client.post(&token_url).form(&form_params).send().await {
|
||||
let token_resp = match client
|
||||
.post(&discovery.token_endpoint)
|
||||
.form(¶ms_vec)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Token exchange request failed");
|
||||
@ -305,13 +344,12 @@ async fn azure_callback(
|
||||
},
|
||||
};
|
||||
|
||||
// Verify id_token JWT signature using Azure AD JWKS and validate claims
|
||||
let id_token = match token_data.id_token {
|
||||
Some(t) => t,
|
||||
None => return Err(error_redirect("sso_error", "No id_token in response")),
|
||||
};
|
||||
|
||||
let claims = match verify_id_token(&id_token, &tenant_id, &client_id, &state.jwks_cache).await {
|
||||
let claims = match verify_id_token(&id_token, &config, &discovery, &state.oidc_cache).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to verify id_token");
|
||||
@ -324,22 +362,37 @@ async fn azure_callback(
|
||||
|
||||
let email = claims.email.unwrap_or_default();
|
||||
let name = claims.name.unwrap_or_default();
|
||||
let oid = claims.oid.unwrap_or_default();
|
||||
let oidc_sub = claims.sub.unwrap_or_default();
|
||||
let azure_oid = claims.oid.unwrap_or_default();
|
||||
let preferred_username = claims.preferred_username.unwrap_or_else(|| email.clone());
|
||||
|
||||
if email.is_empty() || oid.is_empty() {
|
||||
let provider_subject = if !oidc_sub.is_empty() {
|
||||
oidc_sub.clone()
|
||||
} else if !azure_oid.is_empty() {
|
||||
azure_oid.clone()
|
||||
} else {
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
"Missing email or oid in id_token",
|
||||
"Missing subject identifier in id_token",
|
||||
));
|
||||
};
|
||||
|
||||
if email.is_empty() {
|
||||
return Err(error_redirect("sso_error", "Missing email in id_token"));
|
||||
}
|
||||
|
||||
// Look up or create user
|
||||
let auth_provider = match config.provider_type.as_str() {
|
||||
"keycloak" => "keycloak",
|
||||
"azure" => "azure_sso",
|
||||
_ => "oidc",
|
||||
};
|
||||
|
||||
let user_opt: Option<DbUserForSso> = match sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, role, is_active, mfa_enabled
|
||||
FROM users WHERE email = $1 AND auth_provider = 'azure_sso'"#,
|
||||
FROM users WHERE email = $1 AND auth_provider = $2"#,
|
||||
)
|
||||
.bind(&email)
|
||||
.bind(auth_provider)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
@ -356,16 +409,17 @@ async fn azure_callback(
|
||||
},
|
||||
Some(u) => u,
|
||||
None => {
|
||||
// Auto-create user with role=operator, auth_provider=azure_sso
|
||||
let id: Uuid = match sqlx::query_scalar(
|
||||
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid)
|
||||
VALUES ($1, $2, $3, 'operator', 'azure_sso', $4)
|
||||
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid, oidc_sub)
|
||||
VALUES ($1, $2, $3, 'operator', $4, $5, $6)
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(&preferred_username)
|
||||
.bind(&name)
|
||||
.bind(&email)
|
||||
.bind(&oid)
|
||||
.bind(auth_provider)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
{
|
||||
@ -380,10 +434,10 @@ async fn azure_callback(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
None,
|
||||
Some("azure_sso"),
|
||||
Some(auth_provider),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "auth_provider": "azure_sso", "email": email }),
|
||||
json!({ "auth_provider": auth_provider, "email": email }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
@ -400,11 +454,12 @@ async fn azure_callback(
|
||||
},
|
||||
};
|
||||
|
||||
// Update last_login_at and azure_oid
|
||||
// Update last_login_at and provider subject IDs
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1) WHERE id = $2",
|
||||
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1), oidc_sub = COALESCE(oidc_sub, $2) WHERE id = $3",
|
||||
)
|
||||
.bind(&oid)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.bind(user.id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
@ -413,7 +468,6 @@ async fn azure_callback(
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
}
|
||||
|
||||
// Issue JWT access token + refresh token
|
||||
let access_ttl = state.config.security.jwt_access_ttl_secs as i64;
|
||||
let access_token = match issue_access_token(
|
||||
user.id,
|
||||
@ -447,22 +501,21 @@ async fn azure_callback(
|
||||
Some(&user.username),
|
||||
None,
|
||||
None,
|
||||
json!({ "auth_provider": "azure_sso" }),
|
||||
json!({ "auth_provider": auth_provider }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Build user JSON for query parameter
|
||||
let user_json = json!({
|
||||
"id": user.id.to_string(),
|
||||
"username": user.username,
|
||||
"display_name": user.display_name,
|
||||
"role": user.role,
|
||||
"auth_provider": auth_provider,
|
||||
"mfa_enabled": user.mfa_enabled,
|
||||
});
|
||||
|
||||
// Redirect to frontend SPA with tokens as query parameters
|
||||
let redirect_url = format!(
|
||||
"{}?access_token={}&refresh_token={}&token_type=Bearer&expires_in={}&user={}",
|
||||
callback_url,
|
||||
@ -476,32 +529,149 @@ async fn azure_callback(
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// JWT Verification Helpers
|
||||
// Backward-compatible Azure SSO redirect handlers
|
||||
// ============================================================
|
||||
|
||||
/// Verify the id_token JWT signature using Azure AD JWKS and validate standard claims.
|
||||
///
|
||||
/// Steps:
|
||||
/// 1. Decode JWT header to extract `kid` (key ID)
|
||||
/// 2. Fetch JWKS from Azure AD if cache is empty or expired (1-hour TTL)
|
||||
/// 3. Find the matching JWK by `kid`
|
||||
/// 4. Construct RSA public key from JWK modulus (`n`) and exponent (`e`)
|
||||
/// 5. Validate issuer, audience, and expiry via `jsonwebtoken::decode`
|
||||
async fn azure_login_redirect(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
sso_login(State(state)).await
|
||||
}
|
||||
|
||||
async fn azure_callback_redirect(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<Redirect, Redirect> {
|
||||
sso_callback(State(state), axum::extract::Query(params)).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Database helpers
|
||||
// ============================================================
|
||||
|
||||
async fn load_oidc_config(pool: &sqlx::PgPool) -> Result<OidcConfig, (StatusCode, Json<Value>)> {
|
||||
let row: Option<OidcConfig> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(row.unwrap_or(OidcConfig {
|
||||
enabled: false,
|
||||
provider_type: "azure".to_string(),
|
||||
display_name: "Azure AD".to_string(),
|
||||
discovery_url: String::new(),
|
||||
client_id: String::new(),
|
||||
client_secret: String::new(),
|
||||
redirect_uri: String::new(),
|
||||
scopes: "openid profile email".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OIDC Discovery & JWKS
|
||||
// ============================================================
|
||||
|
||||
async fn fetch_discovery(state: &AppState) -> Result<OidcDiscovery, String> {
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err("Failed to load OIDC config".to_string());
|
||||
},
|
||||
};
|
||||
let discovery_url = config.discovery_url;
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let cache = state.oidc_cache.lock().await;
|
||||
if let Some(ref disc) = cache.discovery {
|
||||
let elapsed = Utc::now().signed_duration_since(disc.fetched_at);
|
||||
if elapsed.num_seconds() < DISCOVERY_CACHE_TTL_SECS {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Discovery fetch failed: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"Discovery fetch failed: HTTP {} — {}",
|
||||
status, body
|
||||
));
|
||||
}
|
||||
|
||||
let doc: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse discovery document: {}", e))?;
|
||||
|
||||
let discovery = OidcDiscovery {
|
||||
issuer: doc
|
||||
.get("issuer")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
authorization_endpoint: doc
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
token_endpoint: doc
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
jwks_uri: doc
|
||||
.get("jwks_uri")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
userinfo_endpoint: doc
|
||||
.get("userinfo_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string()),
|
||||
fetched_at: Utc::now(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut cache = state.oidc_cache.lock().await;
|
||||
cache.discovery = Some(discovery.clone());
|
||||
}
|
||||
|
||||
Ok(discovery)
|
||||
}
|
||||
|
||||
async fn verify_id_token(
|
||||
token: &str,
|
||||
tenant_id: &str,
|
||||
client_id: &str,
|
||||
jwks_cache: &Arc<Mutex<JwksCache>>,
|
||||
config: &OidcConfig,
|
||||
discovery: &OidcDiscovery,
|
||||
oidc_cache: &Arc<Mutex<OidcCache>>,
|
||||
) -> Result<IdTokenClaims, String> {
|
||||
// 1. Decode JWT header to get the kid
|
||||
let header = decode_header(token).map_err(|e| format!("Failed to decode JWT header: {}", e))?;
|
||||
|
||||
let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
|
||||
|
||||
// 2. Check JWKS cache — fetch if expired or missing
|
||||
let jwks = {
|
||||
let cache = jwks_cache.lock().await;
|
||||
let needs_fetch = match (&cache.keys, &cache.fetched_at) {
|
||||
let cache = oidc_cache.lock().await;
|
||||
let needs_fetch = match (&cache.jwks, &cache.jwks_fetched_at) {
|
||||
(None, _) => true,
|
||||
(Some(_), None) => true,
|
||||
(Some(_), Some(fetched)) => {
|
||||
@ -511,21 +681,17 @@ async fn verify_id_token(
|
||||
};
|
||||
|
||||
if needs_fetch {
|
||||
// Drop lock before making async HTTP request
|
||||
drop(cache);
|
||||
|
||||
let jwks_value = fetch_jwks(tenant_id).await?;
|
||||
|
||||
let mut cache = jwks_cache.lock().await;
|
||||
cache.keys = Some(jwks_value);
|
||||
cache.fetched_at = Some(Utc::now());
|
||||
cache.keys.clone().unwrap()
|
||||
let jwks_value = fetch_jwks(&discovery.jwks_uri).await?;
|
||||
let mut cache = oidc_cache.lock().await;
|
||||
cache.jwks = Some(jwks_value);
|
||||
cache.jwks_fetched_at = Some(Utc::now());
|
||||
cache.jwks.clone().unwrap()
|
||||
} else {
|
||||
cache.keys.clone().unwrap()
|
||||
cache.jwks.clone().unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
// 3. Find the matching JWK by kid
|
||||
let keys_array = jwks
|
||||
.get("keys")
|
||||
.ok_or("JWKS response missing 'keys' array")?
|
||||
@ -537,7 +703,6 @@ async fn verify_id_token(
|
||||
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()))
|
||||
.ok_or_else(|| format!("No matching JWK found for kid: {}", kid))?;
|
||||
|
||||
// 4. Construct RSA public key from JWK modulus (n) and exponent (e)
|
||||
let n = jwk
|
||||
.get("n")
|
||||
.and_then(|v| v.as_str())
|
||||
@ -550,36 +715,25 @@ async fn verify_id_token(
|
||||
let decoding_key = DecodingKey::from_rsa_components(n, e)
|
||||
.map_err(|e| format!("Failed to construct RSA decoding key: {}", e))?;
|
||||
|
||||
// 5. Configure validation rules
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.iss = Some(HashSet::from([format!(
|
||||
"https://login.microsoftonline.com/{}/v2.0",
|
||||
tenant_id
|
||||
)]));
|
||||
validation.aud = Some(HashSet::from([client_id.to_string()]));
|
||||
validation.leeway = 60; // 60 seconds clock skew tolerance
|
||||
validation.iss = Some(HashSet::from([discovery.issuer.clone()]));
|
||||
validation.aud = Some(HashSet::from([config.client_id.clone()]));
|
||||
validation.leeway = 60;
|
||||
|
||||
// 6. Decode and verify the JWT
|
||||
let token_data = decode::<IdTokenClaims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("JWT signature verification failed: {}", e))?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
/// Fetch the JWKS from the Azure AD discovery endpoint.
|
||||
async fn fetch_jwks(tenant_id: &str) -> Result<serde_json::Value, String> {
|
||||
let jwks_url = format!(
|
||||
"https://login.microsoftonline.com/{}/discovery/v2.0/keys",
|
||||
tenant_id
|
||||
);
|
||||
|
||||
async fn fetch_jwks(jwks_uri: &str) -> Result<serde_json::Value, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build HTTP client for JWKS fetch: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(&jwks_url)
|
||||
.get(jwks_uri)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("JWKS fetch request failed: {}", e))?;
|
||||
Reference in New Issue
Block a user