Private
Public Access
1
0

feat: add bump-version.sh script for version management

Automates version bumps across all version source files:
- Cargo.toml (PRIMARY - workspace.package.version)
- debian/changelog (prepend new entry)
- debian/control (update Version field)
- scripts/build-package.sh (update VERSION variable)
- frontend/package.json (update version field)
- Stale references check after bump

Usage: ./scripts/bump-version.sh <new_version> <old_version>
This commit is contained in:
2026-05-28 10:52:16 -05:00
commit 124b5b0e3b
153 changed files with 41878 additions and 0 deletions

46
crates/pm-web/Cargo.toml Normal file
View File

@ -0,0 +1,46 @@
[package]
name = "pm-web"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[[bin]]
name = "pm-web"
path = "src/main.rs"
[dependencies]
pm-ca = { path = "../pm-ca" }
pm-core = { path = "../pm-core" }
pm-auth = { path = "../pm-auth" }
pm-reports = { path = "../pm-reports" }
tokio = { workspace = true }
axum = { workspace = true }
axum-server = { workspace = true }
rustls = { workspace = true }
axum-extra = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
sqlx = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
ulid = { workspace = true }
chrono = { workspace = true }
ipnet = { workspace = true }
dashmap = { version = "6" }
tower_governor = { workspace = true }
governor = { workspace = true }
reqwest = { workspace = true }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
rand = { workspace = true }
hex = "0.4"
base64 = { workspace = true }
sha2 = { workspace = true }
jsonwebtoken = { workspace = true }
url = { workspace = true }
urlencoding = "2"

353
crates/pm-web/src/main.rs Normal file
View File

@ -0,0 +1,353 @@
//! pm-web — Linux Patch Manager web server.
mod routes;
use axum::{extract::State, http::StatusCode, middleware, response::Json, routing::get, Router};
use axum_server::tls_rustls::RustlsConfig;
use dashmap::DashMap;
use pm_auth::{
jwt,
rbac::{require_auth, AuthConfig},
};
use pm_core::{
config::AppConfig, db, logging, models::PkiBundle, request_id::request_id_middleware,
};
use routes::sso::{OidcCache, SsoSession};
use routes::ws::WsTicket;
use serde_json::{json, Value};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::sync::Mutex;
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
};
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
/// Shared application state threaded through Axum.
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::PgPool,
pub config: Arc<AppConfig>,
pub signing_key_pem: String,
pub auth_config: Arc<AuthConfig>,
/// In-memory store for single-use WebSocket authentication tickets.
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
/// In-memory store for SSO PKCE sessions (state → code_verifier).
pub sso_sessions: Arc<DashMap<String, SsoSession>>,
/// Cached OIDC discovery document and JWKS for SSO id_token verification.
pub oidc_cache: Arc<Mutex<OidcCache>>,
/// Internal certificate authority for mTLS client cert issuance.
pub ca: Arc<pm_ca::CertAuthority>,
/// Short-lived cache for approved enrollment PKI bundles.
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Install the default crypto provider for rustls (required since 0.23)
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
eprintln!("Config file not found or invalid, using defaults");
AppConfig::default()
});
logging::init(&config.logging);
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"patch-manager-web starting"
);
let signing_key_pem = jwt::load_signing_key(&config.security.jwt_signing_key_path)
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "JWT signing key not found (dev mode)");
String::new()
});
let verify_key_pem =
jwt::load_verify_key(&config.security.jwt_verify_key_path).unwrap_or_else(|e| {
tracing::warn!(error = %e, "JWT verify key not found (dev mode)");
String::new()
});
let auth_config = Arc::new(AuthConfig::new(
verify_key_pem,
&config.security.ip_whitelist,
));
let pool = db::init_pool(&config.database).await?;
db::run_migrations(&pool).await?;
// Initialise the internal CA using the configured certificate paths.
// The CA certificate and key must exist at the configured locations and be
// unencrypted PEM. If absent, a new CA is generated in that directory.
let ca_base = std::path::Path::new(&config.security.ca_cert_path)
.parent()
.expect("CA certificate path must have a parent directory");
let ca = pm_ca::CertAuthority::init(ca_base, &pool)
.await
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "CA init failed (dev mode)");
panic!("CA initialization failed: {}", e);
});
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
// Background task: purge expired WS tickets every 30 seconds.
{
let tickets = ws_tickets.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
let now = chrono::Utc::now();
let before = tickets.len();
tickets.retain(|_, v| v.expires_at > now);
let removed = before.saturating_sub(tickets.len());
if removed > 0 {
tracing::debug!(removed, "Purged expired WS tickets");
}
}
});
}
// Background task: purge expired SSO sessions every 60 seconds (sessions older than 10 minutes).
{
let sessions = sso_sessions.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
let now = chrono::Utc::now();
let cutoff = now - chrono::Duration::minutes(10);
let before = sessions.len();
sessions.retain(|_, v| v.created_at > cutoff);
let removed = before.saturating_sub(sessions.len());
if removed > 0 {
tracing::debug!(removed, "Purged expired SSO sessions");
}
}
});
}
// Background task: purge approved enrollment PKI bundles every 10 minutes.
{
let approved = approved_enrollments.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(600));
loop {
interval.tick().await;
approved.clear();
}
});
}
let state = AppState {
db: pool,
config: Arc::new(config.clone()),
signing_key_pem,
auth_config,
ws_tickets,
sso_sessions,
ca: Arc::new(ca),
approved_enrollments,
oidc_cache,
};
let app = build_router(state);
let addr: SocketAddr = format!("{}:{}", config.server.host, config.server.port)
.parse()
.expect("Invalid bind address");
// Try to load TLS certificate and key; fall back to plain HTTP if missing.
let tls_cert = std::path::Path::new(&config.security.web_tls_cert_path);
let tls_key = std::path::Path::new(&config.security.web_tls_key_path);
if tls_cert.exists() && tls_key.exists() {
let tls_config = RustlsConfig::from_pem_file(
&config.security.web_tls_cert_path,
&config.security.web_tls_key_path,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load TLS certificates");
e
})?;
tracing::info!(%addr, "Listening (HTTPS)");
axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
} else {
tracing::warn!(
cert_path = %config.security.web_tls_cert_path,
key_path = %config.security.web_tls_key_path,
"TLS certificates not found — falling back to plain HTTP. \
This is insecure for production!"
);
tracing::info!(%addr, "Listening (HTTP — no TLS)");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
}
Ok(())
}
/// Construct the full Axum router.
pub fn build_router(state: AppState) -> Router {
let static_dir = state.config.server.static_dir.clone();
let auth_config = state.auth_config.clone();
let rl = &state.config.rate_limit;
// Enrollment rate limiting: strict (5 req/min per IP, burst 3)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 12_000ms = ~5/min sustained
let enrollment_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(12_000)
.burst_size(rl.enrollment_burst)
.finish()
.expect("Invalid enrollment governor config"),
);
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 3_000ms = ~20/min sustained
let auth_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(3_000)
.burst_size(rl.auth_burst)
.finish()
.expect("Invalid auth governor config"),
);
// API rate limiting: normal (120 req/min per IP, burst 30)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 500ms = ~120/min sustained
let api_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(500)
.burst_size(rl.api_burst)
.finish()
.expect("Invalid API governor config"),
);
// Enrollment routes with strict per-IP rate limiting
let enrollment_router =
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
// Public auth routes with moderate per-IP rate limiting
let auth_public_router =
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
// SSO routes with moderate per-IP rate limiting
let sso_public_router =
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
let sso_azure_router =
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
// All protected API routes — require valid JWT, with normal per-IP rate limiting
let protected_api = Router::new()
// Auth: MFA setup/verify
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
.nest("/auth", routes::auth::protected_router())
// Hosts
.nest("/hosts", routes::hosts::router())
// Host-scoped certificate endpoints (merged separately to avoid conflict)
.nest("/hosts", routes::ca::host_cert_router())
// Groups
.nest("/groups", routes::groups::router())
// Users
.nest("/users", routes::users::router())
// Discovery
.nest("/discovery", routes::discovery::router())
// Fleet status
.nest("/status", routes::status::router())
// Patch jobs
.nest("/jobs", routes::jobs::router())
// Maintenance windows (nested under hosts path param)
.nest(
"/hosts/{host_id}/maintenance-windows",
routes::maintenance_windows::router(),
)
// Maintenance windows — bulk list-all endpoint
.nest(
"/maintenance-windows",
routes::maintenance_windows::all_windows_router(),
)
// CA root certificate download
.nest("/ca", routes::ca::ca_router())
// Certificate list / renew / revoke
.nest("/certificates", routes::ca::certs_router())
// WS ticket issuance (JWT-protected — ticket returned to browser, then used for WS upgrade)
.merge(routes::ws::ticket_router())
// Reports
.nest("/reports", routes::reports::router())
.nest(
"/hosts/{host_id}/health-checks",
routes::health_checks::router(),
)
// Settings (admin-only)
.nest("/settings", routes::settings::router())
// Admin enrollment routes (JWT protected, Admin role enforced)
.nest("/admin", routes::enrollment::admin_router())
// Apply rate limiting then auth middleware
.layer(GovernorLayer::new(api_governor))
.route_layer(middleware::from_fn(move |req, next| {
let auth_config = auth_config.clone();
require_auth(auth_config, req, next)
}));
Router::new()
.route("/status/health", get(health_handler))
// Public auth routes (rate-limited, no JWT)
.nest("/api/v1/auth", auth_public_router)
// Public enrollment endpoints (rate-limited, no JWT)
.nest("/api/v1", enrollment_router)
// Public SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/sso", sso_public_router)
// Public Azure SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/azure", sso_azure_router)
// Protected API routes (JWT required, rate-limited)
.nest("/api/v1", protected_api)
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
.merge(routes::ws::ws_router())
// Serve React SPA
.fallback_service(
ServeDir::new(&static_dir)
.append_index_html_on_directories(true)
.fallback(ServeFile::new(format!("{}/index.html", static_dir))),
)
.layer(middleware::from_fn(request_id_middleware))
.layer(TraceLayer::new_for_http())
.with_state(state)
}
async fn health_handler(State(state): State<AppState>) -> Result<Json<Value>, StatusCode> {
let db_ok = sqlx::query("SELECT 1").execute(&state.db).await.is_ok();
let status = if db_ok { "healthy" } else { "degraded" };
let body = json!({ "service": "patch-manager-web", "version": env!("CARGO_PKG_VERSION"), "status": status, "database": if db_ok { "ok" } else { "error" } });
if db_ok {
Ok(Json(body))
} else {
Err(StatusCode::SERVICE_UNAVAILABLE)
}
}

434
crates/pm-web/src/routes/auth.rs Executable file
View File

@ -0,0 +1,434 @@
//! Authentication route handlers.
//!
//! Public routes (no auth required):
//! POST /api/v1/auth/login
//! POST /api/v1/auth/refresh
//! POST /api/v1/auth/logout
//!
//! Protected routes (JWT required):
//! GET /api/v1/auth/mfa/setup
//! POST /api/v1/auth/mfa/verify
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::Json,
routing::delete,
routing::{get, post},
Router,
};
use pm_auth::{
hash_password, mfa_totp,
rbac::AuthUser,
session::{self, LoginRequest, LoginResponse},
validate_password_strength, verify_password,
};
use serde::Deserialize;
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
// ============================================================
// Public router — no authentication required
// ============================================================
pub fn public_router() -> Router<AppState> {
Router::new()
.route("/login", post(login_handler))
.route("/refresh", post(refresh_handler))
.route("/logout", post(logout_handler))
.route(
"/force-change-password",
post(force_change_password_handler),
)
}
// ============================================================
// Protected router — requires valid JWT (applied by caller)
// ============================================================
pub fn protected_router() -> Router<AppState> {
Router::new()
.route("/mfa/setup", get(mfa_setup_handler))
.route("/mfa/verify", post(mfa_verify_handler))
.route("/mfa", delete(disable_mfa))
}
// ============================================================
// Helpers
// ============================================================
fn user_agent(headers: &HeaderMap) -> Option<String> {
headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.map(str::to_string)
}
fn remote_ip(headers: &HeaderMap) -> Option<String> {
headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or("").trim().to_string())
}
// ============================================================
// POST /api/v1/auth/login
// ============================================================
async fn login_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<LoginRequest>,
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
let ip = remote_ip(&headers);
let ua = user_agent(&headers);
session::login(
&state.db,
&req,
&state.signing_key_pem,
state.config.security.jwt_access_ttl_secs as i64,
ua.as_deref(),
ip.as_deref(),
)
.await
.map(Json)
.map_err(|e| {
use pm_auth::session::SessionError;
let (status, code, message) = match e {
SessionError::InvalidCredentials | SessionError::InvalidMfaCode => (
StatusCode::UNAUTHORIZED,
"invalid_credentials",
"Invalid username or password",
),
SessionError::MfaRequired => (
StatusCode::UNAUTHORIZED,
"mfa_required",
"MFA code required",
),
SessionError::AccountDisabled => (
StatusCode::FORBIDDEN,
"account_disabled",
"Account is disabled",
),
SessionError::PasswordResetRequired => (
StatusCode::FORBIDDEN,
"password_reset_required",
"Password reset is required before login",
),
SessionError::AccountLocked => (
StatusCode::LOCKED,
"account_locked",
"Account is locked due to too many failed login attempts",
),
_ => {
tracing::error!(error = %e, "Login error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"An error occurred",
)
},
};
(
status,
Json(json!({ "error": { "code": code, "message": message } })),
)
})
}
// ============================================================
// POST /api/v1/auth/refresh
// ============================================================
#[derive(Debug, Deserialize)]
struct RefreshRequest {
refresh_token: String,
}
async fn refresh_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<RefreshRequest>,
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
let ip = remote_ip(&headers);
let ua = user_agent(&headers);
session::refresh_session(
&state.db,
&req.refresh_token,
&state.signing_key_pem,
state.config.security.jwt_access_ttl_secs as i64,
ua.as_deref(),
ip.as_deref(),
)
.await
.map(Json)
.map_err(|e| {
use pm_auth::session::SessionError;
let (status, code, msg) = match e {
SessionError::Refresh(_) => (
StatusCode::UNAUTHORIZED,
"invalid_refresh_token",
"Refresh token is invalid or expired",
),
SessionError::AccountDisabled => (
StatusCode::FORBIDDEN,
"account_disabled",
"Account is disabled",
),
_ => {
tracing::error!(error = %e, "Refresh error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"An error occurred",
)
},
};
(
status,
Json(json!({ "error": { "code": code, "message": msg } })),
)
})
}
// ============================================================
// POST /api/v1/auth/logout
// ============================================================
#[derive(Debug, Deserialize)]
struct LogoutRequest {
refresh_token: String,
}
async fn logout_handler(
State(state): State<AppState>,
Json(req): Json<LogoutRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
session::logout(&state.db, &req.refresh_token)
.await
.map(|_| Json(json!({ "message": "Logged out successfully" })))
.map_err(|e| {
tracing::error!(error = %e, "Logout error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "An error occurred" } })),
)
})
}
// ============================================================
// GET /api/v1/auth/mfa/setup (JWT required — via middleware)
// ============================================================
// ============================================================
// POST /api/v1/auth/force-change-password (PUBLIC — no JWT)
// ============================================================
#[derive(Debug, Deserialize)]
struct ForceChangePasswordRequest {
username: String,
current_password: String,
new_password: String,
}
async fn force_change_password_handler(
State(state): State<AppState>,
Json(req): Json<ForceChangePasswordRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Validate new password strength
if let Err(msg) = validate_password_strength(&req.new_password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
// Look up user by username
let row: Option<(Uuid, Option<String>, bool)> = sqlx::query_as(
"SELECT id, password_hash, force_password_reset FROM users WHERE username = $1 AND auth_provider = 'local'",
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to fetch user");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let (user_id, hash_opt, _force_reset) = match row {
Some(r) => r,
None => {
return Err((
StatusCode::UNAUTHORIZED,
Json(
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
),
));
},
};
// Verify current password
let hash_str = hash_opt.as_deref().unwrap_or("");
let valid = verify_password(&req.current_password, hash_str).unwrap_or(false);
if !valid {
return Err((
StatusCode::UNAUTHORIZED,
Json(
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
),
));
}
// Hash and update password, clear force_password_reset
let new_hash = hash_password(&req.new_password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
sqlx::query(
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, failed_login_attempts = 0, locked_until = NULL, updated_at = NOW() WHERE id = $2",
)
.bind(&new_hash)
.bind(user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to update password");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
)
})?;
tracing::info!(user_id = %user_id, username = %req.username, "Password changed via force-change-password");
Ok(Json(json!({ "message": "Password changed successfully" })))
}
async fn mfa_setup_handler(
auth_user: AuthUser,
) -> Result<Json<mfa_totp::TotpSetup>, (StatusCode, Json<Value>)> {
mfa_totp::generate_setup(&auth_user.username)
.map(Json)
.map_err(|e| {
tracing::error!(error = %e, "TOTP setup error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})
}
// ============================================================
// POST /api/v1/auth/mfa/verify (JWT required — via middleware)
// ============================================================
#[derive(Debug, Deserialize)]
struct MfaVerifyRequest {
secret_base32: String,
code: String,
}
async fn mfa_verify_handler(
State(state): State<AppState>,
auth_user: AuthUser,
Json(req): Json<MfaVerifyRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let valid =
mfa_totp::verify_code(&auth_user.username, &req.secret_base32, &req.code).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
if !valid {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "invalid_code", "message": "Invalid TOTP code" } })),
));
}
sqlx::query("UPDATE users SET totp_secret = $1, mfa_enabled = TRUE WHERE id = $2")
.bind(&req.secret_base32)
.bind(auth_user.user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to save TOTP secret");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to enable MFA" } })),
)
})?;
tracing::info!(user_id = %auth_user.user_id, "MFA enabled for user");
Ok(Json(json!({ "message": "MFA enabled successfully" })))
}
// ============================================================
// DELETE /api/v1/auth/mfa (JWT required — disable own MFA)
// ============================================================
#[derive(Debug, Deserialize)]
struct DisableMfaRequest {
password: String,
}
async fn disable_mfa(
State(state): State<AppState>,
auth_user: AuthUser,
Json(req): Json<DisableMfaRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Verify current password to confirm identity
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
.bind(auth_user.user_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to fetch password hash");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?
.flatten();
let hash_str = hash.unwrap_or_default();
let valid = verify_password(&req.password, &hash_str).unwrap_or(false);
if !valid {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
),
));
}
sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE WHERE id = $1")
.bind(auth_user.user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to disable MFA");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
)
})?;
tracing::info!(user_id = %auth_user.user_id, "MFA disabled for user");
Ok(Json(json!({ "message": "MFA disabled successfully" })))
}

