|
|
|
|
@ -2,7 +2,7 @@
|
|
|
|
|
//!
|
|
|
|
|
//! Public routes (no auth required):
|
|
|
|
|
//! GET /api/v1/auth/azure/login — redirect to Azure AD authorization URL
|
|
|
|
|
//! GET /api/v1/auth/azure/callback — handle Azure AD callback
|
|
|
|
|
//! GET /api/v1/auth/azure/callback — handle Azure AD callback, redirect to frontend SPA
|
|
|
|
|
|
|
|
|
|
use axum::{
|
|
|
|
|
extract::State,
|
|
|
|
|
@ -13,11 +13,15 @@ use axum::{
|
|
|
|
|
};
|
|
|
|
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
|
|
|
|
use chrono::Utc;
|
|
|
|
|
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
|
|
|
|
|
use pm_auth::{jwt::issue_access_token, refresh};
|
|
|
|
|
use pm_core::audit::{log_event, AuditAction};
|
|
|
|
|
use serde::Deserialize;
|
|
|
|
|
use serde_json::{json, Value};
|
|
|
|
|
use sha2::{Digest, Sha256};
|
|
|
|
|
use std::collections::HashSet;
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
use tokio::sync::Mutex;
|
|
|
|
|
use uuid::Uuid;
|
|
|
|
|
|
|
|
|
|
use crate::AppState;
|
|
|
|
|
@ -61,6 +65,24 @@ struct DbUserForSso {
|
|
|
|
|
mfa_enabled: bool,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Cache for Azure AD JWKS (JSON Web Key Set) with TTL-based refresh.
|
|
|
|
|
pub struct JwksCache {
|
|
|
|
|
pub keys: Option<serde_json::Value>,
|
|
|
|
|
pub fetched_at: Option<chrono::DateTime<Utc>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Default for JwksCache {
|
|
|
|
|
fn default() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
keys: None,
|
|
|
|
|
fetched_at: None,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// JWKS cache TTL in seconds (1 hour).
|
|
|
|
|
const JWKS_CACHE_TTL_SECS: i64 = 3600;
|
|
|
|
|
|
|
|
|
|
// ============================================================
|
|
|
|
|
// Router
|
|
|
|
|
// ============================================================
|
|
|
|
|
@ -160,69 +182,61 @@ struct CallbackParams {
|
|
|
|
|
async fn azure_callback(
|
|
|
|
|
State(state): State<AppState>,
|
|
|
|
|
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
|
|
|
|
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
|
|
|
|
) -> Result<Redirect, Redirect> {
|
|
|
|
|
let callback_url = &state.config.security.sso_callback_url;
|
|
|
|
|
|
|
|
|
|
// Helper to build error redirect
|
|
|
|
|
let error_redirect = |code: &str, message: &str| -> Redirect {
|
|
|
|
|
let url = format!(
|
|
|
|
|
"{}?error={}&error_description={}",
|
|
|
|
|
callback_url,
|
|
|
|
|
urlencoding::encode(code),
|
|
|
|
|
urlencoding::encode(message)
|
|
|
|
|
);
|
|
|
|
|
Redirect::to(&url)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Check for error from Azure AD
|
|
|
|
|
if let Some(error) = params.error {
|
|
|
|
|
let desc = params.error_description.unwrap_or_default();
|
|
|
|
|
return Err((
|
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
|
Json(
|
|
|
|
|
json!({ "error": { "code": "sso_error", "message": format!("Azure AD error: {} - {}", error, desc) } }),
|
|
|
|
|
),
|
|
|
|
|
));
|
|
|
|
|
let message = format!("Azure AD error: {} - {}", error, desc);
|
|
|
|
|
return Err(error_redirect("sso_error", &message));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let code = params.code.ok_or_else(|| {
|
|
|
|
|
(
|
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
|
Json(json!({ "error": { "code": "bad_request", "message": "Missing authorization code" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
let code = match params.code {
|
|
|
|
|
Some(c) => c,
|
|
|
|
|
None => return Err(error_redirect("bad_request", "Missing authorization code")),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let state_token = params.state.ok_or_else(|| {
|
|
|
|
|
(
|
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
|
Json(
|
|
|
|
|
json!({ "error": { "code": "bad_request", "message": "Missing state parameter" } }),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
let state_token = match params.state {
|
|
|
|
|
Some(s) => s,
|
|
|
|
|
None => return Err(error_redirect("bad_request", "Missing state parameter")),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Look up code_verifier from sso_sessions
|
|
|
|
|
let sso_session = state
|
|
|
|
|
.sso_sessions
|
|
|
|
|
.remove(&state_token)
|
|
|
|
|
.map(|(_, v)| v)
|
|
|
|
|
.ok_or_else(|| {
|
|
|
|
|
(
|
|
|
|
|
StatusCode::BAD_REQUEST,
|
|
|
|
|
Json(json!({ "error": { "code": "bad_request", "message": "Invalid or expired state token" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
let sso_session = match state.sso_sessions.remove(&state_token).map(|(_, v)| v) {
|
|
|
|
|
Some(s) => s,
|
|
|
|
|
None => return Err(error_redirect("bad_request", "Invalid or expired state token")),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Read Azure SSO config (including client_secret for token exchange)
|
|
|
|
|
let row: Option<(bool, String, String, String, String)> = sqlx::query_as(
|
|
|
|
|
let row: Option<(bool, String, String, String, String)> = match sqlx::query_as(
|
|
|
|
|
"SELECT enabled, tenant_id, client_id, client_secret, redirect_uri FROM azure_sso_config WHERE id = 1",
|
|
|
|
|
)
|
|
|
|
|
.fetch_optional(&state.db)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
tracing::error!(error = %e, "Failed to load azure_sso_config");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
{
|
|
|
|
|
Ok(r) => r,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to load azure_sso_config");
|
|
|
|
|
return Err(error_redirect("internal_error", "Database error"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let (_enabled, tenant_id, client_id, client_secret, redirect_uri) = match row {
|
|
|
|
|
Some(r) => r,
|
|
|
|
|
None => {
|
|
|
|
|
return Err((
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(
|
|
|
|
|
json!({ "error": { "code": "internal_error", "message": "Azure SSO not configured" } }),
|
|
|
|
|
),
|
|
|
|
|
));
|
|
|
|
|
return Err(error_redirect("internal_error", "Azure SSO not configured"));
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -232,16 +246,16 @@ async fn azure_callback(
|
|
|
|
|
tenant_id
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let client = reqwest::Client::builder()
|
|
|
|
|
let client = match reqwest::Client::builder()
|
|
|
|
|
.timeout(std::time::Duration::from_secs(10))
|
|
|
|
|
.build()
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
{
|
|
|
|
|
Ok(c) => c,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to build HTTP client");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
return Err(error_redirect("internal_error", "HTTP client error"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let params = [
|
|
|
|
|
("grant_type", "authorization_code".to_string()),
|
|
|
|
|
@ -254,57 +268,47 @@ async fn azure_callback(
|
|
|
|
|
|
|
|
|
|
let form_params: Vec<(&str, String)> = params.to_vec();
|
|
|
|
|
|
|
|
|
|
let token_resp = client
|
|
|
|
|
let token_resp = match client
|
|
|
|
|
.post(&token_url)
|
|
|
|
|
.form(&form_params)
|
|
|
|
|
.send()
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
{
|
|
|
|
|
Ok(r) => r,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Token exchange request failed");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::BAD_GATEWAY,
|
|
|
|
|
Json(json!({ "error": { "code": "sso_error", "message": format!("Token exchange failed: {}", e) } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
return Err(error_redirect("sso_error", &format!("Token exchange failed: {}", e)));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if !token_resp.status().is_success() {
|
|
|
|
|
let status = token_resp.status();
|
|
|
|
|
let body = token_resp.text().await.unwrap_or_default();
|
|
|
|
|
tracing::error!(status = %status, body = %body, "Token exchange failed");
|
|
|
|
|
return Err((
|
|
|
|
|
StatusCode::BAD_GATEWAY,
|
|
|
|
|
Json(
|
|
|
|
|
json!({ "error": { "code": "sso_error", "message": format!("Token exchange failed: HTTP {}", status) } }),
|
|
|
|
|
),
|
|
|
|
|
));
|
|
|
|
|
return Err(error_redirect("sso_error", &format!("Token exchange failed: HTTP {}", status)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let token_data: TokenResponse = token_resp
|
|
|
|
|
.json()
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
let token_data: TokenResponse = match token_resp.json().await {
|
|
|
|
|
Ok(d) => d,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to parse token response");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Failed to parse token response" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
return Err(error_redirect("internal_error", "Failed to parse token response"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Decode id_token JWT (without verification — trust HTTPS channel)
|
|
|
|
|
let id_token = token_data.id_token.ok_or_else(|| {
|
|
|
|
|
(
|
|
|
|
|
StatusCode::BAD_GATEWAY,
|
|
|
|
|
Json(json!({ "error": { "code": "sso_error", "message": "No id_token in response" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
// Verify id_token JWT signature using Azure AD JWKS and validate claims
|
|
|
|
|
let id_token = match token_data.id_token {
|
|
|
|
|
Some(t) => t,
|
|
|
|
|
None => return Err(error_redirect("sso_error", "No id_token in response")),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let claims = decode_jwt_payload(&id_token).map_err(|e| {
|
|
|
|
|
tracing::error!(error = %e, "Failed to decode id_token");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Failed to decode id_token" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
let claims = match verify_id_token(&id_token, &tenant_id, &client_id, &state.jwks_cache).await {
|
|
|
|
|
Ok(c) => c,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to verify id_token");
|
|
|
|
|
return Err(error_redirect("internal_error", "Failed to verify id_token"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let email = claims.email.unwrap_or_default();
|
|
|
|
|
let name = claims.name.unwrap_or_default();
|
|
|
|
|
@ -312,43 +316,33 @@ async fn azure_callback(
|
|
|
|
|
let preferred_username = claims.preferred_username.unwrap_or_else(|| email.clone());
|
|
|
|
|
|
|
|
|
|
if email.is_empty() || oid.is_empty() {
|
|
|
|
|
return Err((
|
|
|
|
|
StatusCode::BAD_GATEWAY,
|
|
|
|
|
Json(
|
|
|
|
|
json!({ "error": { "code": "sso_error", "message": "Missing email or oid in id_token" } }),
|
|
|
|
|
),
|
|
|
|
|
));
|
|
|
|
|
return Err(error_redirect("sso_error", "Missing email or oid in id_token"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Look up or create user
|
|
|
|
|
let user_opt: Option<DbUserForSso> = sqlx::query_as(
|
|
|
|
|
let user_opt: Option<DbUserForSso> = match sqlx::query_as(
|
|
|
|
|
r#"SELECT id, username, display_name, role, is_active, mfa_enabled
|
|
|
|
|
FROM users WHERE email = $1 AND auth_provider = 'azure_sso'"#,
|
|
|
|
|
)
|
|
|
|
|
.bind(&email)
|
|
|
|
|
.fetch_optional(&state.db)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
tracing::error!(error = %e, "Failed to look up SSO user");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
{
|
|
|
|
|
Ok(o) => o,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to look up SSO user");
|
|
|
|
|
return Err(error_redirect("internal_error", "Database error"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let user = match user_opt {
|
|
|
|
|
Some(u) if !u.is_active => {
|
|
|
|
|
return Err((
|
|
|
|
|
StatusCode::FORBIDDEN,
|
|
|
|
|
Json(
|
|
|
|
|
json!({ "error": { "code": "account_disabled", "message": "Account is disabled" } }),
|
|
|
|
|
),
|
|
|
|
|
));
|
|
|
|
|
return Err(error_redirect("account_disabled", "Account is disabled"));
|
|
|
|
|
},
|
|
|
|
|
Some(u) => u,
|
|
|
|
|
None => {
|
|
|
|
|
// Auto-create user with role=operator, auth_provider=azure_sso
|
|
|
|
|
let id: Uuid = sqlx::query_scalar(
|
|
|
|
|
let id: Uuid = match sqlx::query_scalar(
|
|
|
|
|
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid)
|
|
|
|
|
VALUES ($1, $2, $3, 'operator', 'azure_sso', $4)
|
|
|
|
|
RETURNING id"#,
|
|
|
|
|
@ -359,13 +353,13 @@ async fn azure_callback(
|
|
|
|
|
.bind(&oid)
|
|
|
|
|
.fetch_one(&state.db)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
tracing::error!(error = %e, "Failed to create SSO user");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Failed to create user" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
{
|
|
|
|
|
Ok(id) => id,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to create SSO user");
|
|
|
|
|
return Err(error_redirect("internal_error", "Failed to create user"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
log_event(
|
|
|
|
|
&state.db,
|
|
|
|
|
@ -392,47 +386,41 @@ async fn azure_callback(
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Update last_login_at and azure_oid
|
|
|
|
|
sqlx::query(
|
|
|
|
|
if let Err(e) = sqlx::query(
|
|
|
|
|
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1) WHERE id = $2",
|
|
|
|
|
)
|
|
|
|
|
.bind(&oid)
|
|
|
|
|
.bind(user.id)
|
|
|
|
|
.execute(&state.db)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
{
|
|
|
|
|
tracing::error!(error = %e, "Failed to update last_login_at");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
return Err(error_redirect("internal_error", "Database error"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Issue JWT access token + refresh token
|
|
|
|
|
let access_ttl = state.config.security.jwt_access_ttl_secs as i64;
|
|
|
|
|
let access_token = issue_access_token(
|
|
|
|
|
let access_token = match issue_access_token(
|
|
|
|
|
user.id,
|
|
|
|
|
&user.username,
|
|
|
|
|
&user.role,
|
|
|
|
|
access_ttl,
|
|
|
|
|
&state.signing_key_pem,
|
|
|
|
|
)
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
tracing::error!(error = %e, "Failed to issue access token");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Token issuance failed" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
) {
|
|
|
|
|
Ok(t) => t,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to issue access token");
|
|
|
|
|
return Err(error_redirect("internal_error", "Token issuance failed"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let raw_refresh = refresh::issue(&state.db, user.id, None, None)
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| {
|
|
|
|
|
let raw_refresh = match refresh::issue(&state.db, user.id, None, None).await {
|
|
|
|
|
Ok(r) => r,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!(error = %e, "Failed to issue refresh token");
|
|
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
Json(json!({ "error": { "code": "internal_error", "message": "Refresh token issuance failed" } })),
|
|
|
|
|
)
|
|
|
|
|
})?;
|
|
|
|
|
return Err(error_redirect("internal_error", "Refresh token issuance failed"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
log_event(
|
|
|
|
|
&state.db,
|
|
|
|
|
@ -447,42 +435,145 @@ async fn azure_callback(
|
|
|
|
|
)
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
Ok(Json(json!({
|
|
|
|
|
"access_token": access_token,
|
|
|
|
|
"refresh_token": raw_refresh.0,
|
|
|
|
|
"token_type": "Bearer",
|
|
|
|
|
"expires_in": access_ttl,
|
|
|
|
|
"user": {
|
|
|
|
|
"id": user.id.to_string(),
|
|
|
|
|
"username": user.username,
|
|
|
|
|
"display_name": user.display_name,
|
|
|
|
|
"role": user.role,
|
|
|
|
|
"mfa_enabled": user.mfa_enabled,
|
|
|
|
|
// Build user JSON for query parameter
|
|
|
|
|
let user_json = json!({
|
|
|
|
|
"id": user.id.to_string(),
|
|
|
|
|
"username": user.username,
|
|
|
|
|
"display_name": user.display_name,
|
|
|
|
|
"role": user.role,
|
|
|
|
|
"mfa_enabled": user.mfa_enabled,
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// Redirect to frontend SPA with tokens as query parameters
|
|
|
|
|
let redirect_url = format!(
|
|
|
|
|
"{}?access_token={}&refresh_token={}&token_type=Bearer&expires_in={}&user={}",
|
|
|
|
|
callback_url,
|
|
|
|
|
urlencoding::encode(&access_token),
|
|
|
|
|
urlencoding::encode(&raw_refresh.0),
|
|
|
|
|
access_ttl,
|
|
|
|
|
urlencoding::encode(&user_json.to_string()),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
Ok(Redirect::to(&redirect_url))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============================================================
|
|
|
|
|
// JWT Verification Helpers
|
|
|
|
|
// ============================================================
|
|
|
|
|
|
|
|
|
|
/// Verify the id_token JWT signature using Azure AD JWKS and validate standard claims.
|
|
|
|
|
///
|
|
|
|
|
/// Steps:
|
|
|
|
|
/// 1. Decode JWT header to extract `kid` (key ID)
|
|
|
|
|
/// 2. Fetch JWKS from Azure AD if cache is empty or expired (1-hour TTL)
|
|
|
|
|
/// 3. Find the matching JWK by `kid`
|
|
|
|
|
/// 4. Construct RSA public key from JWK modulus (`n`) and exponent (`e`)
|
|
|
|
|
/// 5. Validate issuer, audience, and expiry via `jsonwebtoken::decode`
|
|
|
|
|
async fn verify_id_token(
|
|
|
|
|
token: &str,
|
|
|
|
|
tenant_id: &str,
|
|
|
|
|
client_id: &str,
|
|
|
|
|
jwks_cache: &Arc<Mutex<JwksCache>>,
|
|
|
|
|
) -> Result<IdTokenClaims, String> {
|
|
|
|
|
// 1. Decode JWT header to get the kid
|
|
|
|
|
let header = decode_header(token)
|
|
|
|
|
.map_err(|e| format!("Failed to decode JWT header: {}", e))?;
|
|
|
|
|
|
|
|
|
|
let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
|
|
|
|
|
|
|
|
|
|
// 2. Check JWKS cache — fetch if expired or missing
|
|
|
|
|
let jwks = {
|
|
|
|
|
let cache = jwks_cache.lock().await;
|
|
|
|
|
let needs_fetch = match (&cache.keys, &cache.fetched_at) {
|
|
|
|
|
(None, _) => true,
|
|
|
|
|
(Some(_), None) => true,
|
|
|
|
|
(Some(_), Some(fetched)) => {
|
|
|
|
|
let elapsed = Utc::now().signed_duration_since(*fetched);
|
|
|
|
|
elapsed.num_seconds() > JWKS_CACHE_TTL_SECS
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if needs_fetch {
|
|
|
|
|
// Drop lock before making async HTTP request
|
|
|
|
|
drop(cache);
|
|
|
|
|
|
|
|
|
|
let jwks_value = fetch_jwks(tenant_id).await?;
|
|
|
|
|
|
|
|
|
|
let mut cache = jwks_cache.lock().await;
|
|
|
|
|
cache.keys = Some(jwks_value);
|
|
|
|
|
cache.fetched_at = Some(Utc::now());
|
|
|
|
|
cache.keys.clone().unwrap()
|
|
|
|
|
} else {
|
|
|
|
|
cache.keys.clone().unwrap()
|
|
|
|
|
}
|
|
|
|
|
})))
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// 3. Find the matching JWK by kid
|
|
|
|
|
let keys_array = jwks
|
|
|
|
|
.get("keys")
|
|
|
|
|
.ok_or("JWKS response missing 'keys' array")?
|
|
|
|
|
.as_array()
|
|
|
|
|
.ok_or("JWKS 'keys' is not an array")?;
|
|
|
|
|
|
|
|
|
|
let jwk = keys_array
|
|
|
|
|
.iter()
|
|
|
|
|
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()))
|
|
|
|
|
.ok_or_else(|| format!("No matching JWK found for kid: {}", kid))?;
|
|
|
|
|
|
|
|
|
|
// 4. Construct RSA public key from JWK modulus (n) and exponent (e)
|
|
|
|
|
let n = jwk
|
|
|
|
|
.get("n")
|
|
|
|
|
.and_then(|v| v.as_str())
|
|
|
|
|
.ok_or("JWK missing 'n' (modulus) field")?;
|
|
|
|
|
let e = jwk
|
|
|
|
|
.get("e")
|
|
|
|
|
.and_then(|v| v.as_str())
|
|
|
|
|
.ok_or("JWK missing 'e' (exponent) field")?;
|
|
|
|
|
|
|
|
|
|
let decoding_key = DecodingKey::from_rsa_components(n, e)
|
|
|
|
|
.map_err(|e| format!("Failed to construct RSA decoding key: {}", e))?;
|
|
|
|
|
|
|
|
|
|
// 5. Configure validation rules
|
|
|
|
|
let mut validation = Validation::new(Algorithm::RS256);
|
|
|
|
|
validation.iss = Some(HashSet::from([format!(
|
|
|
|
|
"https://login.microsoftonline.com/{}/v2.0",
|
|
|
|
|
tenant_id
|
|
|
|
|
)]));
|
|
|
|
|
validation.aud = Some(HashSet::from([client_id.to_string()]));
|
|
|
|
|
validation.leeway = 60; // 60 seconds clock skew tolerance
|
|
|
|
|
|
|
|
|
|
// 6. Decode and verify the JWT
|
|
|
|
|
let token_data = decode::<IdTokenClaims>(token, &decoding_key, &validation)
|
|
|
|
|
.map_err(|e| format!("JWT signature verification failed: {}", e))?;
|
|
|
|
|
|
|
|
|
|
Ok(token_data.claims)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============================================================
|
|
|
|
|
// Helpers
|
|
|
|
|
// ============================================================
|
|
|
|
|
/// Fetch the JWKS from the Azure AD discovery endpoint.
|
|
|
|
|
async fn fetch_jwks(tenant_id: &str) -> Result<serde_json::Value, String> {
|
|
|
|
|
let jwks_url = format!(
|
|
|
|
|
"https://login.microsoftonline.com/{}/discovery/v2.0/keys",
|
|
|
|
|
tenant_id
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
/// Decode JWT payload without verification (trust HTTPS channel from Azure AD).
|
|
|
|
|
fn decode_jwt_payload(token: &str) -> Result<IdTokenClaims, String> {
|
|
|
|
|
let parts: Vec<&str> = token.split('.').collect();
|
|
|
|
|
if parts.len() != 3 {
|
|
|
|
|
return Err("Invalid JWT format".to_string());
|
|
|
|
|
let client = reqwest::Client::builder()
|
|
|
|
|
.timeout(std::time::Duration::from_secs(10))
|
|
|
|
|
.build()
|
|
|
|
|
.map_err(|e| format!("Failed to build HTTP client for JWKS fetch: {}", e))?;
|
|
|
|
|
|
|
|
|
|
let resp = client
|
|
|
|
|
.get(&jwks_url)
|
|
|
|
|
.send()
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| format!("JWKS fetch request failed: {}", e))?;
|
|
|
|
|
|
|
|
|
|
if !resp.status().is_success() {
|
|
|
|
|
let status = resp.status();
|
|
|
|
|
let body = resp.text().await.unwrap_or_default();
|
|
|
|
|
return Err(format!("JWKS fetch failed: HTTP {} — {}", status, body));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let payload_b64 = parts[1];
|
|
|
|
|
// Add padding if needed
|
|
|
|
|
let mut payload_b64_padded = payload_b64.to_string();
|
|
|
|
|
while payload_b64_padded.len() % 4 != 0 {
|
|
|
|
|
payload_b64_padded.push('=');
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let payload_bytes = base64::engine::general_purpose::STANDARD
|
|
|
|
|
.decode(&payload_b64_padded)
|
|
|
|
|
.map_err(|e| format!("Base64 decode error: {}", e))?;
|
|
|
|
|
|
|
|
|
|
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON parse error: {}", e))
|
|
|
|
|
resp.json::<serde_json::Value>()
|
|
|
|
|
.await
|
|
|
|
|
.map_err(|e| format!("Failed to parse JWKS response: {}", e))
|
|
|
|
|
}
|
|
|
|
|
|