516
crates/pm-web/src/routes/ca.rs Executable file
View File

@ -0,0 +1,516 @@
//! CA / certificate management routes.
//!
//! ca_router() → mounted at /api/v1/ca
//! GET /root.crt download_root_ca (any authed role)
//!
//! certs_router() → mounted at /api/v1/certificates
//! GET / list_certificates (any authed role)
//! POST /:cert_id/renew renew_cert (admin only)
//! DELETE /:cert_id revoke_cert (admin only)
//!
//! host_cert_router() → merged under /api/v1/hosts
//! GET /:host_id/client.crt download_client_cert (admin only)
//! POST /:host_id/certificates issue_client_cert (admin only)
//! POST /:host_id/certificates/reissue reissue_host_cert (admin only)
use axum::{
body::Body,
extract::{Path, Query, State},
http::{header, Response, StatusCode},
response::Json,
routing::{delete, get, post},
Router,
};
use chrono::{DateTime, Utc};
use pm_auth::rbac::AuthUser;
use pm_core::audit::{log_event, AuditAction};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlx::Row;
use uuid::Uuid;
use crate::AppState;
// ── Router constructors ───────────────────────────────────────────────────────
/// Handles routes mounted at /api/v1/ca
pub fn ca_router() -> Router<AppState> {
Router::new().route("/root.crt", get(download_root_ca))
}
/// Handles routes mounted at /api/v1/certificates
pub fn certs_router() -> Router<AppState> {
Router::new()
.route("/", get(list_certificates))
.route("/{cert_id}/renew", post(renew_cert))
.route("/{cert_id}", delete(revoke_cert))
}
/// Handles cert-specific paths merged under /api/v1/hosts.
/// Only adds paths not already claimed by the hosts router.
pub fn host_cert_router() -> Router<AppState> {
Router::new()
.route("/{host_id}/client.crt", get(download_client_cert))
.route("/{host_id}/certificates", post(issue_client_cert))
.route("/{host_id}/certificates/reissue", post(reissue_host_cert))
}
// ── Shared types ──────────────────────────────────────────────────────────────
/// Row returned from the `certificates` table.
#[derive(Debug, Serialize, sqlx::FromRow)]
struct CertRow {
id: Uuid,
host_id: Option<Uuid>,
serial_number: String,
common_name: String,
/// Cast to TEXT in all queries to avoid custom-enum decode.
status: String,
issued_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
/// Query params for `list_certificates`.
#[derive(Debug, Deserialize)]
struct CertListQuery {
host_id: Option<Uuid>,
status: Option<String>,
}
/// Request body for `issue_client_cert`.
#[derive(Debug, Deserialize)]
struct IssueCertRequest {
hostname: String,
}
// ── Helper: build PEM download response ──────────────────────────────────────
fn pem_response(pem: String, filename: &str) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
let disposition = format!("attachment; filename=\"{filename}\"");
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/x-pem-file")
.header(header::CONTENT_DISPOSITION, disposition)
.body(Body::from(pem))
.map_err(|e| {
tracing::error!(error = %e, "Failed to build PEM response");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Response build error" } })),
)
})
}
// ── Helper: admin-only guard ──────────────────────────────────────────────────
fn require_write_access(user: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
if !user.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
Ok(())
}
// ── Helper: map sqlx error to 500 ─────────────────────────────────────────────
fn db_error(e: sqlx::Error) -> (StatusCode, Json<Value>) {
tracing::error!(error = %e, "Database error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
}
// ── Helper: build the full IssuedCert JSON response ──────────────────────────
fn issued_cert_json(issued: &pm_ca::IssuedCert) -> Value {
json!({
"cert_pem": issued.cert_pem,
"key_pem": issued.key_pem,
"serial_number": issued.serial_number,
"expires_at": issued.expires_at,
"server_cert_pem": issued.server_cert_pem,
"server_key_pem": issued.server_key_pem,
"server_serial_number": issued.server_serial_number,
"ca_root_pem": issued.ca_root_pem,
})
}
// ── GET /api/v1/ca/root.crt ───────────────────────────────────────────────────
/// Download the root CA certificate as a PEM file.
async fn download_root_ca(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
let pem = state.ca.root_cert_pem().to_owned();
log_event(
&state.db,
AuditAction::CertificateDownloaded,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some("root_ca"),
json!({ "operation": "download_root_ca" }),
None,
None,
)
.await;
pem_response(pem, "ca.crt")
}
// ── GET /api/v1/certificates ──────────────────────────────────────────────────
/// List certificates with optional `?host_id=` and `?status=` filters.
async fn list_certificates(
State(state): State<AppState>,
_auth: AuthUser,
Query(q): Query<CertListQuery>,
) -> Result<Json<Vec<CertRow>>, (StatusCode, Json<Value>)> {
// Use the non-macro query_as form — avoids needing DATABASE_URL at compile
// time. status is cast to TEXT so sqlx decodes it into String directly.
let rows: Vec<CertRow> = match (q.host_id, q.status.as_deref()) {
(Some(hid), Some(st)) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
WHERE host_id = $1 AND status::text = $2
ORDER BY issued_at DESC"#,
)
.bind(hid)
.bind(st)
.fetch_all(&state.db)
.await
},
(Some(hid), None) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
WHERE host_id = $1
ORDER BY issued_at DESC"#,
)
.bind(hid)
.fetch_all(&state.db)
.await
},
(None, Some(st)) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
WHERE status::text = $1
ORDER BY issued_at DESC"#,
)
.bind(st)
.fetch_all(&state.db)
.await
},
(None, None) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
ORDER BY issued_at DESC"#,
)
.fetch_all(&state.db)
.await
},
}
.map_err(db_error)?;
Ok(Json(rows))
}
// ── GET /api/v1/hosts/:host_id/client.crt ────────────────────────────────────
/// Download the most recent active client certificate PEM for a host.
async fn download_client_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
let cert_pem: Option<String> = sqlx::query_scalar(
r#"SELECT cert_pem
FROM certificates
WHERE host_id = $1
AND status = 'active'::cert_status
AND common_name NOT LIKE '%-server'
ORDER BY issued_at DESC
LIMIT 1"#,
)
.bind(host_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to fetch client cert");
db_error(e)
})?;
match cert_pem {
Some(pem) => {
log_event(
&state.db,
AuditAction::CertificateDownloaded,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&host_id.to_string()),
json!({ "operation": "download_client_cert" }),
None,
None,
)
.await;
pem_response(pem, "client.crt")
},
None => Err((
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"code": "not_found",
"message": "No active certificate found for this host"
}
})),
)),
}
}
// ── POST /api/v1/hosts/:host_id/certificates ─────────────────────────────────
/// Issue a new mTLS client certificate (and server certificate) for a host.
/// **The private keys are returned only once — the caller must save them.**
async fn issue_client_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
Json(req): Json<IssueCertRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
// Look up the host's IP address from the database.
let ip_address: String = sqlx::query_scalar("SELECT host(ip_address) FROM hosts WHERE id = $1")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to fetch host IP address");
if e.to_string().contains("no rows") {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
} else {
db_error(e)
}
})?;
let issued = state
.ca
.issue_client_cert(host_id, &req.hostname, &ip_address, &state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, hostname = %req.hostname,
"Failed to issue client cert");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::CertificateIssued,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&host_id.to_string()),
json!({ "hostname": req.hostname, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
None,
None,
)
.await;
Ok(Json(issued_cert_json(&issued)))
}
// ── POST /api/v1/certificates/:cert_id/renew ─────────────────────────────────
/// Revoke the specified certificate and issue a replacement with the same CN.
async fn renew_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(cert_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
let issued = state.ca.renew_cert(cert_id, &state.db).await.map_err(|e| {
let msg = e.to_string();
tracing::error!(error = %e, %cert_id, "Failed to renew cert");
if msg.contains("not found") {
(
StatusCode::NOT_FOUND,
Json(
json!({ "error": { "code": "not_found", "message": "Certificate not found" } }),
),
)
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
)
}
})?;
log_event(
&state.db,
AuditAction::CertificateRenewed,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&cert_id.to_string()),
json!({ "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
None,
None,
)
.await;
Ok(Json(issued_cert_json(&issued)))
}
// ── POST /api/v1/hosts/:host_id/certificates/reissue ────────────────────────
/// Revoke ALL active certificates for a host and issue new ones.
/// The private keys are returned only once — the caller must save them.
async fn reissue_host_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
// Look up the host's FQDN and IP address for the new certificate CN and SANs.
let row = sqlx::query("SELECT fqdn, host(ip_address) AS ip_address FROM hosts WHERE id = $1")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to fetch host FQDN/IP");
if e.to_string().contains("no rows") {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
} else {
db_error(e)
}
})?;
let fqdn: String = row.try_get("fqdn").map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to read fqdn");
db_error(e)
})?;
let ip_address: String = row.try_get("ip_address").map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to read ip_address");
db_error(e)
})?;
// Revoke all active certificates for this host.
let revoked = sqlx::query(
"UPDATE certificates SET status = 'revoked'::cert_status, revoked_at = NOW() \
WHERE host_id = $1 AND status = 'active'::cert_status",
)
.bind(host_id)
.execute(&state.db)
.await
.map_err(db_error)?;
tracing::info!(%host_id, rows_revoked = revoked.rows_affected(), "Revoked all active certs for host");
// Issue a new certificate bundle using the host's FQDN and IP.
let issued = state
.ca
.issue_client_cert(host_id, &fqdn, &ip_address, &state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to issue new cert during reissue");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::CertificateReissued,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&host_id.to_string()),
json!({ "hostname": &fqdn, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number, "rows_revoked": revoked.rows_affected() }),
None,
None,
)
.await;
Ok(Json(issued_cert_json(&issued)))
}
// ── DELETE /api/v1/certificates/:cert_id ─────────────────────────────────────
/// Revoke a certificate by ID. Sets status to 'revoked' in the database.
async fn revoke_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(cert_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
state
.ca
.revoke_cert(cert_id, &state.db)
.await
.map_err(|e| {
let msg = e.to_string();
tracing::error!(error = %e, %cert_id, "Failed to revoke cert");
if msg.contains("not found") {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Certificate not found" } })),
)
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
)
}
})?;
tracing::info!(%cert_id, "Certificate revoked via API");
log_event(
&state.db,
AuditAction::CertificateRevoked,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&cert_id.to_string()),
json!({ "operation": "revoke" }),
None,
None,
)
.await;
Ok(Json(json!({ "revoked": true })))
}

View File

@ -0,0 +1,304 @@
//! CIDR auto-discovery routes.
//!
//! POST /api/v1/discovery/cidr — start a CIDR scan
//! GET /api/v1/discovery/:scan_id — get scan results
//! POST /api/v1/discovery/:id/register — register a discovered host
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{DiscoveryCidrRequest, DiscoveryResult, RegisterDiscoveredRequest},
};
use serde_json::{json, Value};
use std::{
net::{IpAddr, TcpStream},
time::Duration,
};
use tokio::{sync::Semaphore, task};
use uuid::Uuid;
use crate::AppState;
/// Maximum concurrent TCP probes during CIDR scan.
const MAX_CONCURRENT_PROBES: usize = 128;
/// TCP connect timeout per probe.
const PROBE_TIMEOUT_SECS: u64 = 2;
pub fn router() -> Router<AppState> {
Router::new()
.route("/cidr", post(start_cidr_scan))
.route("/{scan_id}", get(get_scan_results))
.route("/{id}/register", post(register_discovered_host))
}
// ── POST /api/v1/discovery/cidr ───────────────────────────────────────────────
async fn start_cidr_scan(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<DiscoveryCidrRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let cidr: ipnet::IpNet = req.cidr.parse().map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "bad_request", "message": "Invalid CIDR range" } })),
)
})?;
let agent_port = req.agent_port.unwrap_or(12443) as u16;
let scan_id = Uuid::new_v4();
// Clear previous results for this type of scan and start async scan
let pool = state.db.clone();
let scan_id_clone = scan_id;
let cidr_str = req.cidr.clone();
// Spawn non-blocking background scan
task::spawn(async move {
run_cidr_scan(pool, scan_id_clone, cidr, agent_port).await;
});
log_event(
&state.db,
AuditAction::DiscoveryScanStarted,
Some(auth.user_id),
Some(&auth.username),
Some("discovery"),
Some(&scan_id.to_string()),
json!({ "cidr": cidr_str }),
None,
None,
)
.await;
tracing::info!(scan_id = %scan_id, cidr = %req.cidr, "CIDR scan started");
Ok(Json(
json!({ "scan_id": scan_id, "message": "Discovery scan started", "cidr": req.cidr }),
))
}
/// Background CIDR scanner.
async fn run_cidr_scan(pool: sqlx::PgPool, scan_id: Uuid, cidr: ipnet::IpNet, port: u16) {
let semaphore = std::sync::Arc::new(Semaphore::new(MAX_CONCURRENT_PROBES));
let hosts: Vec<IpAddr> = cidr.hosts().collect();
let total = hosts.len();
tracing::info!(scan_id = %scan_id, total = total, "CIDR scan probing {} hosts", total);
let mut handles = Vec::new();
for ip in hosts {
let sem = semaphore.clone();
let pool_clone = pool.clone();
let h = task::spawn(async move {
let _permit = sem.acquire().await.ok()?;
probe_and_store(pool_clone, scan_id, ip, port).await
});
handles.push(h);
}
for h in handles {
let _ = h.await;
}
tracing::info!(scan_id = %scan_id, "CIDR scan complete");
}
/// Probe a single IP:port and store the result if the port is open.
async fn probe_and_store(pool: sqlx::PgPool, scan_id: Uuid, ip: IpAddr, port: u16) -> Option<()> {
let addr = format!("{ip}:{port}");
// TCP connect probe (blocking, run in thread pool)
// TCP connect probe (blocking, run in thread pool)
let addr_clone = addr.clone();
let open = task::spawn_blocking(move || {
TcpStream::connect_timeout(
&match addr_clone.parse() {
Ok(a) => a,
Err(_) => return false,
},
Duration::from_secs(PROBE_TIMEOUT_SECS),
)
.is_ok()
})
.await
.unwrap_or(false);
if !open {
return None;
}
// Reverse DNS lookup (best-effort)
let ip_clone = ip;
let fqdn = task::spawn_blocking(move || {
use std::net::ToSocketAddrs;
let addr = format!("{ip_clone}:{port}");
addr.to_socket_addrs()
.ok()
.and_then(|mut a| a.next())
.and_then(|_| dns_lookup_for_ip(ip_clone))
})
.await
.ok()
.flatten();
let _ = sqlx::query(
r#"INSERT INTO discovery_results (scan_id, ip_address, fqdn, agent_port)
VALUES ($1, $2::inet, $3, $4)
ON CONFLICT DO NOTHING"#,
)
.bind(scan_id)
.bind(ip.to_string())
.bind(fqdn)
.bind(port as i32)
.execute(&pool)
.await;
tracing::debug!(ip = %ip, port = port, "Discovered agent");
Some(())
}
/// Simple reverse DNS lookup.
fn dns_lookup_for_ip(ip: IpAddr) -> Option<String> {
use std::net::{SocketAddr, ToSocketAddrs};
let _addr = SocketAddr::new(ip, 0);
// Standard library doesn't have reverse lookup; use getaddrinfo via format
let host = format!("{ip}");
// Best-effort: try to resolve numeric address to hostname
(host + ":0")
.to_socket_addrs()
.ok()?
.next()
.map(|a| a.ip().to_string())
.filter(|s| s != &ip.to_string())
}
// ── GET /api/v1/discovery/:scan_id ────────────────────────────────────────────
async fn get_scan_results(
State(state): State<AppState>,
_auth: AuthUser,
Path(scan_id): Path<Uuid>,
) -> Result<Json<Vec<DiscoveryResult>>, (StatusCode, Json<Value>)> {
sqlx::query_as::<_, DiscoveryResult>(
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
agent_version, os_name, agent_port, discovered_at, registered
FROM discovery_results
WHERE scan_id = $1
ORDER BY ip_address"#,
)
.bind(scan_id)
.fetch_all(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})
}
// ── POST /api/v1/discovery/:id/register ──────────────────────────────────────
async fn register_discovered_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<RegisterDiscoveredRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Fetch discovery result
let result: Option<DiscoveryResult> = sqlx::query_as(
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
agent_version, os_name, agent_port, discovered_at, registered
FROM discovery_results WHERE id = $1"#,
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
let result = result.ok_or_else(|| (
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Discovery result not found" } }))
))?;
let fqdn = result.fqdn.as_deref().unwrap_or(&result.ip_address);
let display_name = req.display_name.as_deref().unwrap_or(fqdn);
let host_id: Uuid = sqlx::query_scalar(
r#"INSERT INTO hosts (fqdn, ip_address, display_name, agent_port)
VALUES ($1, $2::inet, $3, $4)
ON CONFLICT DO NOTHING
RETURNING id"#,
)
.bind(fqdn)
.bind(&result.ip_address)
.bind(display_name)
.bind(result.agent_port)
.fetch_one(&state.db)
.await
.map_err(|e| {
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": e.to_string() } })),
)
})?;
// Assign to groups
if let Some(group_ids) = &req.group_ids {
for gid in group_ids {
let _ = sqlx::query("INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING")
.bind(host_id).bind(gid).execute(&state.db).await;
}
}
// Mark as registered
let _ = sqlx::query("UPDATE discovery_results SET registered = TRUE WHERE id = $1")
.bind(id)
.execute(&state.db)
.await;
log_event(
&state.db,
AuditAction::HostRegistered,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&host_id.to_string()),
json!({ "from_discovery": true, "ip": result.ip_address }),
None,
None,
)
.await;
Ok(Json(
json!({ "host_id": host_id, "message": "Host registered from discovery" }),
))
}

View File

@ -0,0 +1,319 @@
use crate::AppState;
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
Json, Router,
};
use chrono::Utc;
use pm_auth::AuthUser;
use pm_core::{
db,
models::{
CreateEnrollmentRequest, EnrollmentRequest, EnrollmentStatusResponse, Host, PkiBundle,
},
};
use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize;
#[derive(Debug, Clone, Serialize)]
pub struct HostConflict {
pub existing_host: Host,
pub message: String,
}
/// Define public enrollment routes.
pub fn router() -> Router<AppState> {
Router::new()
.route("/enroll", post(enroll_host))
.route("/enroll/status/{token}", get(enroll_status))
}
/// POST /api/v1/enroll
/// Initiates host self-enrollment.
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
async fn enroll_host(
State(state): State<AppState>,
Json(payload): Json<CreateEnrollmentRequest>,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
// Generate secure random polling token
let polling_token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
// For database storage, we'll hash the token (spec says hashed)
// Using a simple SHA256 or similar for the hash storage
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(polling_token.as_bytes());
let token_hash = hex::encode(hasher.finalize());
// 3. Store in DB
db::create_enrollment_request(&state.db, payload, token_hash)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to create enrollment request");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// 4. Return the raw token to the client
Ok((
StatusCode::ACCEPTED,
Json(serde_json::json!({ "polling_token": polling_token })),
)
.into_response())
}
/// GET /api/v1/enroll/status/{token}
/// Returns status of enrollment (pending/approved/denied/not_found).
async fn enroll_status(
State(state): State<AppState>,
Path(token): Path<String>,
) -> Result<Json<EnrollmentStatusResponse>, (StatusCode, Json<serde_json::Value>)> {
// Hash the provided token to match DB
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
let token_hash = hex::encode(hasher.finalize());
// 1. Check enrollment_requests table
let requests = db::list_enrollment_requests(&state.db).await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
if let Some(req) = requests.into_iter().find(|r| r.polling_token == token_hash) {
if req.expires_at < Utc::now() {
return Ok(Json(EnrollmentStatusResponse::NotFound));
}
return Ok(Json(EnrollmentStatusResponse::Pending));
}
// 2. If not in pending, check if it was recently approved.
if let Some(pki) = state.approved_enrollments.get(&token_hash) {
return Ok(Json(EnrollmentStatusResponse::Approved {
ca_crt: pki.ca_crt.clone(),
server_crt: pki.server_crt.clone(),
server_key: pki.server_key.clone(),
}));
}
Ok(Json(EnrollmentStatusResponse::NotFound))
}
/// Define admin enrollment routes.
pub fn admin_router() -> Router<AppState> {
Router::new()
.route("/enrollments", get(list_admin_enrollments))
.route("/enrollments/{id}/approve", post(approve_enrollment))
.route("/enrollments/{id}/deny", delete(deny_enrollment))
}
/// GET /api/v1/admin/enrollments
/// Lists all pending enrollment requests.
async fn list_admin_enrollments(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Vec<EnrollmentRequest>>, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
db::list_enrollment_requests(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e, "Failed to list enrollment requests");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})
}
/// POST /api/v1/admin/enrollments/{id}/approve
/// Approves a pending enrollment request, generates PKI, and moves to hosts table.
async fn approve_enrollment(
State(state): State<AppState>,
Path(id): Path<uuid::Uuid>,
auth: AuthUser,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
// Fetch the enrollment request
let mut requests = db::list_enrollment_requests(&state.db).await.map_err(|e| {
tracing::error!(error = %e, "Failed to list enrollment requests for approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
let enrollment_request = match requests.iter().position(|r| r.id == id) {
Some(idx) => requests.remove(idx),
None => return Ok(StatusCode::NOT_FOUND),
};
// Check for FQDN/IP collision in hosts table
if let Some(existing_host) = sqlx::query_as::<_, Host>(
"SELECT id, fqdn, ip_address::text, display_name, os_family, os_name, arch, agent_version, health_status, last_health_at, last_patch_at, agent_port, notes, registered_at, updated_at FROM hosts WHERE fqdn = $1 OR ip_address = $2::inet"
)
.bind(&enrollment_request.fqdn)
.bind(enrollment_request.ip_address.to_string())
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to check for host collision");
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": "Database error" })))
})? {
return Err((
StatusCode::CONFLICT,
Json(serde_json::json!({ "error": "Host collision detected", "conflict": HostConflict { existing_host, message: "FQDN or IP already exists".to_string() } }))
));
}
// Move to hosts table FIRST (certificates table has FK reference to hosts)
let os_family = enrollment_request
.os_details
.get("os")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let os_name = enrollment_request
.os_details
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
// Build os_name from os + os_version if "name" is absent
let os = enrollment_request
.os_details
.get("os")
.and_then(|v| v.as_str())?;
let ver = enrollment_request
.os_details
.get("os_version")
.and_then(|v| v.as_str())
.unwrap_or("");
Some(format!("{} {}", os, ver).trim().to_string())
});
let arch = enrollment_request
.os_details
.get("architecture")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let display_name = enrollment_request
.hostname
.clone()
.unwrap_or_else(|| enrollment_request.fqdn.clone());
sqlx::query(
r#"
INSERT INTO hosts (id, fqdn, ip_address, os_family, os_name, arch, display_name, registered_at, updated_at)
VALUES ($1, $2, $3::inet, $4, $5, $6, $7, NOW(), NOW())
"#,
)
.bind(enrollment_request.id)
.bind(&enrollment_request.fqdn)
.bind(enrollment_request.ip_address.to_string())
.bind(&os_family)
.bind(&os_name)
.bind(&arch)
.bind(&display_name)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to insert host after approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// Generate PKI bundle using CA (after host row exists)
let issued = state
.ca
.issue_client_cert(
enrollment_request.id,
&enrollment_request.fqdn,
&enrollment_request.ip_address.to_string(),
&state.db,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to issue client certificate");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Certificate generation failed" })),
)
})?;
// Delete from enrollment_requests table
db::delete_enrollment_request(&state.db, id)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to delete enrollment request after approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// Store PKI bundle in cache for client retrieval
let pki = PkiBundle {
ca_crt: issued.ca_root_pem,
server_crt: issued.server_cert_pem,
server_key: issued.server_key_pem,
};
state
.approved_enrollments
.insert(enrollment_request.polling_token.clone(), pki);
Ok(StatusCode::OK)
}
/// DELETE /api/v1/admin/enrollments/{id}/deny
/// Denies and purges a pending enrollment request.
async fn deny_enrollment(
State(state): State<AppState>,
Path(id): Path<uuid::Uuid>,
auth: AuthUser,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
db::delete_enrollment_request(&state.db, id)
.await
.map(|_| StatusCode::NO_CONTENT)
.map_err(|e| {
tracing::error!(error = %e, "Failed to deny enrollment request");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})
}

View File

@ -0,0 +1,312 @@
//! Group management routes.
//!
//! GET /api/v1/groups — list all groups
//! POST /api/v1/groups — create group (admin)
//! GET /api/v1/groups/:id — get group detail + members
//! PUT /api/v1/groups/:id — update group (admin)
//! DELETE /api/v1/groups/:id — delete group (admin)
//! POST /api/v1/groups/:id/users/:user_id — add user to group (admin)
//! DELETE /api/v1/groups/:id/users/:user_id — remove user from group (admin)
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateGroupRequest, Group, UpdateGroupRequest},
};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_groups).post(create_group))
.route(
"/{id}",
get(get_group).put(update_group).delete(delete_group),
)
.route(
"/{id}/users/{user_id}",
post(add_user_to_group).delete(remove_user_from_group),
)
}
async fn list_groups(
State(state): State<AppState>,
_auth: AuthUser,
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
sqlx::query_as::<_, Group>(
"SELECT id, name, description, created_at, updated_at FROM groups ORDER BY name",
)
.fetch_all(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e, "Failed to list groups");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})
}
async fn create_group(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateGroupRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let id: Uuid =
sqlx::query_scalar("INSERT INTO groups (name, description) VALUES ($1, $2) RETURNING id")
.bind(&req.name)
.bind(req.description.as_deref().unwrap_or(""))
.fetch_one(&state.db)
.await
.map_err(|e| {
let msg = if e.to_string().contains("unique") {
"Group name already exists".to_string()
} else {
"Database error".to_string()
};
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupCreated,
Some(auth.user_id),
Some(&auth.username),
Some("group"),
Some(&id.to_string()),
json!({ "name": req.name }),
None,
None,
)
.await;
Ok(Json(json!({ "id": id, "message": "Group created" })))
}
async fn get_group(
State(state): State<AppState>,
_auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let group: Option<Group> = sqlx::query_as(
"SELECT id, name, description, created_at, updated_at FROM groups WHERE id = $1",
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let group = group.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
)
})?;
// Fetch member counts
let host_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM host_groups WHERE group_id = $1")
.bind(id)
.fetch_one(&state.db)
.await
.unwrap_or(0);
let user_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM user_groups WHERE group_id = $1")
.bind(id)
.fetch_one(&state.db)
.await
.unwrap_or(0);
Ok(Json(
json!({ "group": group, "host_count": host_count, "user_count": user_count }),
))
}
async fn update_group(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<UpdateGroupRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let rows = sqlx::query(
"UPDATE groups SET name = COALESCE($1, name), description = COALESCE($2, description), updated_at = NOW() WHERE id = $3"
)
.bind(req.name.as_deref())
.bind(req.description.as_deref())
.bind(id)
.execute(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
));
}
Ok(Json(json!({ "message": "Group updated" })))
}
async fn delete_group(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let rows = sqlx::query("DELETE FROM groups WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
));
}
log_event(
&state.db,
AuditAction::GroupDeleted,
Some(auth.user_id),
Some(&auth.username),
Some("group"),
Some(&id.to_string()),
json!({}),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Group deleted" })))
}
async fn add_user_to_group(
State(state): State<AppState>,
auth: AuthUser,
Path((id, user_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query(
"INSERT INTO user_groups (user_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
)
.bind(user_id)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("user_group"),
Some(&id.to_string()),
json!({ "user_id": user_id, "action": "added" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User added to group" })))
}
async fn remove_user_from_group(
State(state): State<AppState>,
auth: AuthUser,
Path((id, user_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query("DELETE FROM user_groups WHERE user_id = $1 AND group_id = $2")
.bind(user_id)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("user_group"),
Some(&id.to_string()),
json!({ "user_id": user_id, "action": "removed" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User removed from group" })))
}

File diff suppressed because it is too large Load Diff

678
crates/pm-web/src/routes/hosts.rs Executable file
View File

@ -0,0 +1,678 @@
//! Host management routes.
//!
//! GET /api/v1/hosts — list hosts (RBAC scoped)
//! POST /api/v1/hosts — register new host (admin only)
//! GET /api/v1/hosts/{id} — get host detail
//! DELETE /api/v1/hosts/{id} — remove host (admin only)
//! PUT /api/v1/hosts/{id} — update host (write access)
//! GET /api/v1/hosts/{id}/groups — list groups for host
//! POST /api/v1/hosts/{id}/groups — assign host to group
//! DELETE /api/v1/hosts/{id}/groups/{group_id} — remove host from group
//! POST /api/v1/hosts/{id}/refresh — queue on-demand refresh (write access)
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
routing::{delete, get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateHostRequest, Group, HostSummary, UpdateHostRequest},
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_hosts).post(register_host))
.route("/{id}", get(get_host).put(update_host).delete(remove_host))
.route(
"/{id}/groups",
get(list_host_groups).post(add_host_to_group),
)
.route("/{id}/groups/{group_id}", delete(remove_host_from_group))
.route("/{id}/refresh", post(refresh_host))
}
// ── Query params ─────────────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct HostListQuery {
pub group_id: Option<Uuid>,
pub health_status: Option<String>,
pub os_family: Option<String>,
pub search: Option<String>,
pub limit: Option<i64>,
pub offset: Option<i64>,
}
// ── Response types ────────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct HostListResponse {
hosts: Vec<HostSummary>,
total: i64,
limit: i64,
offset: i64,
}
// ── Helper: check if operator can access a host ───────────────────────────────
async fn operator_can_access_host(
pool: &sqlx::PgPool,
user_id: Uuid,
host_id: Uuid,
) -> Result<bool, sqlx::Error> {
// Admins can access all; operators can access hosts in their groups
// OR ungrouped hosts (no group memberships)
let in_group: bool = sqlx::query_scalar(
r#"
SELECT EXISTS (
SELECT 1 FROM host_groups hg
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE hg.host_id = $1 AND ug.user_id = $2
)
"#,
)
.bind(host_id)
.bind(user_id)
.fetch_one(pool)
.await?;
if in_group {
return Ok(true);
}
// Ungrouped hosts are accessible to any operator
let ungrouped: bool =
sqlx::query_scalar("SELECT NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = $1)")
.bind(host_id)
.fetch_one(pool)
.await?;
Ok(ungrouped)
}
// ── GET /api/v1/hosts ─────────────────────────────────────────────────────────
async fn list_hosts(
State(state): State<AppState>,
auth: AuthUser,
Query(q): Query<HostListQuery>,
) -> Result<Json<HostListResponse>, (StatusCode, Json<Value>)> {
let limit = q.limit.unwrap_or(50).min(200);
let offset = q.offset.unwrap_or(0);
// For operators: only show hosts in their groups (or ungrouped)
let hosts: Vec<HostSummary> = if auth.role.is_admin() {
sqlx::query_as(
r#"
SELECT h.id, h.fqdn, host(h.ip_address)::text AS ip_address, h.display_name,
h.os_family, h.os_name, h.health_status, h.agent_version,
COALESCE(hpd.patch_count, 0) AS patches_missing,
CASE
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
THEN NULL
WHEN EXISTS (
SELECT 1 FROM host_health_checks hc
LEFT JOIN LATERAL (
SELECT healthy FROM host_health_check_results r
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
) lr ON TRUE
WHERE hc.host_id = h.id AND hc.enabled = TRUE
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
)
THEN 'some_unhealthy'
ELSE 'all_healthy'
END AS health_check_status,
h.registered_at
FROM hosts h
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
ORDER BY h.fqdn
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await
} else {
sqlx::query_as(
r#"
SELECT DISTINCT h.id, h.fqdn, host(h.ip_address)::text AS ip_address,
h.display_name, h.os_family, h.os_name,
h.health_status, h.agent_version,
COALESCE(hpd.patch_count, 0) AS patches_missing,
CASE
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
THEN NULL
WHEN EXISTS (
SELECT 1 FROM host_health_checks hc
LEFT JOIN LATERAL (
SELECT healthy FROM host_health_check_results r
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
) lr ON TRUE
WHERE hc.host_id = h.id AND hc.enabled = TRUE
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
)
THEN 'some_unhealthy'
ELSE 'all_healthy'
END AS health_check_status,
h.registered_at
FROM hosts h
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
WHERE
-- Hosts in operator's groups
EXISTS (
SELECT 1 FROM host_groups hg
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE hg.host_id = h.id AND ug.user_id = $3
)
-- OR ungrouped hosts
OR NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = h.id)
ORDER BY h.fqdn
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.bind(auth.user_id)
.fetch_all(&state.db)
.await
}
.map_err(|e| {
tracing::error!(error = %e, "Failed to list hosts");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM hosts")
.fetch_one(&state.db)
.await
.unwrap_or(0);
Ok(Json(HostListResponse {
hosts,
total,
limit,
offset,
}))
}
// ── POST /api/v1/hosts ────────────────────────────────────────────────────────
async fn register_host(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateHostRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Admin only
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Resolve FQDN to IP address
let ip_address = resolve_fqdn(&req.fqdn).await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "fqdn_resolution_failed", "message": e } })),
)
})?;
let display_name = req.display_name.clone().unwrap_or_else(|| req.fqdn.clone());
let agent_port = req.agent_port.unwrap_or(12443);
let notes = req.notes.clone().unwrap_or_default();
// Insert host
let host_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO hosts (fqdn, ip_address, display_name, agent_port, notes)
VALUES ($1, $2::inet, $3, $4, $5)
RETURNING id
"#,
)
.bind(&req.fqdn)
.bind(&ip_address)
.bind(&display_name)
.bind(agent_port)
.bind(&notes)
.fetch_one(&state.db)
.await
.map_err(|e| {
let msg = if e.to_string().contains("unique") {
"Host with this FQDN and IP already exists".to_string()
} else {
"Database error".to_string()
};
tracing::error!(error = %e, "Failed to register host");
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
// Assign to groups if specified
if let Some(group_ids) = &req.group_ids {
for gid in group_ids {
let _ = sqlx::query(
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
)
.bind(host_id)
.bind(gid)
.execute(&state.db)
.await;
}
}
// Audit log
log_event(
&state.db,
AuditAction::HostRegistered,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&host_id.to_string()),
json!({ "fqdn": req.fqdn, "ip": ip_address }),
None,
None,
)
.await;
tracing::info!(host_id = %host_id, fqdn = %req.fqdn, "Host registered");
Ok(Json(json!({ "id": host_id, "message": "Host registered" })))
}
// ── GET /api/v1/hosts/:id ─────────────────────────────────────────────────────
async fn get_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
.await
.unwrap_or(false);
if !can_access {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
}
let host: Option<Value> = sqlx::query_scalar(
r#"
SELECT row_to_json(h) FROM (
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
os_family, os_name, arch, agent_version, health_status,
last_health_at, last_patch_at, agent_port, notes,
registered_at, updated_at
FROM hosts WHERE id = $1
) h
"#,
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to get host");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
host.map(Json).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
})
}
// ── DELETE /api/v1/hosts/:id ──────────────────────────────────────────────────
async fn remove_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Fetch FQDN for audit before deletion
let fqdn: Option<String> = sqlx::query_scalar("SELECT fqdn FROM hosts WHERE id = $1")
.bind(id)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let result = sqlx::query("DELETE FROM hosts WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to remove host");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
if result.rows_affected() == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
));
}
log_event(
&state.db,
AuditAction::HostRemoved,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&id.to_string()),
json!({ "fqdn": fqdn }),
None,
None,
)
.await;
tracing::info!(host_id = %id, "Host removed");
Ok(Json(json!({ "message": "Host removed" })))
}
// ── PUT /api/v1/hosts/:id ─────────────────────────────────────────────────────
async fn update_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<UpdateHostRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Update only fields that were provided; COALESCE preserves existing values.
let host = sqlx::query_scalar(
r#"
WITH updated AS (
UPDATE hosts SET
fqdn = COALESCE($1, fqdn),
ip_address = COALESCE($2::inet, ip_address),
display_name = COALESCE($3, display_name),
updated_at = NOW()
WHERE id = $4
RETURNING id
)
SELECT row_to_json(h) FROM (
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
os_family, os_name, arch, agent_version, health_status,
last_health_at, last_patch_at, agent_port, notes,
registered_at, updated_at
FROM hosts WHERE id = (SELECT id FROM updated)
) h
"#,
)
.bind(&req.fqdn)
.bind(&req.ip_address)
.bind(&req.display_name)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, host_id = %id, "Failed to update host");
let msg = if e.to_string().contains("unique") {
"A host with this FQDN and IP already exists".to_string()
} else {
"Database error".to_string()
};
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
host.map(Json).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
})
}
// ── GET /api/v1/hosts/:id/groups ──────────────────────────────────────────────
async fn list_host_groups(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
.await
.unwrap_or(false);
if !can_access {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
}
let groups: Vec<Group> = sqlx::query_as(
r#"SELECT g.id, g.name, g.description, g.created_at, g.updated_at
FROM groups g
JOIN host_groups hg ON hg.group_id = g.id
WHERE hg.host_id = $1
ORDER BY g.name"#,
)
.bind(id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to list host groups");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(Json(groups))
}
// ── POST /api/v1/hosts/:id/groups ─────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct AddToGroupRequest {
group_id: Uuid,
}
async fn add_host_to_group(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<AddToGroupRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query(
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
)
.bind(id)
.bind(req.group_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to add host to group");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&id.to_string()),
json!({ "group_id": req.group_id, "action": "added" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Host added to group" })))
}
// ── DELETE /api/v1/hosts/:id/groups/:group_id ─────────────────────────────────
async fn remove_host_from_group(
State(state): State<AppState>,
auth: AuthUser,
Path((id, group_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query("DELETE FROM host_groups WHERE host_id = $1 AND group_id = $2")
.bind(id)
.bind(group_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to remove host from group");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&id.to_string()),
json!({ "group_id": group_id, "action": "removed" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Host removed from group" })))
}
// ── FQDN resolution ───────────────────────────────────────────────────────────
/// Resolve an FQDN (or IP) to its primary IP address.
/// If the input is already a valid IP, returns it as-is.
async fn resolve_fqdn(fqdn: &str) -> Result<String, String> {
use std::net::ToSocketAddrs;
// Try direct IP parse first
if fqdn.parse::<std::net::IpAddr>().is_ok() {
return Ok(fqdn.to_string());
}
// DNS resolution
let addr = format!("{fqdn}:0");
match tokio::task::spawn_blocking(move || addr.to_socket_addrs()).await {
Ok(Ok(mut addrs)) => addrs
.next()
.map(|a| a.ip().to_string())
.ok_or_else(|| format!("No addresses found for {fqdn}")),
_ => Err(format!("Failed to resolve FQDN: {fqdn}")),
}
}
// ── POST /api/v1/hosts/:id/refresh ───────────────────────────────────────────
/// Queue an on-demand health + patch refresh for a single host.
///
/// Sends a PostgreSQL NOTIFY on the `refresh_requested` channel; the
/// pm-worker refresh listener picks this up and polls the host immediately.
/// Requires Operator or Admin role (any authenticated user).
async fn refresh_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<(StatusCode, Json<Value>), (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Verify the host exists.
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
.bind(id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "refresh_host: db error checking host existence");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
if !exists {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
));
}
// NOTIFY the worker's refresh listener.
sqlx::query("SELECT pg_notify('refresh_requested', $1)")
.bind(id.to_string())
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "refresh_host: pg_notify failed");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to queue refresh" } })),
)
})?;
tracing::info!(%id, "On-demand refresh queued");
Ok((
StatusCode::ACCEPTED,
Json(json!({ "message": "Refresh queued" })),
))
}

677
crates/pm-web/src/routes/jobs.rs Executable file
View File

@ -0,0 +1,677 @@
//! Patch job management routes.
//!
//! POST /api/v1/jobs — create a new patch job (operator+)
//! GET /api/v1/jobs — list jobs with pagination (RBAC scoped)
//! GET /api/v1/jobs/{id} — get job detail + per-host status
//! POST /api/v1/jobs/{id}/cancel — cancel a queued/pending job (admin or creator)
//! POST /api/v1/jobs/{id}/rollback — create a rollback job (admin only)
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateJobRequest, PatchJobSummary},
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
// ── Router ────────────────────────────────────────────────────────────────────
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_jobs).post(create_job))
.route("/{id}", get(get_job))
.route("/{id}/cancel", post(cancel_job))
.route("/{id}/rollback", post(rollback_job))
}
// ── Query params ──────────────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
pub struct JobListQuery {
pub limit: Option<i64>,
pub offset: Option<i64>,
}
// ── Response types ────────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct JobListResponse {
jobs: Vec<PatchJobSummary>,
total: i64,
limit: i64,
offset: i64,
}
/// Per-host row included in `GET /api/v1/jobs/{id}` response.
#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
struct JobHostRow {
pub id: Uuid,
pub host_id: Uuid,
pub display_name: String,
pub status: String,
pub agent_job_id: Option<String>,
pub retry_count: i32,
pub output: String,
pub error_message: Option<String>,
pub last_error: Option<String>,
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
}
// ── Error helper ──────────────────────────────────────────────────────────────
#[inline]
fn err(
status: StatusCode,
code: &'static str,
message: impl Into<String>,
) -> (StatusCode, Json<Value>) {
(
status,
Json(json!({ "error": { "code": code, "message": message.into() } })),
)
}
// ── RBAC helper ───────────────────────────────────────────────────────────────
/// Returns `true` when the operator's groups contain at least one host that
/// belongs to the given job. Admins always pass this check at the call site.
async fn operator_can_access_job(
pool: &sqlx::PgPool,
user_id: Uuid,
job_id: Uuid,
) -> Result<bool, sqlx::Error> {
sqlx::query_scalar(
r#"
SELECT EXISTS (
SELECT 1
FROM patch_job_hosts pjh
JOIN host_groups hg ON hg.host_id = pjh.host_id
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE pjh.job_id = $1
AND ug.user_id = $2
)
"#,
)
.bind(job_id)
.bind(user_id)
.fetch_one(pool)
.await
}
// ── POST /api/v1/jobs ─────────────────────────────────────────────────────────
async fn create_job(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateJobRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
if req.host_ids.is_empty() {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"host_ids must not be empty",
));
}
// Encode package list as JSONB.
let patch_selection = serde_json::to_value(&req.packages).unwrap_or(json!([]));
let notes = req.notes.clone().unwrap_or_default();
// Insert the parent job row; the DB NOTIFY trigger fires automatically
// when immediate = TRUE (see migration 003_jobs_scheduling.sql).
let job_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO patch_jobs
(kind, status, created_by_user_id, maintenance_window_id,
immediate, patch_selection, notes)
VALUES
('patch_apply'::job_kind, 'queued'::job_status, $1, $2, $3, $4, $5)
RETURNING id
"#,
)
.bind(auth.user_id)
.bind(req.maintenance_window_id)
.bind(req.immediate)
.bind(&patch_selection)
.bind(&notes)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "create_job: insert patch_jobs failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Insert one patch_job_hosts row per requested host.
for host_id in &req.host_ids {
sqlx::query(
r#"
INSERT INTO patch_job_hosts (job_id, host_id, status)
VALUES ($1, $2, 'queued'::job_status)
ON CONFLICT (job_id, host_id) DO NOTHING
"#,
)
.bind(job_id)
.bind(host_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(
error = %e, %job_id, %host_id,
"create_job: insert patch_job_hosts failed"
);
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
}
log_event(
&state.db,
AuditAction::PatchJobCreated,
Some(auth.user_id),
Some(&auth.username),
Some("job"),
Some(&job_id.to_string()),
json!({
"kind": "patch_apply",
"immediate": req.immediate,
"host_count": req.host_ids.len(),
"packages": req.packages,
"notes": notes,
}),
None,
None,
)
.await;
tracing::info!(
job_id = %job_id,
host_count = req.host_ids.len(),
immediate = req.immediate,
user = %auth.username,
"Patch job created"
);
Ok(Json(json!({ "id": job_id, "message": "Job created" })))
}
// ── GET /api/v1/jobs ──────────────────────────────────────────────────────────
async fn list_jobs(
State(state): State<AppState>,
auth: AuthUser,
Query(q): Query<JobListQuery>,
) -> Result<Json<JobListResponse>, (StatusCode, Json<Value>)> {
let limit = q.limit.unwrap_or(50).min(200);
let offset = q.offset.unwrap_or(0);
let jobs: Vec<PatchJobSummary> = if auth.role.is_admin() {
// Admins see every job.
sqlx::query_as(
r#"
SELECT
pj.id,
pj.kind,
pj.status,
pj.immediate,
pj.notes,
pj.created_at,
pj.started_at,
pj.completed_at,
COUNT(pjh.id) AS host_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
FROM patch_jobs pj
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
GROUP BY pj.id
ORDER BY pj.created_at DESC
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await
} else {
// Operators: only jobs where at least one host is in their groups.
sqlx::query_as(
r#"
SELECT
pj.id,
pj.kind,
pj.status,
pj.immediate,
pj.notes,
pj.created_at,
pj.started_at,
pj.completed_at,
COUNT(pjh.id) AS host_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
FROM patch_jobs pj
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
WHERE EXISTS (
SELECT 1
FROM patch_job_hosts pjh2
JOIN host_groups hg ON hg.host_id = pjh2.host_id
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE pjh2.job_id = pj.id
AND ug.user_id = $3
)
GROUP BY pj.id
ORDER BY pj.created_at DESC
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.bind(auth.user_id)
.fetch_all(&state.db)
.await
}
.map_err(|e| {
tracing::error!(error = %e, "list_jobs: query failed");
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
})?;
// Total count for pagination metadata.
let total: i64 = if auth.role.is_admin() {
sqlx::query_scalar("SELECT COUNT(*) FROM patch_jobs")
.fetch_one(&state.db)
.await
.unwrap_or(0)
} else {
sqlx::query_scalar(
r#"
SELECT COUNT(DISTINCT pj.id)
FROM patch_jobs pj
WHERE EXISTS (
SELECT 1
FROM patch_job_hosts pjh
JOIN host_groups hg ON hg.host_id = pjh.host_id
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE pjh.job_id = pj.id
AND ug.user_id = $1
)
"#,
)
.bind(auth.user_id)
.fetch_one(&state.db)
.await
.unwrap_or(0)
};
Ok(Json(JobListResponse {
jobs,
total,
limit,
offset,
}))
}
// ── GET /api/v1/jobs/:id ─────────────────────────────────────────────────────
async fn get_job(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// RBAC: operators may only view jobs touching their group's hosts.
if !auth.role.is_admin() {
let allowed = operator_can_access_job(&state.db, auth.user_id, id)
.await
.unwrap_or(false);
if !allowed {
return Err(err(StatusCode::FORBIDDEN, "forbidden", "Access denied"));
}
}
// Fetch the job header row as JSON.
let job: Option<Value> = sqlx::query_scalar(
r#"
SELECT row_to_json(j) FROM (
SELECT id, kind, status, created_by_user_id, parent_job_id,
maintenance_window_id, immediate, patch_selection, notes,
created_at, started_at, completed_at
FROM patch_jobs
WHERE id = $1
) j
"#,
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "get_job: failed to fetch job");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
let job = job.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
// Fetch per-host status rows joined to the host display name.
let hosts: Vec<JobHostRow> = sqlx::query_as(
r#"
SELECT
pjh.id,
pjh.host_id,
COALESCE(h.display_name, h.fqdn) AS display_name,
pjh.status::text AS status,
pjh.agent_job_id,
pjh.retry_count,
pjh.output,
pjh.error_message,
pjh.last_error,
pjh.started_at,
pjh.completed_at
FROM patch_job_hosts pjh
JOIN hosts h ON h.id = pjh.host_id
WHERE pjh.job_id = $1
ORDER BY h.display_name
"#,
)
.bind(id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "get_job: failed to fetch host rows");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
Ok(Json(json!({ "job": job, "hosts": hosts })))
}
// ── POST /api/v1/jobs/:id/cancel ─────────────────────────────────────────────
async fn cancel_job(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Fetch the job to verify it exists and check ownership.
let row: Option<(String, Option<Uuid>)> =
sqlx::query_as("SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1")
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "cancel_job: db fetch failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
let (status_str, creator_id) =
row.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
// Only admin or the job creator may cancel.
if !auth.role.can_write() {
let is_creator = creator_id == Some(auth.user_id);
if !is_creator {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
}
// Only queued or pending jobs can be cancelled.
if status_str != "queued" && status_str != "pending" {
return Err(err(
StatusCode::CONFLICT,
"invalid_state",
format!(
"Cannot cancel a job in '{}' state; only queued or pending jobs may be cancelled",
status_str
),
));
}
// Cancel the parent job.
sqlx::query("UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Cancel all queued/pending host rows for this job.
sqlx::query(
r#"
UPDATE patch_job_hosts
SET status = 'cancelled'::job_status
WHERE job_id = $1
AND status IN ('queued'::job_status, 'pending'::job_status)
"#,
)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "cancel_job: update patch_job_hosts failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Fire job-level pg_notify so the frontend can update the job row.
let notify_payload = json!({
"event_type": "job",
"job_id": id.to_string(),
"host_id": "",
"status": "cancelled",
"succeeded_count": 0,
"failed_count": 0,
"host_count": 0,
});
if let Ok(payload_str) = serde_json::to_string(&notify_payload) {
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
.bind(&payload_str)
.execute(&state.db)
.await
{
tracing::error!(error = %e, %id, "cancel_job: job-level pg_notify failed");
} else {
tracing::info!(%id, "cancel_job: job-level pg_notify sent");
}
}
log_event(
&state.db,
AuditAction::PatchJobCancelled,
Some(auth.user_id),
Some(&auth.username),
Some("job"),
Some(&id.to_string()),
json!({ "previous_status": status_str }),
None,
None,
)
.await;
tracing::info!(job_id = %id, user = %auth.username, "Patch job cancelled");
Ok(Json(json!({ "message": "Job cancelled" })))
}
// ── POST /api/v1/jobs/:id/rollback ────────────────────────────────────────────
async fn rollback_job(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Admin-only operation.
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
// Verify the original job exists.
let original_exists: bool =
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM patch_jobs WHERE id = $1)")
.bind(id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "rollback_job: existence check failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if !original_exists {
return Err(err(StatusCode::NOT_FOUND, "not_found", "Job not found"));
}
// Gather the host IDs from the original job.
let host_ids: Vec<Uuid> =
sqlx::query_scalar("SELECT host_id FROM patch_job_hosts WHERE job_id = $1")
.bind(id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "rollback_job: host fetch failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if host_ids.is_empty() {
return Err(err(
StatusCode::UNPROCESSABLE_ENTITY,
"no_hosts",
"Original job has no host entries to roll back",
));
}
// Create the rollback job row (immediate = true so the worker picks it up
// right away and the NOTIFY trigger fires).
let rollback_job_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO patch_jobs
(kind, status, created_by_user_id, parent_job_id, immediate,
patch_selection, notes)
VALUES
('rollback'::job_kind, 'queued'::job_status, $1, $2, TRUE,
'[]'::jsonb, $3)
RETURNING id
"#,
)
.bind(auth.user_id)
.bind(id)
.bind(format!("Rollback of job {}", id))
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, parent_job_id = %id, "rollback_job: insert failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Replicate host list into the rollback job.
for host_id in &host_ids {
sqlx::query(
r#"
INSERT INTO patch_job_hosts (job_id, host_id, status)
VALUES ($1, $2, 'queued'::job_status)
ON CONFLICT (job_id, host_id) DO NOTHING
"#,
)
.bind(rollback_job_id)
.bind(host_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(
error = %e, %rollback_job_id, %host_id,
"rollback_job: insert patch_job_hosts failed"
);
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
}
log_event(
&state.db,
AuditAction::PatchJobRollback,
Some(auth.user_id),
Some(&auth.username),
Some("job"),
Some(&rollback_job_id.to_string()),
json!({
"original_job_id": id,
"rollback_job_id": rollback_job_id,
"host_count": host_ids.len(),
}),
None,
None,
)
.await;
tracing::info!(
rollback_job_id = %rollback_job_id,
original_job_id = %id,
user = %auth.username,
"Rollback job created"
);
Ok(Json(json!({
"id": rollback_job_id,
"parent_job_id": id,
"message": "Rollback job created"
})))
}

View File

@ -0,0 +1,452 @@
//! Maintenance window management routes.
//!
//! GET /api/v1/hosts/{id}/maintenance-windows — list windows for host
//! GET /api/v1/maintenance-windows — list ALL windows (bulk)
//! POST /api/v1/hosts/{id}/maintenance-windows — create window for host
//! PUT /api/v1/hosts/{id}/maintenance-windows/{win_id} — update window
//! DELETE /api/v1/hosts/{id}/maintenance-windows/{win_id} — delete window
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, put},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest},
};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
// ── Router ────────────────────────────────────────────────────────────────────
/// Mount as a nested router under `/hosts/{host_id}/maintenance-windows`.
/// Axum will merge the `{host_id}` path segment from the parent nest.
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_windows).post(create_window))
.route("/{win_id}", put(update_window).delete(delete_window))
}
/// Top-level router for `/api/v1/maintenance-windows` — bulk list-all endpoint.
pub fn all_windows_router() -> Router<AppState> {
Router::new().route("/", get(list_all_windows))
}
// ── GET /api/v1/maintenance-windows ──────────────────────────────────────────
/// Bulk endpoint: return every maintenance window across all hosts.
/// Eliminates N+1 queries from the frontend (one request instead of one per host).
async fn list_all_windows(
State(state): State<AppState>,
_auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
r#"
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
FROM maintenance_windows
ORDER BY host_id, created_at ASC
"#,
)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "list_all_windows: query failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
Ok(Json(json!({ "windows": windows })))
}
// ── Error helper ──────────────────────────────────────────────────────────────
#[inline]
fn err(
status: StatusCode,
code: &'static str,
message: impl Into<String>,
) -> (StatusCode, Json<Value>) {
(
status,
Json(json!({ "error": { "code": code, "message": message.into() } })),
)
}
// ── GET /api/v1/hosts/:host_id/maintenance-windows ────────────────────────────
async fn list_windows(
State(state): State<AppState>,
_auth: AuthUser,
Path(host_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Verify host exists.
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "list_windows: host existence check failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if !host_exists {
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
}
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
r#"
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
FROM maintenance_windows
WHERE host_id = $1
ORDER BY created_at ASC
"#,
)
.bind(host_id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "list_windows: query failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
Ok(Json(json!({ "windows": windows })))
}
// ── POST /api/v1/hosts/:host_id/maintenance-windows ───────────────────────────
async fn create_window(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
Json(req): Json<CreateMaintenanceWindowRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
// Validate: weekly requires recurrence_day 0-6
if req.recurrence == pm_core::models::WindowRecurrence::Weekly {
match req.recurrence_day {
Some(d) if (0..=6).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Weekly recurrence requires recurrence_day 0-6 (0=Sunday)",
));
},
}
}
// Validate: monthly requires recurrence_day 1-31
if req.recurrence == pm_core::models::WindowRecurrence::Monthly {
match req.recurrence_day {
Some(d) if (1..=31).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Monthly recurrence requires recurrence_day 1-31",
));
},
}
}
// Verify host exists.
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "create_window: host existence check failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if !host_exists {
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
}
let duration = req.duration_minutes.unwrap_or(60);
let enabled = req.enabled.unwrap_or(true);
let auto_apply = req.auto_apply.unwrap_or(true);
let window: MaintenanceWindow = sqlx::query_as(
r#"
INSERT INTO maintenance_windows
(host_id, label, recurrence, start_at, duration_minutes, recurrence_day, enabled, auto_apply)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
"#,
)
.bind(host_id)
.bind(&req.label)
.bind(&req.recurrence)
.bind(req.start_at)
.bind(duration)
.bind(req.recurrence_day)
.bind(enabled)
.bind(auto_apply)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "create_window: insert failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
log_event(
&state.db,
AuditAction::MaintenanceWindowCreated,
Some(auth.user_id),
Some(&auth.username),
Some("maintenance_window"),
Some(&window.id.to_string()),
json!({
"host_id": host_id,
"label": window.label,
"recurrence": window.recurrence.to_string(),
}),
None,
None,
)
.await;
tracing::info!(
window_id = %window.id,
%host_id,
recurrence = %window.recurrence,
user = %auth.username,
"Maintenance window created"
);
Ok(Json(json!(window)))
}
// ── PUT /api/v1/hosts/:host_id/maintenance-windows/:win_id ───────────────────
async fn update_window(
State(state): State<AppState>,
auth: AuthUser,
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
Json(req): Json<UpdateMaintenanceWindowRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
// Fetch existing record (verify ownership and existence).
let existing: Option<MaintenanceWindow> = sqlx::query_as(
r#"
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
FROM maintenance_windows
WHERE id = $1 AND host_id = $2
"#,
)
.bind(win_id)
.bind(host_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %win_id, "update_window: fetch failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
let existing = existing.ok_or_else(|| {
err(
StatusCode::NOT_FOUND,
"not_found",
"Maintenance window not found",
)
})?;
// Apply partial updates using existing values as defaults.
let new_label = req.label.unwrap_or(existing.label);
let new_recurrence = req.recurrence.unwrap_or(existing.recurrence);
let new_start_at = req.start_at.unwrap_or(existing.start_at);
let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes);
let new_rec_day = req.recurrence_day.or(existing.recurrence_day);
let new_enabled = req.enabled.unwrap_or(existing.enabled);
let new_auto_apply = req.auto_apply.unwrap_or(existing.auto_apply);
// Validate recurrence_day for the final recurrence type.
if new_recurrence == pm_core::models::WindowRecurrence::Weekly {
match new_rec_day {
Some(d) if (0..=6).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Weekly recurrence requires recurrence_day 0-6",
));
},
}
}
if new_recurrence == pm_core::models::WindowRecurrence::Monthly {
match new_rec_day {
Some(d) if (1..=31).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Monthly recurrence requires recurrence_day 1-31",
));
},
}
}
let updated: MaintenanceWindow = sqlx::query_as(
r#"
UPDATE maintenance_windows
SET label = $3,
recurrence = $4,
start_at = $5,
duration_minutes = $6,
recurrence_day = $7,
enabled = $8,
auto_apply = $9,
updated_at = NOW()
WHERE id = $1 AND host_id = $2
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
"#,
)
.bind(win_id)
.bind(host_id)
.bind(&new_label)
.bind(&new_recurrence)
.bind(new_start_at)
.bind(new_duration)
.bind(new_rec_day)
.bind(new_enabled)
.bind(new_auto_apply)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %win_id, "update_window: update failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
log_event(
&state.db,
AuditAction::MaintenanceWindowUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("maintenance_window"),
Some(&win_id.to_string()),
json!({ "host_id": host_id }),
None,
None,
)
.await;
tracing::info!(
window_id = %win_id,
%host_id,
user = %auth.username,
"Maintenance window updated"
);
Ok(Json(json!(updated)))
}
// ── DELETE /api/v1/hosts/:host_id/maintenance-windows/:win_id ────────────────
async fn delete_window(
State(state): State<AppState>,
auth: AuthUser,
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
let result = sqlx::query("DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2")
.bind(win_id)
.bind(host_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %win_id, "delete_window: delete failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if result.rows_affected() == 0 {
return Err(err(
StatusCode::NOT_FOUND,
"not_found",
"Maintenance window not found",
));
}
log_event(
&state.db,
AuditAction::MaintenanceWindowDeleted,
Some(auth.user_id),
Some(&auth.username),
Some("maintenance_window"),
Some(&win_id.to_string()),
json!({ "host_id": host_id }),
None,
None,
)
.await;
tracing::info!(
window_id = %win_id,
%host_id,
user = %auth.username,
"Maintenance window deleted"
);
Ok(Json(json!({ "message": "Maintenance window deleted" })))
}

17
crates/pm-web/src/routes/mod.rs Executable file
View File

@ -0,0 +1,17 @@
//! Route modules for the pm-web API.
pub mod auth;
pub mod ca;
pub mod discovery;
pub mod enrollment;
pub mod groups;
pub mod health_checks;
pub mod hosts;
pub mod jobs;
pub mod maintenance_windows;
pub mod settings;
pub mod sso;
pub mod status;
pub mod users;
pub mod ws;
pub mod reports;

View File

@ -0,0 +1,163 @@
//! Report generation endpoints.
//!
//! GET /api/v1/reports/compliance?format=csv|pdf&from=...&to=...&group_id=...
//! GET /api/v1/reports/patch-history?format=csv|pdf&from=...&to=...
//! GET /api/v1/reports/vulnerability?format=csv|pdf&from=...&to=...
//! GET /api/v1/reports/audit?format=csv|pdf&from=...&to=...
use axum::{
body::Bytes,
extract::{Query, State},
http::{header, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
routing::get,
Router,
};
use pm_reports::{ReportParams, ReportType};
use crate::AppState;
#[derive(serde::Deserialize)]
struct ReportQuery {
/// "csv" or "pdf" (defaults to "csv")
format: Option<String>,
from: Option<chrono::DateTime<chrono::Utc>>,
to: Option<chrono::DateTime<chrono::Utc>>,
group_id: Option<uuid::Uuid>,
}
pub fn router() -> Router<AppState> {
Router::new()
.route("/compliance", get(compliance_report))
.route("/patch-history", get(patch_history_report))
.route("/vulnerability", get(vulnerability_report))
.route("/audit", get(audit_report))
}
// ---------------------------------------------------------------------------
// Internal helper
// ---------------------------------------------------------------------------
async fn run_report(
db: sqlx::PgPool,
params: ReportParams,
use_pdf: bool,
csv_name: &'static str,
pdf_name: &'static str,
) -> Response {
let (ct, disposition, result) = if use_pdf {
let disp = format!("attachment; filename=\"{}\"", pdf_name);
let data = pm_reports::generate_pdf(&db, &params).await;
("application/pdf", disp, data)
} else {
let disp = format!("attachment; filename=\"{}\"", csv_name);
let data = pm_reports::generate_csv(&db, &params).await;
("text/csv; charset=utf-8", disp, data)
};
match result {
Ok(bytes) => {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static(ct));
headers.insert(
header::CONTENT_DISPOSITION,
HeaderValue::from_str(&disposition)
.unwrap_or_else(|_| HeaderValue::from_static("attachment")),
);
(headers, Bytes::from(bytes)).into_response()
},
Err(e) => {
tracing::error!(error = %e, "report generation failed");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Report error: {}", e),
)
.into_response()
},
}
}
// ---------------------------------------------------------------------------
// Handlers
// ---------------------------------------------------------------------------
async fn compliance_report(
State(state): State<AppState>,
Query(q): Query<ReportQuery>,
) -> Response {
let params = ReportParams {
report_type: ReportType::Compliance,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"compliance-report.csv",
"compliance-report.pdf",
)
.await
}
async fn patch_history_report(
State(state): State<AppState>,
Query(q): Query<ReportQuery>,
) -> Response {
let params = ReportParams {
report_type: ReportType::PatchHistory,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"patch-history-report.csv",
"patch-history-report.pdf",
)
.await
}
async fn vulnerability_report(
State(state): State<AppState>,
Query(q): Query<ReportQuery>,
) -> Response {
let params = ReportParams {
report_type: ReportType::Vulnerability,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"vulnerability-report.csv",
"vulnerability-report.pdf",
)
.await
}
async fn audit_report(State(state): State<AppState>, Query(q): Query<ReportQuery>) -> Response {
let params = ReportParams {
report_type: ReportType::Audit,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"audit-report.csv",
"audit-report.pdf",
)
.await
}

View File

@ -0,0 +1,977 @@
//! Settings management routes.
//!
//! GET /api/v1/settings — get all settings (admin only)
//! PUT /api/v1/settings — update settings (admin only)
//! POST /api/v1/settings/sso/discover — discover OIDC endpoints (admin only)
//! POST /api/v1/settings/sso/test — test OIDC provider connectivity (admin only)
//! POST /api/v1/settings/azure-sso/test — backward-compat alias for SSO test (admin only)
//! POST /api/v1/settings/smtp/test — send test email (admin only)
//! GET /api/v1/settings/ip-whitelist — get IP whitelist (admin only)
//! PUT /api/v1/settings/ip-whitelist — update IP whitelist (admin only)
//! POST /api/v1/settings/audit-integrity — verify audit log integrity (admin only)
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use lettre::{
message::{header::ContentType, Mailbox},
transport::smtp::authentication::Credentials,
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
};
use pm_auth::rbac::AuthUser;
use pm_core::audit::{log_event, verify_integrity, AuditAction};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::AppState;
// ============================================================
// Data structures
// ============================================================
#[derive(Debug, Serialize)]
pub struct SettingsResponse {
pub oidc: OidcConfigResponse,
pub smtp: SmtpConfig,
pub polling: PollingConfig,
pub ip_whitelist: Vec<String>,
pub web_tls_strategy: String,
pub notification: NotificationConfig,
pub sso_callback_url: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OidcConfigResponse {
pub enabled: bool,
pub provider_type: String, // "keycloak", "azure", "custom"
pub display_name: String,
pub discovery_url: String,
pub client_id: String,
pub client_secret: String, // Always masked in responses
pub redirect_uri: String,
pub scopes: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SmtpConfig {
pub enabled: bool,
pub host: String,
pub port: u16,
pub username: String,
pub from: String,
pub tls_mode: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PollingConfig {
pub health_poll_interval_secs: u64,
pub patch_poll_interval_secs: u64,
}
#[derive(Debug, Deserialize)]
pub struct UpdateSettingsRequest {
pub oidc: Option<OidcConfigUpdate>,
pub smtp: Option<SmtpConfigUpdate>,
pub polling: Option<PollingConfigUpdate>,
pub ip_whitelist: Option<Vec<String>>,
pub web_tls_strategy: Option<String>,
pub notification: Option<NotificationConfigUpdate>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct NotificationConfig {
pub email_enabled: bool,
pub email_from: String,
pub recipients: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct NotificationConfigUpdate {
pub email_enabled: Option<bool>,
pub email_from: Option<String>,
pub recipients: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct OidcConfigUpdate {
pub enabled: Option<bool>,
pub provider_type: Option<String>,
pub display_name: Option<String>,
pub discovery_url: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub redirect_uri: Option<String>,
pub scopes: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OidcDiscoveryRequest {
pub discovery_url: String,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
pub struct OidcDiscoveryResult {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub jwks_uri: String,
pub userinfo_endpoint: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct SmtpConfigUpdate {
pub enabled: Option<bool>,
pub host: Option<String>,
pub port: Option<u16>,
pub username: Option<String>,
pub password: Option<String>,
pub from: Option<String>,
pub tls_mode: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct PollingConfigUpdate {
pub health_poll_interval_secs: Option<u64>,
pub patch_poll_interval_secs: Option<u64>,
}
#[derive(Debug, Deserialize)]
pub struct IpWhitelistUpdate {
pub entries: Vec<String>,
}
// ============================================================
// Router
// ============================================================
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(get_settings).put(update_settings))
.route("/sso/discover", post(discover_oidc))
.route("/sso/test", post(test_oidc))
.route("/azure-sso/test", post(test_azure_sso_compat))
.route("/smtp/test", post(test_smtp))
.route(
"/ip-whitelist",
get(get_ip_whitelist).put(update_ip_whitelist),
)
.route("/audit-integrity", post(audit_integrity))
}
// ============================================================
// Helpers
// ============================================================
const MASKED: &str = "********";
fn write_access_required(auth: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
Ok(())
}
async fn load_system_config(
pool: &sqlx::PgPool,
) -> Result<HashMap<String, String>, (StatusCode, Json<Value>)> {
let rows: Vec<(String, String)> = sqlx::query_as("SELECT key, value FROM system_config")
.fetch_all(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load system_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(rows.into_iter().collect())
}
fn build_settings_response(
cfg: &HashMap<String, String>,
oidc: OidcConfigResponse,
) -> SettingsResponse {
let get = |key: &str| -> String { cfg.get(key).cloned().unwrap_or_default() };
let recipients: Vec<String> =
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
SettingsResponse {
oidc,
smtp: SmtpConfig {
enabled: get("smtp_enabled") == "true",
host: get("smtp_host"),
port: get("smtp_port").parse().unwrap_or(587),
username: get("smtp_username"),
from: get("smtp_from"),
tls_mode: get("smtp_tls_mode"),
},
polling: PollingConfig {
health_poll_interval_secs: get("health_poll_interval_secs").parse().unwrap_or(300),
patch_poll_interval_secs: get("patch_poll_interval_secs").parse().unwrap_or(1800),
},
ip_whitelist: serde_json::from_str(&get("ip_whitelist")).unwrap_or_default(),
web_tls_strategy: get("web_tls_strategy"),
notification: NotificationConfig {
email_enabled: get("notification_email_enabled") == "true",
email_from: get("notification_email_from"),
recipients,
},
sso_callback_url: get("sso_callback_url"),
}
}
async fn update_config_key(
pool: &sqlx::PgPool,
key: &str,
value: &str,
) -> Result<(), (StatusCode, Json<Value>)> {
sqlx::query("UPDATE system_config SET value = $1, updated_at = NOW() WHERE key = $2")
.bind(value)
.bind(key)
.execute(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, key, "Failed to update system_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(())
}
async fn fetch_oidc_config(
pool: &sqlx::PgPool,
) -> Result<OidcConfigResponse, (StatusCode, Json<Value>)> {
let row: Option<(bool, String, String, String, String, String, String, String)> = sqlx::query_as(
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
)
.fetch_optional(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(match row {
Some((
enabled,
provider_type,
display_name,
discovery_url,
client_id,
client_secret,
redirect_uri,
scopes,
)) => OidcConfigResponse {
enabled,
provider_type,
display_name,
discovery_url,
client_id,
client_secret: if client_secret.is_empty() {
String::new()
} else {
MASKED.to_string()
},
redirect_uri,
scopes,
},
None => OidcConfigResponse {
enabled: false,
provider_type: "azure".to_string(),
display_name: "Azure AD".to_string(),
discovery_url: String::new(),
client_id: String::new(),
client_secret: String::new(),
redirect_uri: String::new(),
scopes: "openid profile email".to_string(),
},
})
}
// ============================================================
// GET /api/v1/settings
// ============================================================
async fn get_settings(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let cfg = load_system_config(&state.db).await?;
// Inject read-only config values from TOML file (not stored in DB)
let mut cfg = cfg;
cfg.insert(
"sso_callback_url".to_string(),
state.config.security.sso_callback_url.clone(),
);
let oidc = fetch_oidc_config(&state.db).await?;
Ok(Json(build_settings_response(&cfg, oidc)))
}
// ============================================================
// PUT /api/v1/settings
// ============================================================
async fn update_settings(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<UpdateSettingsRequest>,
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
// Update OIDC config
if let Some(oidc) = req.oidc {
let update_secret = oidc
.client_secret
.as_ref()
.is_some_and(|s| s != MASKED && !s.is_empty());
let result = if update_secret {
sqlx::query(
"UPDATE oidc_config SET \
enabled = COALESCE($1, enabled), \
provider_type = COALESCE($2, provider_type), \
display_name = COALESCE($3, display_name), \
discovery_url = COALESCE($4, discovery_url), \
client_id = COALESCE($5, client_id), \
client_secret = $6, \
redirect_uri = COALESCE($7, redirect_uri), \
scopes = COALESCE($8, scopes), \
updated_at = NOW() \
WHERE id = 1",
)
.bind(oidc.enabled)
.bind(&oidc.provider_type)
.bind(&oidc.display_name)
.bind(&oidc.discovery_url)
.bind(&oidc.client_id)
.bind(oidc.client_secret.as_deref().unwrap_or(""))
.bind(&oidc.redirect_uri)
.bind(&oidc.scopes)
.execute(&state.db)
.await
} else {
sqlx::query(
"UPDATE oidc_config SET \
enabled = COALESCE($1, enabled), \
provider_type = COALESCE($2, provider_type), \
display_name = COALESCE($3, display_name), \
discovery_url = COALESCE($4, discovery_url), \
client_id = COALESCE($5, client_id), \
redirect_uri = COALESCE($6, redirect_uri), \
scopes = COALESCE($7, scopes), \
updated_at = NOW() \
WHERE id = 1",
)
.bind(oidc.enabled)
.bind(&oidc.provider_type)
.bind(&oidc.display_name)
.bind(&oidc.discovery_url)
.bind(&oidc.client_id)
.bind(&oidc.redirect_uri)
.bind(&oidc.scopes)
.execute(&state.db)
.await
};
result.map_err(|e| {
tracing::error!(error = %e, "Failed to update oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": format!("Failed to update OIDC config: {}", e) } })),
)
})?;
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("oidc"),
Some("1"),
json!({ "section": "oidc" }),
None,
None,
)
.await;
}
// Update SMTP config
if let Some(smtp) = &req.smtp {
if let Some(v) = smtp.enabled {
update_config_key(&state.db, "smtp_enabled", &v.to_string()).await?;
}
if let Some(ref v) = smtp.host {
update_config_key(&state.db, "smtp_host", v).await?;
}
if let Some(v) = smtp.port {
update_config_key(&state.db, "smtp_port", &v.to_string()).await?;
}
if let Some(ref v) = smtp.username {
update_config_key(&state.db, "smtp_username", v).await?;
}
if let Some(ref v) = smtp.password {
if v != MASKED {
update_config_key(&state.db, "smtp_password", v).await?;
}
}
if let Some(ref v) = smtp.from {
update_config_key(&state.db, "smtp_from", v).await?;
}
if let Some(ref v) = smtp.tls_mode {
update_config_key(&state.db, "smtp_tls_mode", v).await?;
}
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("smtp"),
Some("system_config"),
json!({ "section": "smtp" }),
None,
None,
)
.await;
}
// Update polling config
if let Some(polling) = &req.polling {
if let Some(v) = polling.health_poll_interval_secs {
update_config_key(&state.db, "health_poll_interval_secs", &v.to_string()).await?;
}
if let Some(v) = polling.patch_poll_interval_secs {
update_config_key(&state.db, "patch_poll_interval_secs", &v.to_string()).await?;
}
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("polling"),
Some("system_config"),
json!({ "section": "polling" }),
None,
None,
)
.await;
}
// Update IP whitelist
if let Some(ref entries) = req.ip_whitelist {
let json_str = serde_json::to_string(entries).unwrap_or_else(|_| "[]".to_string());
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
// Update in-memory AuthConfig for immediate enforcement
state.auth_config.update_ip_whitelist(entries.clone());
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("ip_whitelist"),
Some("system_config"),
json!({ "entries": entries }),
None,
None,
)
.await;
}
// Update web TLS strategy
if let Some(ref v) = req.web_tls_strategy {
update_config_key(&state.db, "web_tls_strategy", v).await?;
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("web_tls_strategy"),
Some("system_config"),
json!({ "web_tls_strategy": v }),
None,
None,
)
.await;
}
// Update notification config
if let Some(notif) = &req.notification {
if let Some(v) = notif.email_enabled {
update_config_key(&state.db, "notification_email_enabled", &v.to_string()).await?;
}
if let Some(ref v) = notif.email_from {
update_config_key(&state.db, "notification_email_from", v).await?;
}
if let Some(ref v) = notif.recipients {
let json_str = serde_json::to_string(v).unwrap_or_else(|_| "[]".to_string());
update_config_key(&state.db, "notification_email_recipients", &json_str).await?;
}
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("notification"),
Some("system_config"),
json!({ "section": "notification" }),
None,
None,
)
.await;
}
// Return updated settings
let cfg = load_system_config(&state.db).await?;
// Inject read-only config values from TOML file (not stored in DB)
let mut cfg = cfg;
cfg.insert(
"sso_callback_url".to_string(),
state.config.security.sso_callback_url.clone(),
);
let oidc = fetch_oidc_config(&state.db).await?;
Ok(Json(build_settings_response(&cfg, oidc)))
}
// ============================================================
// POST /api/v1/settings/sso/discover
// ============================================================
async fn discover_oidc(
State(_state): State<AppState>,
auth: AuthUser,
Json(req): Json<OidcDiscoveryRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
if req.discovery_url.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "bad_request", "message": "discovery_url is required" } }),
),
));
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| {
tracing::error!(error = %e, "Failed to build HTTP client");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
)
})?;
match client.get(&req.discovery_url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: Value = resp.json().await.unwrap_or(json!({}));
Ok(Json(json!({
"success": true,
"issuer": body.get("issuer").and_then(|v| v.as_str()).unwrap_or(""),
"authorization_endpoint": body.get("authorization_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
"token_endpoint": body.get("token_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
"jwks_uri": body.get("jwks_uri").and_then(|v| v.as_str()).unwrap_or(""),
"userinfo_endpoint": body.get("userinfo_endpoint").and_then(|v| v.as_str()),
})))
},
Ok(resp) => Err((
StatusCode::BAD_GATEWAY,
Json(
json!({ "error": { "code": "discovery_failed", "message": format!("Discovery endpoint returned HTTP {}", resp.status()) } }),
),
)),
Err(e) => Err((
StatusCode::BAD_GATEWAY,
Json(
json!({ "error": { "code": "discovery_failed", "message": format!("Failed to reach discovery endpoint: {}", e) } }),
),
)),
}
}
// ============================================================
// POST /api/v1/settings/sso/test
// ============================================================
async fn test_oidc(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let row: Option<(bool, String, String)> = sqlx::query_as(
"SELECT enabled, provider_type, discovery_url FROM oidc_config WHERE id = 1",
)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let (enabled, provider_type, discovery_url) = match row {
Some(r) => r,
None => {
return Ok(Json(json!({
"success": false,
"message": "OIDC is not configured"
})));
},
};
if !enabled {
return Ok(Json(json!({
"success": false,
"message": "OIDC is not enabled"
})));
}
if discovery_url.is_empty() {
return Ok(Json(json!({
"success": false,
"message": "OIDC discovery URL is not set"
})));
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| {
tracing::error!(error = %e, "Failed to build HTTP client");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
)
})?;
match client.get(&discovery_url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: Value = resp.json().await.unwrap_or(json!({}));
let issuer = body.get("issuer").and_then(|v| v.as_str()).unwrap_or("");
let provider_label = match provider_type.as_str() {
"keycloak" => "Keycloak",
"azure" => "Azure AD",
_ => "OIDC",
};
Ok(Json(json!({
"success": true,
"message": format!("{} provider verified successfully", provider_label),
"issuer": issuer,
"provider_type": provider_type,
})))
},
Ok(resp) => Ok(Json(json!({
"success": false,
"message": format!("Failed to reach OIDC provider: HTTP {}", resp.status())
}))),
Err(e) => Ok(Json(json!({
"success": false,
"message": format!("Failed to reach OIDC provider: {}", e)
}))),
}
}
// ============================================================
// POST /api/v1/settings/azure-sso/test (backward-compatible alias)
// ============================================================
async fn test_azure_sso_compat(
state: State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
test_oidc(state, auth).await
}
// ============================================================
// POST /api/v1/settings/smtp/test
// ============================================================
async fn test_smtp(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let cfg = load_system_config(&state.db).await?;
let smtp_enabled = cfg.get("smtp_enabled").map(|v| v.as_str()) == Some("true");
if !smtp_enabled {
return Ok(Json(json!({
"success": false,
"message": "SMTP is not enabled"
})));
}
let host = cfg.get("smtp_host").cloned().unwrap_or_default();
let port: u16 = cfg
.get("smtp_port")
.and_then(|v| v.parse().ok())
.unwrap_or(587);
let username = cfg.get("smtp_username").cloned().unwrap_or_default();
let password = cfg.get("smtp_password").cloned().unwrap_or_default();
let from_addr = cfg.get("smtp_from").cloned().unwrap_or_default();
let tls_mode = cfg
.get("smtp_tls_mode")
.cloned()
.unwrap_or_else(|| "starttls".to_string());
let recipients_str = cfg
.get("notification_email_recipients")
.cloned()
.unwrap_or_default();
let recipients: Vec<String> = serde_json::from_str(&recipients_str).unwrap_or_default();
if host.is_empty() || from_addr.is_empty() {
return Ok(Json(json!({
"success": false,
"message": "SMTP host or from address is not configured"
})));
}
let result = send_smtp_test(
&host,
port,
&username,
&password,
&from_addr,
&tls_mode,
&recipients,
)
.await;
match result {
Ok(()) => {
let recipient_info = if recipients.is_empty() {
String::new()
} else {
format!(" and {} recipient(s)", recipients.len())
};
Ok(Json(json!({
"success": true,
"message": format!("Test email sent successfully to from address{}", recipient_info)
})))
},
Err(e) => Ok(Json(json!({
"success": false,
"message": format!("Failed to send test email: {}", e)
}))),
}
}
async fn send_smtp_test(
host: &str,
port: u16,
username: &str,
password: &str,
from_addr: &str,
tls_mode: &str,
recipients: &[String],
) -> Result<(), String> {
let from_mailbox: Mailbox = from_addr
.parse()
.map_err(|e| format!("Invalid from address: {}", e))?;
let mut builder = Message::builder()
.from(from_mailbox.clone())
.to(from_mailbox);
for recipient in recipients {
if let Ok(addr) = recipient.parse() {
builder = builder.bcc(addr);
}
}
let body = if recipients.is_empty() {
"This is a test email from Linux Patch Manager.".to_string()
} else {
format!(
"This is a test email from Linux Patch Manager.\n\nSent to: {}",
recipients.join(", ")
)
};
let email = builder
.subject("Linux Patch Manager — SMTP Test")
.header(ContentType::TEXT_PLAIN)
.body(body)
.map_err(|e| format!("Failed to build email: {}", e))?;
let result = match tls_mode {
"tls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(host)
.map_err(|e| format!("TLS relay error: {}", e))?;
builder = builder.port(port);
if !username.is_empty() {
builder = builder
.credentials(Credentials::new(username.to_string(), password.to_string()));
}
let transport = builder.build();
transport.send(email).await
},
"starttls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(host)
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
builder = builder.port(port);
if !username.is_empty() {
builder = builder
.credentials(Credentials::new(username.to_string(), password.to_string()));
}
let transport = builder.build();
transport.send(email).await
},
_ => {
// "none" — plaintext / no TLS
let mut builder =
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(host).port(port);
if !username.is_empty() {
builder = builder
.credentials(Credentials::new(username.to_string(), password.to_string()));
}
let transport = builder.build();
transport.send(email).await
},
};
result
.map(|_| ())
.map_err(|e| format!("SMTP send error: {}", e))
}
// ============================================================
// GET /api/v1/settings/ip-whitelist
// ============================================================
async fn get_ip_whitelist(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let value: Option<String> = sqlx::query_scalar(
"SELECT value FROM system_config WHERE key = 'ip_whitelist'",
)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load ip_whitelist");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let entries: Vec<String> = serde_json::from_str(&value.unwrap_or_default()).unwrap_or_default();
Ok(Json(json!({ "entries": entries })))
}
// ============================================================
// PUT /api/v1/settings/ip-whitelist
// ============================================================
async fn update_ip_whitelist(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<IpWhitelistUpdate>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
// Validate each entry
for entry in &req.entries {
if entry.parse::<ipnet::IpNet>().is_err() && entry.parse::<std::net::IpAddr>().is_err() {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "bad_request", "message": format!("Invalid CIDR or IP: {}", entry) } }),
),
));
}
}
let json_str = serde_json::to_string(&req.entries).unwrap_or_else(|_| "[]".to_string());
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
// Update in-memory AuthConfig for immediate enforcement
state.auth_config.update_ip_whitelist(req.entries.clone());
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("ip_whitelist"),
Some("system_config"),
json!({ "entries": req.entries }),
None,
None,
)
.await;
Ok(Json(json!({ "entries": req.entries })))
}
// ============================================================
// POST /api/v1/settings/audit-integrity
// ============================================================
/// Verify audit log hash chain integrity.
/// Returns whether the chain is intact, rows checked, and any errors.
async fn audit_integrity(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let result = verify_integrity(&state.db).await;
log_event(
&state.db,
AuditAction::AuditIntegrityVerified,
Some(auth.user_id),
Some(&auth.username),
Some("audit_log"),
None,
json!({
"intact": result.intact,
"rows_checked": result.rows_checked,
"error_count": result.errors.len(),
}),
None,
None,
)
.await;
Ok(Json(json!({
"intact": result.intact,
"rows_checked": result.rows_checked,
"errors": result.errors.iter().map(|e| json!({
"row_id": e.row_id,
"expected_hash": e.expected_hash,
"actual_hash": e.actual_hash,
})).collect::<Vec<_>>(),
})))
}

838
crates/pm-web/src/routes/sso.rs Executable file
View File

@ -0,0 +1,838 @@
//! Generic OIDC SSO routes (Keycloak, Azure AD, Custom).
//!
//! Public routes (no auth required):
//! GET /api/v1/auth/sso/login — redirect to OIDC provider authorization URL
//! GET /api/v1/auth/sso/callback — handle OIDC provider callback, redirect to frontend SPA
//!
//! Backward-compatible aliases:
//! GET /api/v1/auth/azure/login → redirects to generic SSO login
//! GET /api/v1/auth/azure/callback → redirects to generic SSO callback
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Redirect},
routing::get,
Router,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use chrono::Utc;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use pm_auth::{jwt::issue_access_token, refresh};
use pm_core::audit::{log_event, AuditAction};
use serde::Deserialize;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::AppState;
// ============================================================
// Data structures
// ============================================================
#[derive(Clone)]
pub struct SsoSession {
pub code_verifier: String,
pub created_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
#[allow(dead_code)]
access_token: Option<String>,
id_token: Option<String>,
#[allow(dead_code)]
token_type: Option<String>,
#[allow(dead_code)]
expires_in: Option<i64>,
}
#[derive(Debug, Deserialize)]
struct IdTokenClaims {
email: Option<String>,
name: Option<String>,
sub: Option<String>,
oid: Option<String>,
preferred_username: Option<String>,
}
#[derive(Debug, sqlx::FromRow)]
struct DbUserForSso {
id: Uuid,
username: String,
display_name: String,
role: String,
is_active: bool,
mfa_enabled: bool,
}
/// OIDC provider configuration from database.
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct OidcConfig {
pub enabled: bool,
pub provider_type: String,
pub display_name: String,
pub discovery_url: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub scopes: String,
}
/// Cached OIDC discovery document.
#[derive(Debug, Clone)]
pub struct OidcDiscovery {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub jwks_uri: String,
pub userinfo_endpoint: Option<String>,
pub fetched_at: chrono::DateTime<Utc>,
}
/// Cache for OIDC discovery documents and JWKS with TTL-based refresh.
#[derive(Default)]
pub struct OidcCache {
pub discovery: Option<OidcDiscovery>,
pub jwks: Option<serde_json::Value>,
pub jwks_fetched_at: Option<chrono::DateTime<Utc>>,
}
/// JWKS cache TTL in seconds (1 hour).
const JWKS_CACHE_TTL_SECS: i64 = 3600;
/// Discovery cache TTL in seconds (1 hour).
const DISCOVERY_CACHE_TTL_SECS: i64 = 3600;
// ============================================================
// Router
// ============================================================
pub fn public_router() -> Router<AppState> {
Router::new()
.route("/login", get(sso_login))
.route("/callback", get(sso_callback))
.route("/config", get(sso_config))
}
/// Backward-compatible Azure SSO routes — redirect to generic SSO endpoints.
pub fn azure_compat_router() -> Router<AppState> {
Router::new()
.route("/login", get(azure_login_redirect))
.route("/callback", get(azure_callback_redirect))
}
// ============================================================
// GET /api/v1/auth/sso/config
// ============================================================
/// Public endpoint returning minimal SSO configuration for the login page.
/// Returns only: enabled, display_name, auth_url — no secrets exposed.
async fn sso_config(
State(state): State<AppState>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let config = match load_oidc_config(&state.db).await {
Ok(c) => c,
Err(_) => {
// If we can't load config, SSO is effectively disabled
return Ok(Json(json!({
"enabled": false,
"display_name": "SSO",
"auth_url": ""
})));
},
};
Ok(Json(json!({
"enabled": config.enabled,
"display_name": if config.display_name.is_empty() { "SSO".to_string() } else { config.display_name },
"auth_url": "/api/v1/auth/sso/login"
})))
}
// ============================================================
// GET /api/v1/auth/sso/login
// ============================================================
async fn sso_login(
State(state): State<AppState>,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
let config = load_oidc_config(&state.db).await?;
if !config.enabled {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "SSO is not enabled" } })),
));
}
if config.discovery_url.is_empty() {
return Err((
StatusCode::FORBIDDEN,
Json(
json!({ "error": { "code": "forbidden", "message": "SSO discovery URL is not configured" } }),
),
));
}
// Fetch OIDC discovery document (with caching)
let discovery = match fetch_discovery(&state).await {
Ok(d) => d,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(
json!({ "error": { "code": "internal_error", "message": format!("Failed to fetch OIDC discovery: {}", e) } }),
),
));
},
};
// Generate PKCE code_verifier (32 random bytes → base64url)
let mut verifier_bytes = [0u8; 32];
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut verifier_bytes);
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
// code_challenge = BASE64URL(SHA256(code_verifier))
let challenge_digest = Sha256::digest(code_verifier.as_bytes());
let code_challenge = URL_SAFE_NO_PAD.encode(challenge_digest);
// Generate state token
let state_token = Uuid::new_v4().to_string();
// Store (state_token, code_verifier) in sso_sessions DashMap
state.sso_sessions.insert(
state_token.clone(),
SsoSession {
code_verifier,
created_at: Utc::now(),
},
);
// Build authorization URL from discovery
let encoded_scopes = urlencoding::encode(&config.scopes);
let auth_url = format!(
"{}?client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
discovery.authorization_endpoint,
urlencoding::encode(&config.client_id),
urlencoding::encode(&config.redirect_uri),
encoded_scopes,
code_challenge,
state_token
);
Ok(Redirect::to(&auth_url))
}
// ============================================================
// GET /api/v1/auth/sso/callback
// ============================================================
#[derive(Debug, Deserialize)]
struct CallbackParams {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
async fn sso_callback(
State(state): State<AppState>,
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
) -> Result<Redirect, Redirect> {
let callback_url = &state.config.security.sso_callback_url;
let error_redirect = |code: &str, message: &str| -> Redirect {
let url = format!(
"{}?error={}&error_description={}",
callback_url,
urlencoding::encode(code),
urlencoding::encode(message)
);
Redirect::to(&url)
};
if let Some(error) = params.error {
let desc = params.error_description.unwrap_or_default();
let message = format!("OIDC provider error: {} - {}", error, desc);
return Err(error_redirect("sso_error", &message));
}
let code = match params.code {
Some(c) => c,
None => return Err(error_redirect("bad_request", "Missing authorization code")),
};
let state_token = match params.state {
Some(s) => s,
None => return Err(error_redirect("bad_request", "Missing state parameter")),
};
let sso_session = match state.sso_sessions.remove(&state_token).map(|(_, v)| v) {
Some(s) => s,
None => {
return Err(error_redirect(
"bad_request",
"Invalid or expired state token",
))
},
};
let config = match load_oidc_config(&state.db).await {
Ok(c) => c,
Err(_) => {
return Err(error_redirect(
"internal_error",
"Failed to load OIDC config",
))
},
};
let discovery = match fetch_discovery(&state).await {
Ok(d) => d,
Err(e) => {
tracing::error!(error = %e, "Failed to fetch OIDC discovery");
return Err(error_redirect(
"internal_error",
"Failed to fetch OIDC discovery",
));
},
};
// Exchange code for tokens
let client = match reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
{
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Failed to build HTTP client");
return Err(error_redirect("internal_error", "HTTP client error"));
},
};
let mut params_vec: Vec<(&str, String)> = vec![
("grant_type", "authorization_code".to_string()),
("code", code.clone()),
("redirect_uri", config.redirect_uri.clone()),
("client_id", config.client_id.clone()),
("code_verifier", sso_session.code_verifier.clone()),
];
// For confidential clients (Azure AD), include client_secret
if !config.client_secret.is_empty() {
params_vec.push(("client_secret", config.client_secret.clone()));
}
let token_resp = match client
.post(&discovery.token_endpoint)
.form(&params_vec)
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Token exchange request failed");
return Err(error_redirect(
"sso_error",
&format!("Token exchange failed: {}", e),
));
},
};
if !token_resp.status().is_success() {
let status = token_resp.status();
let body = token_resp.text().await.unwrap_or_default();
tracing::error!(status = %status, body = %body, "Token exchange failed");
return Err(error_redirect(
"sso_error",
&format!("Token exchange failed: HTTP {}", status),
));
}
let token_data: TokenResponse = match token_resp.json().await {
Ok(d) => d,
Err(e) => {
tracing::error!(error = %e, "Failed to parse token response");
return Err(error_redirect(
"internal_error",
"Failed to parse token response",
));
},
};
let id_token = match token_data.id_token {
Some(t) => t,
None => return Err(error_redirect("sso_error", "No id_token in response")),
};
let claims = match verify_id_token(&id_token, &config, &discovery, &state.oidc_cache).await {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Failed to verify id_token");
return Err(error_redirect(
"internal_error",
"Failed to verify id_token",
));
},
};
let email = claims.email.unwrap_or_default();
let name = claims.name.unwrap_or_default();
let oidc_sub = claims.sub.unwrap_or_default();
let azure_oid = claims.oid.unwrap_or_default();
let preferred_username = claims.preferred_username.unwrap_or_else(|| email.clone());
let provider_subject = if !oidc_sub.is_empty() {
oidc_sub.clone()
} else if !azure_oid.is_empty() {
azure_oid.clone()
} else {
return Err(error_redirect(
"sso_error",
"Missing subject identifier in id_token",
));
};
if email.is_empty() {
return Err(error_redirect("sso_error", "Missing email in id_token"));
}
let auth_provider = match config.provider_type.as_str() {
"keycloak" => "keycloak",
"azure" => "azure_sso",
_ => "oidc",
};
// First try exact match: email AND auth_provider
let user_opt: Option<DbUserForSso> = match sqlx::query_as(
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
FROM users WHERE email = $1 AND auth_provider = $2::auth_provider"#,
)
.bind(&email)
.bind(auth_provider)
.fetch_optional(&state.db)
.await
{
Ok(o) => o,
Err(e) => {
tracing::error!(error = %e, "Failed to look up SSO user");
return Err(error_redirect("internal_error", "Database error"));
},
};
let user = match user_opt {
Some(u) if !u.is_active => {
return Err(error_redirect("account_disabled", "Account is disabled"));
},
Some(u) => u,
None => {
// Try to find existing user by email alone (may have different auth_provider)
let existing_user: Option<DbUserForSso> = match sqlx::query_as(
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
FROM users WHERE email = $1"#,
)
.bind(&email)
.fetch_optional(&state.db)
.await
{
Ok(o) => o,
Err(e) => {
tracing::error!(error = %e, "Failed to look up existing user by email");
return Err(error_redirect("internal_error", "Database error"));
},
};
match existing_user {
Some(existing) if !existing.is_active => {
return Err(error_redirect("account_disabled", "Account is disabled"));
},
Some(existing) => {
// Link existing local user to SSO provider
tracing::info!(user_id = %existing.id, "Linking existing user to SSO provider");
if let Err(e) = sqlx::query(
"UPDATE users SET auth_provider = $1::auth_provider, azure_oid = COALESCE(azure_oid, $2), oidc_sub = COALESCE(oidc_sub, $3) WHERE id = $4",
)
.bind(auth_provider)
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
.bind(existing.id)
.execute(&state.db)
.await
{
tracing::error!(error = %e, "Failed to link user to SSO provider");
return Err(error_redirect("internal_error", "Failed to link SSO account"));
}
log_event(
&state.db,
AuditAction::UserCreated,
None,
Some(auth_provider),
Some("user"),
Some(&existing.id.to_string()),
json!({ "action": "sso_link", "auth_provider": auth_provider, "email": email }),
None,
None,
)
.await;
DbUserForSso {
id: existing.id,
username: existing.username.clone(),
display_name: if name.is_empty() {
existing.display_name.clone()
} else {
name
},
role: existing.role.clone(),
is_active: existing.is_active,
mfa_enabled: existing.mfa_enabled,
}
},
None => {
// No existing user - create new one
let id: Uuid = match sqlx::query_scalar(
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid, oidc_sub)
VALUES ($1, $2, $3, 'reporter'::user_role, $4::auth_provider, $5, $6)
RETURNING id"#,
)
.bind(&preferred_username)
.bind(&name)
.bind(&email)
.bind(auth_provider)
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
.fetch_one(&state.db)
.await
{
Ok(id) => id,
Err(e) => {
tracing::error!(error = %e, "Failed to create SSO user");
return Err(error_redirect("internal_error", "Failed to create user"));
},
};
log_event(
&state.db,
AuditAction::UserCreated,
None,
Some(auth_provider),
Some("user"),
Some(&id.to_string()),
json!({ "auth_provider": auth_provider, "email": email }),
None,
None,
)
.await;
DbUserForSso {
id,
username: preferred_username,
display_name: name,
role: "reporter".to_string(),
is_active: true,
mfa_enabled: false,
}
},
}
},
};
// Update last_login_at and provider subject IDs
if let Err(e) = sqlx::query(
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1), oidc_sub = COALESCE(oidc_sub, $2) WHERE id = $3",
)
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
.bind(user.id)
.execute(&state.db)
.await
{
tracing::error!(error = %e, "Failed to update last_login_at");
return Err(error_redirect("internal_error", "Database error"));
}
let access_ttl = state.config.security.jwt_access_ttl_secs as i64;
let access_token = match issue_access_token(
user.id,
&user.username,
&user.role,
access_ttl,
&state.signing_key_pem,
) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "Failed to issue access token");
return Err(error_redirect("internal_error", "Token issuance failed"));
},
};
let raw_refresh = match refresh::issue(&state.db, user.id, None, None).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Failed to issue refresh token");
return Err(error_redirect(
"internal_error",
"Refresh token issuance failed",
));
},
};
log_event(
&state.db,
AuditAction::UserLogin,
Some(user.id),
Some(&user.username),
None,
None,
json!({ "auth_provider": auth_provider }),
None,
None,
)
.await;
let user_json = json!({
"id": user.id.to_string(),
"username": user.username,
"display_name": user.display_name,
"role": user.role,
"auth_provider": auth_provider,
"mfa_enabled": user.mfa_enabled,
});
let redirect_url = format!(
"{}?access_token={}&refresh_token={}&token_type=Bearer&expires_in={}&user={}",
callback_url,
urlencoding::encode(&access_token),
urlencoding::encode(&raw_refresh.0),
access_ttl,
urlencoding::encode(&user_json.to_string()),
);
Ok(Redirect::to(&redirect_url))
}
// ============================================================
// Backward-compatible Azure SSO redirect handlers
// ============================================================
async fn azure_login_redirect(
State(state): State<AppState>,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
sso_login(State(state)).await
}
async fn azure_callback_redirect(
State(state): State<AppState>,
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
) -> Result<Redirect, Redirect> {
sso_callback(State(state), axum::extract::Query(params)).await
}
// ============================================================
// Database helpers
// ============================================================
async fn load_oidc_config(pool: &sqlx::PgPool) -> Result<OidcConfig, (StatusCode, Json<Value>)> {
let row: Option<OidcConfig> = sqlx::query_as(
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
)
.fetch_optional(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(row.unwrap_or(OidcConfig {
enabled: false,
provider_type: "azure".to_string(),
display_name: "Azure AD".to_string(),
discovery_url: String::new(),
client_id: String::new(),
client_secret: String::new(),
redirect_uri: String::new(),
scopes: "openid profile email".to_string(),
}))
}
// ============================================================
// OIDC Discovery & JWKS
// ============================================================
async fn fetch_discovery(state: &AppState) -> Result<OidcDiscovery, String> {
let config = match load_oidc_config(&state.db).await {
Ok(c) => c,
Err(_) => {
return Err("Failed to load OIDC config".to_string());
},
};
let discovery_url = config.discovery_url;
// Check cache first
{
let cache = state.oidc_cache.lock().await;
if let Some(ref disc) = cache.discovery {
let elapsed = Utc::now().signed_duration_since(disc.fetched_at);
if elapsed.num_seconds() < DISCOVERY_CACHE_TTL_SECS {
return Ok(disc.clone());
}
}
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;
let resp = client
.get(&discovery_url)
.send()
.await
.map_err(|e| format!("Discovery fetch failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!(
"Discovery fetch failed: HTTP {}{}",
status, body
));
}
let doc: Value = resp
.json()
.await
.map_err(|e| format!("Failed to parse discovery document: {}", e))?;
let discovery = OidcDiscovery {
issuer: doc
.get("issuer")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
authorization_endpoint: doc
.get("authorization_endpoint")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
token_endpoint: doc
.get("token_endpoint")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
jwks_uri: doc
.get("jwks_uri")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
userinfo_endpoint: doc
.get("userinfo_endpoint")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
fetched_at: Utc::now(),
};
{
let mut cache = state.oidc_cache.lock().await;
cache.discovery = Some(discovery.clone());
}
Ok(discovery)
}
async fn verify_id_token(
token: &str,
config: &OidcConfig,
discovery: &OidcDiscovery,
oidc_cache: &Arc<Mutex<OidcCache>>,
) -> Result<IdTokenClaims, String> {
let header = decode_header(token).map_err(|e| format!("Failed to decode JWT header: {}", e))?;
let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
let jwks = {
let cache = oidc_cache.lock().await;
let needs_fetch = match (&cache.jwks, &cache.jwks_fetched_at) {
(None, _) => true,
(Some(_), None) => true,
(Some(_), Some(fetched)) => {
let elapsed = Utc::now().signed_duration_since(*fetched);
elapsed.num_seconds() > JWKS_CACHE_TTL_SECS
},
};
if needs_fetch {
drop(cache);
let jwks_value = fetch_jwks(&discovery.jwks_uri).await?;
let mut cache = oidc_cache.lock().await;
cache.jwks = Some(jwks_value);
cache.jwks_fetched_at = Some(Utc::now());
cache.jwks.clone().unwrap()
} else {
cache.jwks.clone().unwrap()
}
};
let keys_array = jwks
.get("keys")
.ok_or("JWKS response missing 'keys' array")?
.as_array()
.ok_or("JWKS 'keys' is not an array")?;
let jwk = keys_array
.iter()
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()))
.ok_or_else(|| format!("No matching JWK found for kid: {}", kid))?;
let n = jwk
.get("n")
.and_then(|v| v.as_str())
.ok_or("JWK missing 'n' (modulus) field")?;
let e = jwk
.get("e")
.and_then(|v| v.as_str())
.ok_or("JWK missing 'e' (exponent) field")?;
let decoding_key = DecodingKey::from_rsa_components(n, e)
.map_err(|e| format!("Failed to construct RSA decoding key: {}", e))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.iss = Some(HashSet::from([discovery.issuer.clone()]));
validation.aud = Some(HashSet::from([config.client_id.clone()]));
validation.leeway = 60;
let token_data = decode::<IdTokenClaims>(token, &decoding_key, &validation)
.map_err(|e| format!("JWT signature verification failed: {}", e))?;
Ok(token_data.claims)
}
async fn fetch_jwks(jwks_uri: &str) -> Result<serde_json::Value, String> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| format!("Failed to build HTTP client for JWKS fetch: {}", e))?;
let resp = client
.get(jwks_uri)
.send()
.await
.map_err(|e| format!("JWKS fetch request failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("JWKS fetch failed: HTTP {}{}", status, body));
}
resp.json::<serde_json::Value>()
.await
.map_err(|e| format!("Failed to parse JWKS response: {}", e))
}

View File

@ -0,0 +1,145 @@
//! Fleet status routes.
//!
//! GET /api/v1/status/fleet — aggregate health and patch summary across all hosts.
use axum::{extract::State, http::StatusCode, response::Json, routing::get, Router};
use serde::Serialize;
use serde_json::{json, Value};
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new().route("/fleet", get(fleet_status))
}
// ── Response type ─────────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
pub struct FleetStatus {
pub total_hosts: i64,
pub healthy: i64,
pub degraded: i64,
pub unreachable: i64,
pub pending: i64,
pub total_pending_patches: i64,
pub hosts_requiring_reboot: i64,
pub compliance_pct: f64,
}
// ── GET /api/v1/status/fleet ──────────────────────────────────────────────────
pub async fn fleet_status(
State(state): State<AppState>,
) -> Result<Json<FleetStatus>, (StatusCode, Json<Value>)> {
// ── 1. Host health aggregates ─────────────────────────────────────────
let health_row: (i64, i64, i64, i64, i64) = sqlx::query_as(
r#"
SELECT
COUNT(*) AS total_hosts,
COUNT(*) FILTER (WHERE health_status = 'healthy') AS healthy,
COUNT(*) FILTER (WHERE health_status = 'degraded') AS degraded,
COUNT(*) FILTER (WHERE health_status = 'unreachable') AS unreachable,
COUNT(*) FILTER (WHERE health_status = 'pending') AS pending
FROM hosts
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query host health aggregates");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let (total_hosts, healthy, degraded, unreachable, pending) = health_row;
// ── 2. Total pending patches across fleet (latest row per host) ───────
let total_pending_patches: i64 = sqlx::query_scalar(
r#"
SELECT COALESCE(SUM(patch_count), 0)
FROM (
SELECT DISTINCT ON (host_id) patch_count
FROM host_patch_data
ORDER BY host_id, polled_at DESC
) latest
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query total pending patches");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
// ── 3. Hosts requiring a reboot (latest patch row per host) ───────────
let hosts_requiring_reboot: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM (
SELECT DISTINCT ON (host_id) available_patches
FROM host_patch_data
ORDER BY host_id, polled_at DESC
) latest
WHERE available_patches @> '[{"requires_reboot": true}]'
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query reboot-required hosts");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
// ── 4. Compliance: hosts with zero pending patches / total hosts ───────
// Hosts that have been polled and have patch_count == 0 are considered
// compliant. Hosts with no patch data at all are excluded from the
// compliance calculation.
let compliant_hosts: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM (
SELECT DISTINCT ON (host_id) patch_count
FROM host_patch_data
ORDER BY host_id, polled_at DESC
) latest
WHERE patch_count = 0
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query compliant hosts");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let compliance_pct = if total_hosts == 0 {
100.0_f64
} else {
(compliant_hosts as f64 / total_hosts as f64) * 100.0
};
// Round to one decimal place.
let compliance_pct = (compliance_pct * 10.0).round() / 10.0;
Ok(Json(FleetStatus {
total_hosts,
healthy,
degraded,
unreachable,
pending,
total_pending_patches,
hosts_requiring_reboot,
compliance_pct,
}))
}

571
crates/pm-web/src/routes/users.rs Executable file
View File

@ -0,0 +1,571 @@
//! User management routes.
//!
//! GET /api/v1/users — list users (admin only)
//! POST /api/v1/users — create user (admin only)
//! GET /api/v1/users/:id — get user detail
//! PUT /api/v1/users/:id — update user
//! DELETE /api/v1/users/:id — delete user (admin only)
//! GET /api/v1/users/me — current user profile
//! POST /api/v1/users/:id/revoke — revoke all sessions (admin only)
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{delete, get, post, put},
Router,
};
use pm_auth::validate_password_strength;
use pm_auth::{hash_password, rbac::AuthUser, session::force_logout, verify_password};
use pm_core::{
audit::{log_event, AuditAction},
models::{
AdminResetPasswordRequest, ChangePasswordRequest, CreateUserRequest, UpdateUserRequest,
User,
},
};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_users).post(create_user))
.route("/me", get(get_current_user))
.route("/me/password", put(change_own_password))
.route("/{id}", get(get_user).put(update_user).delete(delete_user))
.route("/{id}/password", put(admin_reset_password))
.route("/{id}/mfa", delete(admin_disable_mfa))
.route("/{id}/revoke", post(revoke_user_sessions))
}
async fn list_users(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Vec<User>>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
sqlx::query_as::<_, User>(
r#"SELECT id, username, display_name, email, role, auth_provider,
mfa_enabled, is_active, force_password_reset, last_login_at,
created_at, updated_at
FROM users ORDER BY username"#,
)
.fetch_all(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})
}
async fn create_user(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateUserRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
// Validate password strength
if let Err(msg) = validate_password_strength(&req.password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
let hash = hash_password(&req.password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
let role = match req.role.to_lowercase().as_str() {
"admin" => "admin",
"reporter" => "reporter",
_ => "operator",
};
let id: Uuid = sqlx::query_scalar(
r#"INSERT INTO users (username, display_name, email, role, auth_provider, password_hash)
VALUES ($1, $2, $3, $4::user_role, 'local', $5)
RETURNING id"#,
)
.bind(&req.username)
.bind(req.display_name.as_deref().unwrap_or(&req.username))
.bind(&req.email)
.bind(role)
.bind(&hash)
.fetch_one(&state.db)
.await
.map_err(|e| {
let msg = if e.to_string().contains("unique") {
"Username or email already exists".to_string()
} else {
"Database error".to_string()
};
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
log_event(
&state.db,
AuditAction::UserCreated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({ "username": req.username }),
None,
None,
)
.await;
Ok(Json(json!({ "id": id, "message": "User created" })))
}
async fn get_current_user(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
fetch_user(&state.db, auth.user_id).await
}
async fn get_user(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
// Users can see themselves; admin can see anyone
if !auth.role.is_admin() && auth.user_id != id {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
fetch_user(&state.db, id).await
}
async fn fetch_user(
pool: &sqlx::PgPool,
id: Uuid,
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
let user: Option<User> = sqlx::query_as(
r#"SELECT id, username, display_name, email, role, auth_provider,
mfa_enabled, is_active, force_password_reset, last_login_at,
created_at, updated_at
FROM users WHERE id = $1"#,
)
.bind(id)
.fetch_optional(pool)
.await
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
user.map(Json).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
)
})
}
async fn update_user(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<UpdateUserRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() && auth.user_id != id {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
// Only admins can change role or active status
if (req.role.is_some() || req.is_active.is_some() || req.force_password_reset.is_some())
&& !auth.role.is_admin()
{
return Err((
StatusCode::FORBIDDEN,
Json(
json!({ "error": { "code": "forbidden", "message": "Admin role required to change role, status, or force_password_reset" } }),
),
));
}
let role_str = req
.role
.as_deref()
.map(|r| match r.to_lowercase().as_str() {
"admin" => "admin",
"reporter" => "reporter",
_ => "operator",
});
let rows = sqlx::query(
r#"UPDATE users SET
display_name = COALESCE($1, display_name),
email = COALESCE($2, email),
role = COALESCE($3::user_role, role),
is_active = COALESCE($4, is_active),
force_password_reset = COALESCE($5, force_password_reset),
updated_at = NOW()
WHERE id = $6"#,
)
.bind(req.display_name.as_deref())
.bind(req.email.as_deref())
.bind(role_str)
.bind(req.is_active)
.bind(req.force_password_reset)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({}),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User updated" })))
}
async fn delete_user(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
if auth.user_id == id {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "bad_request", "message": "Cannot delete your own account" } }),
),
));
}
let rows = sqlx::query("DELETE FROM users WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
log_event(
&state.db,
AuditAction::UserDeleted,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({}),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User deleted" })))
}
async fn revoke_user_sessions(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
let count = force_logout(&state.db, id).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
Ok(Json(
json!({ "message": "Sessions revoked", "count": count }),
))
}
// ============================================================
// PUT /api/v1/users/me/password — change own password
// ============================================================
async fn change_own_password(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<ChangePasswordRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Fetch current password hash
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
.bind(auth.user_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to fetch password hash");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?
.flatten();
let hash_str = hash.unwrap_or_default();
let valid = verify_password(&req.current_password, &hash_str).unwrap_or(false);
if !valid {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
),
));
}
// Validate new password strength
if let Err(msg) = validate_password_strength(&req.new_password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
let new_hash = hash_password(&req.new_password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
sqlx::query(
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, updated_at = NOW() WHERE id = $2",
)
.bind(&new_hash)
.bind(auth.user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to update password");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
)
})?;
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&auth.user_id.to_string()),
json!({ "action": "password_change" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Password changed successfully" })))
}
// ============================================================
// PUT /api/v1/users/:id/password — admin reset password
// ============================================================
async fn admin_reset_password(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<AdminResetPasswordRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
// Verify target user exists
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)")
.bind(id)
.fetch_one(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
if !exists {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
// Validate new password strength
if let Err(msg) = validate_password_strength(&req.new_password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
let new_hash = hash_password(&req.new_password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
sqlx::query(
"UPDATE users SET password_hash = $1, force_password_reset = $2, updated_at = NOW() WHERE id = $3",
)
.bind(&new_hash)
.bind(req.force_password_reset)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to reset password");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to reset password" } })),
)
})?;
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({ "action": "admin_password_reset", "force_password_reset": req.force_password_reset }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Password reset successfully" })))
}
// ============================================================
// DELETE /api/v1/users/:id/mfa — admin disable MFA
// ============================================================
async fn admin_disable_mfa(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
let rows = sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE, updated_at = NOW() WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to disable MFA");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({ "action": "admin_mfa_disabled" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "MFA disabled successfully" })))
}

205
crates/pm-web/src/routes/ws.rs Executable file
View File

@ -0,0 +1,205 @@
//! WebSocket relay routes — M7
//!
//! POST /api/v1/ws/ticket — create a single-use WS auth ticket (JWT-protected)
//! GET /api/v1/ws/jobs — browser WebSocket endpoint (ticket-authenticated)
use axum::{
extract::ws::{Message, WebSocket},
extract::{Query, State, WebSocketUpgrade},
http::StatusCode,
response::{Json, Response},
routing::{get, post},
Router,
};
use chrono::{Duration, Utc};
use pm_auth::rbac::AuthUser;
use serde::Deserialize;
use serde_json::{json, Value};
use sqlx::postgres::PgListener;
use ulid::Ulid;
use uuid::Uuid;
use crate::AppState;
// ── WsTicket ──────────────────────────────────────────────────────────────────
/// Single-use WebSocket authentication ticket stored in-memory.
#[derive(Debug, Clone)]
pub struct WsTicket {
pub user_id: Uuid,
pub role: String,
pub expires_at: chrono::DateTime<Utc>,
}
// ── Router ────────────────────────────────────────────────────────────────────
/// Router for ticket-issuance endpoint (JWT-protected, merged into protected_api).
pub fn ticket_router() -> Router<AppState> {
Router::new().route("/ws/ticket", post(create_ticket_handler))
}
/// Router for the WebSocket endpoint (ticket-authenticated, NO JWT middleware).
pub fn ws_router() -> Router<AppState> {
Router::new().route("/api/v1/ws/jobs", get(ws_handler))
}
// ── Error helper ─────────────────────────────────────────────────────────────
#[inline]
fn err(
status: StatusCode,
code: &'static str,
message: impl Into<String>,
) -> (StatusCode, Json<Value>) {
(
status,
Json(json!({ "error": { "code": code, "message": message.into() } })),
)
}
// ── POST /api/v1/ws/ticket ────────────────────────────────────────────────────
/// Issue a single-use WebSocket authentication ticket (60 s expiry).
pub async fn create_ticket_handler(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let ticket_id = Ulid::new().to_string();
let expires_at = Utc::now() + Duration::seconds(60);
let ticket = WsTicket {
user_id: auth.user_id,
role: auth.role.as_str().to_string(),
expires_at,
};
state.ws_tickets.insert(ticket_id.clone(), ticket);
tracing::info!(
user_id = %auth.user_id,
username = %auth.username,
ticket = %ticket_id,
"WS ticket issued"
);
Ok(Json(json!({ "ticket": ticket_id })))
}
// ── GET /api/v1/ws/jobs ───────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
pub struct WsQuery {
pub ticket: String,
}
/// Browser WebSocket upgrade endpoint — authenticates via single-use ticket.
pub async fn ws_handler(
State(state): State<AppState>,
Query(q): Query<WsQuery>,
ws: WebSocketUpgrade,
) -> Result<Response, (StatusCode, Json<Value>)> {
// Validate and consume the ticket atomically.
let ticket = {
let entry = state.ws_tickets.get(&q.ticket);
match entry {
None => {
return Err(err(
StatusCode::UNAUTHORIZED,
"invalid_ticket",
"WebSocket ticket not found or already used",
));
},
Some(t) => {
if t.expires_at < Utc::now() {
drop(t);
state.ws_tickets.remove(&q.ticket);
return Err(err(
StatusCode::UNAUTHORIZED,
"ticket_expired",
"WebSocket ticket has expired",
));
}
t.clone()
},
}
};
// Single-use: remove immediately after validation.
state.ws_tickets.remove(&q.ticket);
tracing::info!(
user_id = %ticket.user_id,
role = %ticket.role,
"Browser WebSocket connection upgraded"
);
let db = state.db.clone();
Ok(ws.on_upgrade(move |socket| handle_browser_ws(socket, db, ticket)))
}
// ── WebSocket handler ─────────────────────────────────────────────────────────
/// Drive the browser WebSocket: LISTEN on `job_update` and forward payloads.
async fn handle_browser_ws(mut socket: WebSocket, db: sqlx::PgPool, ticket: WsTicket) {
// Acquire a dedicated PG listener connection.
let mut listener = match PgListener::connect_with(&db).await {
Ok(l) => l,
Err(e) => {
tracing::error!(error = %e, user_id = %ticket.user_id, "Failed to create PgListener");
let _ = socket
.send(Message::Text(
json!({ "error": "internal_error" }).to_string().into(),
))
.await;
return;
},
};
if let Err(e) = listener.listen("job_update").await {
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener LISTEN failed");
return;
}
tracing::info!(user_id = %ticket.user_id, "Browser WS: LISTEN job_update started");
loop {
tokio::select! {
// Forward PG notifications to the browser.
notify_result = listener.recv() => {
match notify_result {
Ok(notification) => {
let payload = notification.payload().to_string();
tracing::debug!(user_id = %ticket.user_id, payload = %payload, "Forwarding job_update");
if socket.send(Message::Text(payload.into())).await.is_err() {
tracing::info!(user_id = %ticket.user_id, "Browser WS send failed — client disconnected");
break;
}
}
Err(e) => {
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener recv error");
break;
}
}
}
// Handle incoming frames from the browser (ping/close).
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => {
tracing::info!(user_id = %ticket.user_id, "Browser WS closed by client");
break;
}
Some(Ok(Message::Ping(data))) if socket.send(Message::Pong(data.clone())).await.is_err() => {
break;
}
Some(Err(e)) => {
tracing::debug!(error = %e, user_id = %ticket.user_id, "Browser WS recv error");
break;
}
_ => {}
}
}
}
}
tracing::info!(user_id = %ticket.user_id, "Browser WS handler exiting");
}