feat: add bump-version.sh script for version management
Automates version bumps across all version source files: - Cargo.toml (PRIMARY - workspace.package.version) - debian/changelog (prepend new entry) - debian/control (update Version field) - scripts/build-package.sh (update VERSION variable) - frontend/package.json (update version field) - Stale references check after bump Usage: ./scripts/bump-version.sh <new_version> <old_version>
This commit is contained in:
46
crates/pm-web/Cargo.toml
Normal file
46
crates/pm-web/Cargo.toml
Normal file
@ -0,0 +1,46 @@
|
||||
[package]
|
||||
name = "pm-web"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "pm-web"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pm-ca = { path = "../pm-ca" }
|
||||
pm-core = { path = "../pm-core" }
|
||||
pm-auth = { path = "../pm-auth" }
|
||||
pm-reports = { path = "../pm-reports" }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
axum-server = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
ulid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
dashmap = { version = "6" }
|
||||
tower_governor = { workspace = true }
|
||||
governor = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
|
||||
rand = { workspace = true }
|
||||
hex = "0.4"
|
||||
base64 = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
url = { workspace = true }
|
||||
urlencoding = "2"
|
||||
353
crates/pm-web/src/main.rs
Normal file
353
crates/pm-web/src/main.rs
Normal file
@ -0,0 +1,353 @@
|
||||
//! pm-web — Linux Patch Manager web server.
|
||||
|
||||
mod routes;
|
||||
|
||||
use axum::{extract::State, http::StatusCode, middleware, response::Json, routing::get, Router};
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use dashmap::DashMap;
|
||||
use pm_auth::{
|
||||
jwt,
|
||||
rbac::{require_auth, AuthConfig},
|
||||
};
|
||||
use pm_core::{
|
||||
config::AppConfig, db, logging, models::PkiBundle, request_id::request_id_middleware,
|
||||
};
|
||||
use routes::sso::{OidcCache, SsoSession};
|
||||
use routes::ws::WsTicket;
|
||||
use serde_json::{json, Value};
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_governor::{
|
||||
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
|
||||
};
|
||||
use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
trace::TraceLayer,
|
||||
};
|
||||
|
||||
/// Shared application state threaded through Axum.
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub db: sqlx::PgPool,
|
||||
pub config: Arc<AppConfig>,
|
||||
pub signing_key_pem: String,
|
||||
pub auth_config: Arc<AuthConfig>,
|
||||
/// In-memory store for single-use WebSocket authentication tickets.
|
||||
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
|
||||
/// In-memory store for SSO PKCE sessions (state → code_verifier).
|
||||
pub sso_sessions: Arc<DashMap<String, SsoSession>>,
|
||||
/// Cached OIDC discovery document and JWKS for SSO id_token verification.
|
||||
pub oidc_cache: Arc<Mutex<OidcCache>>,
|
||||
/// Internal certificate authority for mTLS client cert issuance.
|
||||
pub ca: Arc<pm_ca::CertAuthority>,
|
||||
/// Short-lived cache for approved enrollment PKI bundles.
|
||||
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Install the default crypto provider for rustls (required since 0.23)
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install rustls crypto provider");
|
||||
|
||||
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
|
||||
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
|
||||
|
||||
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
|
||||
eprintln!("Config file not found or invalid, using defaults");
|
||||
AppConfig::default()
|
||||
});
|
||||
|
||||
logging::init(&config.logging);
|
||||
tracing::info!(
|
||||
version = env!("CARGO_PKG_VERSION"),
|
||||
"patch-manager-web starting"
|
||||
);
|
||||
|
||||
let signing_key_pem = jwt::load_signing_key(&config.security.jwt_signing_key_path)
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "JWT signing key not found (dev mode)");
|
||||
String::new()
|
||||
});
|
||||
|
||||
let verify_key_pem =
|
||||
jwt::load_verify_key(&config.security.jwt_verify_key_path).unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "JWT verify key not found (dev mode)");
|
||||
String::new()
|
||||
});
|
||||
|
||||
let auth_config = Arc::new(AuthConfig::new(
|
||||
verify_key_pem,
|
||||
&config.security.ip_whitelist,
|
||||
));
|
||||
|
||||
let pool = db::init_pool(&config.database).await?;
|
||||
db::run_migrations(&pool).await?;
|
||||
|
||||
// Initialise the internal CA using the configured certificate paths.
|
||||
// The CA certificate and key must exist at the configured locations and be
|
||||
// unencrypted PEM. If absent, a new CA is generated in that directory.
|
||||
let ca_base = std::path::Path::new(&config.security.ca_cert_path)
|
||||
.parent()
|
||||
.expect("CA certificate path must have a parent directory");
|
||||
let ca = pm_ca::CertAuthority::init(ca_base, &pool)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "CA init failed (dev mode)");
|
||||
panic!("CA initialization failed: {}", e);
|
||||
});
|
||||
|
||||
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
|
||||
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
|
||||
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
|
||||
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
|
||||
|
||||
// Background task: purge expired WS tickets every 30 seconds.
|
||||
{
|
||||
let tickets = ws_tickets.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(30));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = chrono::Utc::now();
|
||||
let before = tickets.len();
|
||||
tickets.retain(|_, v| v.expires_at > now);
|
||||
let removed = before.saturating_sub(tickets.len());
|
||||
if removed > 0 {
|
||||
tracing::debug!(removed, "Purged expired WS tickets");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Background task: purge expired SSO sessions every 60 seconds (sessions older than 10 minutes).
|
||||
{
|
||||
let sessions = sso_sessions.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = chrono::Utc::now();
|
||||
let cutoff = now - chrono::Duration::minutes(10);
|
||||
let before = sessions.len();
|
||||
sessions.retain(|_, v| v.created_at > cutoff);
|
||||
let removed = before.saturating_sub(sessions.len());
|
||||
if removed > 0 {
|
||||
tracing::debug!(removed, "Purged expired SSO sessions");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Background task: purge approved enrollment PKI bundles every 10 minutes.
|
||||
{
|
||||
let approved = approved_enrollments.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(600));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
approved.clear();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let state = AppState {
|
||||
db: pool,
|
||||
config: Arc::new(config.clone()),
|
||||
signing_key_pem,
|
||||
auth_config,
|
||||
ws_tickets,
|
||||
sso_sessions,
|
||||
ca: Arc::new(ca),
|
||||
approved_enrollments,
|
||||
oidc_cache,
|
||||
};
|
||||
|
||||
let app = build_router(state);
|
||||
|
||||
let addr: SocketAddr = format!("{}:{}", config.server.host, config.server.port)
|
||||
.parse()
|
||||
.expect("Invalid bind address");
|
||||
|
||||
// Try to load TLS certificate and key; fall back to plain HTTP if missing.
|
||||
let tls_cert = std::path::Path::new(&config.security.web_tls_cert_path);
|
||||
let tls_key = std::path::Path::new(&config.security.web_tls_key_path);
|
||||
|
||||
if tls_cert.exists() && tls_key.exists() {
|
||||
let tls_config = RustlsConfig::from_pem_file(
|
||||
&config.security.web_tls_cert_path,
|
||||
&config.security.web_tls_key_path,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load TLS certificates");
|
||||
e
|
||||
})?;
|
||||
|
||||
tracing::info!(%addr, "Listening (HTTPS)");
|
||||
axum_server::bind_rustls(addr, tls_config)
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await?;
|
||||
} else {
|
||||
tracing::warn!(
|
||||
cert_path = %config.security.web_tls_cert_path,
|
||||
key_path = %config.security.web_tls_key_path,
|
||||
"TLS certificates not found — falling back to plain HTTP. \
|
||||
This is insecure for production!"
|
||||
);
|
||||
tracing::info!(%addr, "Listening (HTTP — no TLS)");
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Construct the full Axum router.
|
||||
pub fn build_router(state: AppState) -> Router {
|
||||
let static_dir = state.config.server.static_dir.clone();
|
||||
let auth_config = state.auth_config.clone();
|
||||
let rl = &state.config.rate_limit;
|
||||
|
||||
// Enrollment rate limiting: strict (5 req/min per IP, burst 3)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 12_000ms = ~5/min sustained
|
||||
let enrollment_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(12_000)
|
||||
.burst_size(rl.enrollment_burst)
|
||||
.finish()
|
||||
.expect("Invalid enrollment governor config"),
|
||||
);
|
||||
|
||||
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 3_000ms = ~20/min sustained
|
||||
let auth_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(3_000)
|
||||
.burst_size(rl.auth_burst)
|
||||
.finish()
|
||||
.expect("Invalid auth governor config"),
|
||||
);
|
||||
|
||||
// API rate limiting: normal (120 req/min per IP, burst 30)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 500ms = ~120/min sustained
|
||||
let api_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(500)
|
||||
.burst_size(rl.api_burst)
|
||||
.finish()
|
||||
.expect("Invalid API governor config"),
|
||||
);
|
||||
|
||||
// Enrollment routes with strict per-IP rate limiting
|
||||
let enrollment_router =
|
||||
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
|
||||
|
||||
// Public auth routes with moderate per-IP rate limiting
|
||||
let auth_public_router =
|
||||
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
|
||||
|
||||
// SSO routes with moderate per-IP rate limiting
|
||||
let sso_public_router =
|
||||
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
|
||||
let sso_azure_router =
|
||||
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
|
||||
|
||||
// All protected API routes — require valid JWT, with normal per-IP rate limiting
|
||||
let protected_api = Router::new()
|
||||
// Auth: MFA setup/verify
|
||||
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
|
||||
.nest("/auth", routes::auth::protected_router())
|
||||
// Hosts
|
||||
.nest("/hosts", routes::hosts::router())
|
||||
// Host-scoped certificate endpoints (merged separately to avoid conflict)
|
||||
.nest("/hosts", routes::ca::host_cert_router())
|
||||
// Groups
|
||||
.nest("/groups", routes::groups::router())
|
||||
// Users
|
||||
.nest("/users", routes::users::router())
|
||||
// Discovery
|
||||
.nest("/discovery", routes::discovery::router())
|
||||
// Fleet status
|
||||
.nest("/status", routes::status::router())
|
||||
// Patch jobs
|
||||
.nest("/jobs", routes::jobs::router())
|
||||
// Maintenance windows (nested under hosts path param)
|
||||
.nest(
|
||||
"/hosts/{host_id}/maintenance-windows",
|
||||
routes::maintenance_windows::router(),
|
||||
)
|
||||
// Maintenance windows — bulk list-all endpoint
|
||||
.nest(
|
||||
"/maintenance-windows",
|
||||
routes::maintenance_windows::all_windows_router(),
|
||||
)
|
||||
// CA root certificate download
|
||||
.nest("/ca", routes::ca::ca_router())
|
||||
// Certificate list / renew / revoke
|
||||
.nest("/certificates", routes::ca::certs_router())
|
||||
// WS ticket issuance (JWT-protected — ticket returned to browser, then used for WS upgrade)
|
||||
.merge(routes::ws::ticket_router())
|
||||
// Reports
|
||||
.nest("/reports", routes::reports::router())
|
||||
.nest(
|
||||
"/hosts/{host_id}/health-checks",
|
||||
routes::health_checks::router(),
|
||||
)
|
||||
// Settings (admin-only)
|
||||
.nest("/settings", routes::settings::router())
|
||||
// Admin enrollment routes (JWT protected, Admin role enforced)
|
||||
.nest("/admin", routes::enrollment::admin_router())
|
||||
// Apply rate limiting then auth middleware
|
||||
.layer(GovernorLayer::new(api_governor))
|
||||
.route_layer(middleware::from_fn(move |req, next| {
|
||||
let auth_config = auth_config.clone();
|
||||
require_auth(auth_config, req, next)
|
||||
}));
|
||||
|
||||
Router::new()
|
||||
.route("/status/health", get(health_handler))
|
||||
// Public auth routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth", auth_public_router)
|
||||
// Public enrollment endpoints (rate-limited, no JWT)
|
||||
.nest("/api/v1", enrollment_router)
|
||||
// Public SSO routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth/sso", sso_public_router)
|
||||
// Public Azure SSO routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth/azure", sso_azure_router)
|
||||
// Protected API routes (JWT required, rate-limited)
|
||||
.nest("/api/v1", protected_api)
|
||||
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
|
||||
.merge(routes::ws::ws_router())
|
||||
// Serve React SPA
|
||||
.fallback_service(
|
||||
ServeDir::new(&static_dir)
|
||||
.append_index_html_on_directories(true)
|
||||
.fallback(ServeFile::new(format!("{}/index.html", static_dir))),
|
||||
)
|
||||
.layer(middleware::from_fn(request_id_middleware))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn health_handler(State(state): State<AppState>) -> Result<Json<Value>, StatusCode> {
|
||||
let db_ok = sqlx::query("SELECT 1").execute(&state.db).await.is_ok();
|
||||
let status = if db_ok { "healthy" } else { "degraded" };
|
||||
let body = json!({ "service": "patch-manager-web", "version": env!("CARGO_PKG_VERSION"), "status": status, "database": if db_ok { "ok" } else { "error" } });
|
||||
if db_ok {
|
||||
Ok(Json(body))
|
||||
} else {
|
||||
Err(StatusCode::SERVICE_UNAVAILABLE)
|
||||
}
|
||||
}
|
||||
434
crates/pm-web/src/routes/auth.rs
Executable file
434
crates/pm-web/src/routes/auth.rs
Executable file
@ -0,0 +1,434 @@
|
||||
//! Authentication route handlers.
|
||||
//!
|
||||
//! Public routes (no auth required):
|
||||
//! POST /api/v1/auth/login
|
||||
//! POST /api/v1/auth/refresh
|
||||
//! POST /api/v1/auth/logout
|
||||
//!
|
||||
//! Protected routes (JWT required):
|
||||
//! GET /api/v1/auth/mfa/setup
|
||||
//! POST /api/v1/auth/mfa/verify
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::Json,
|
||||
routing::delete,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::{
|
||||
hash_password, mfa_totp,
|
||||
rbac::AuthUser,
|
||||
session::{self, LoginRequest, LoginResponse},
|
||||
validate_password_strength, verify_password,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ============================================================
|
||||
// Public router — no authentication required
|
||||
// ============================================================
|
||||
|
||||
pub fn public_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", post(login_handler))
|
||||
.route("/refresh", post(refresh_handler))
|
||||
.route("/logout", post(logout_handler))
|
||||
.route(
|
||||
"/force-change-password",
|
||||
post(force_change_password_handler),
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Protected router — requires valid JWT (applied by caller)
|
||||
// ============================================================
|
||||
|
||||
pub fn protected_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/mfa/setup", get(mfa_setup_handler))
|
||||
.route("/mfa/verify", post(mfa_verify_handler))
|
||||
.route("/mfa", delete(disable_mfa))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn user_agent(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("user-agent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
fn remote_ip(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.split(',').next().unwrap_or("").trim().to_string())
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/login
|
||||
// ============================================================
|
||||
|
||||
async fn login_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
|
||||
let ip = remote_ip(&headers);
|
||||
let ua = user_agent(&headers);
|
||||
|
||||
session::login(
|
||||
&state.db,
|
||||
&req,
|
||||
&state.signing_key_pem,
|
||||
state.config.security.jwt_access_ttl_secs as i64,
|
||||
ua.as_deref(),
|
||||
ip.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
use pm_auth::session::SessionError;
|
||||
let (status, code, message) = match e {
|
||||
SessionError::InvalidCredentials | SessionError::InvalidMfaCode => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_credentials",
|
||||
"Invalid username or password",
|
||||
),
|
||||
SessionError::MfaRequired => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"mfa_required",
|
||||
"MFA code required",
|
||||
),
|
||||
SessionError::AccountDisabled => (
|
||||
StatusCode::FORBIDDEN,
|
||||
"account_disabled",
|
||||
"Account is disabled",
|
||||
),
|
||||
SessionError::PasswordResetRequired => (
|
||||
StatusCode::FORBIDDEN,
|
||||
"password_reset_required",
|
||||
"Password reset is required before login",
|
||||
),
|
||||
SessionError::AccountLocked => (
|
||||
StatusCode::LOCKED,
|
||||
"account_locked",
|
||||
"Account is locked due to too many failed login attempts",
|
||||
),
|
||||
_ => {
|
||||
tracing::error!(error = %e, "Login error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"An error occurred",
|
||||
)
|
||||
},
|
||||
};
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/refresh
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RefreshRequest {
|
||||
refresh_token: String,
|
||||
}
|
||||
|
||||
async fn refresh_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<RefreshRequest>,
|
||||
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
|
||||
let ip = remote_ip(&headers);
|
||||
let ua = user_agent(&headers);
|
||||
|
||||
session::refresh_session(
|
||||
&state.db,
|
||||
&req.refresh_token,
|
||||
&state.signing_key_pem,
|
||||
state.config.security.jwt_access_ttl_secs as i64,
|
||||
ua.as_deref(),
|
||||
ip.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
use pm_auth::session::SessionError;
|
||||
let (status, code, msg) = match e {
|
||||
SessionError::Refresh(_) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_refresh_token",
|
||||
"Refresh token is invalid or expired",
|
||||
),
|
||||
SessionError::AccountDisabled => (
|
||||
StatusCode::FORBIDDEN,
|
||||
"account_disabled",
|
||||
"Account is disabled",
|
||||
),
|
||||
_ => {
|
||||
tracing::error!(error = %e, "Refresh error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"An error occurred",
|
||||
)
|
||||
},
|
||||
};
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": msg } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/logout
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LogoutRequest {
|
||||
refresh_token: String,
|
||||
}
|
||||
|
||||
async fn logout_handler(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<LogoutRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
session::logout(&state.db, &req.refresh_token)
|
||||
.await
|
||||
.map(|_| Json(json!({ "message": "Logged out successfully" })))
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Logout error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "An error occurred" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/mfa/setup (JWT required — via middleware)
|
||||
// ============================================================
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/force-change-password (PUBLIC — no JWT)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ForceChangePasswordRequest {
|
||||
username: String,
|
||||
current_password: String,
|
||||
new_password: String,
|
||||
}
|
||||
|
||||
async fn force_change_password_handler(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<ForceChangePasswordRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Validate new password strength
|
||||
if let Err(msg) = validate_password_strength(&req.new_password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Look up user by username
|
||||
let row: Option<(Uuid, Option<String>, bool)> = sqlx::query_as(
|
||||
"SELECT id, password_hash, force_password_reset FROM users WHERE username = $1 AND auth_provider = 'local'",
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to fetch user");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (user_id, hash_opt, _force_reset) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
|
||||
),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Verify current password
|
||||
let hash_str = hash_opt.as_deref().unwrap_or("");
|
||||
let valid = verify_password(&req.current_password, hash_str).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Hash and update password, clear force_password_reset
|
||||
let new_hash = hash_password(&req.new_password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, failed_login_attempts = 0, locked_until = NULL, updated_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(&new_hash)
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update password");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(user_id = %user_id, username = %req.username, "Password changed via force-change-password");
|
||||
|
||||
Ok(Json(json!({ "message": "Password changed successfully" })))
|
||||
}
|
||||
|
||||
async fn mfa_setup_handler(
|
||||
auth_user: AuthUser,
|
||||
) -> Result<Json<mfa_totp::TotpSetup>, (StatusCode, Json<Value>)> {
|
||||
mfa_totp::generate_setup(&auth_user.username)
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "TOTP setup error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/mfa/verify (JWT required — via middleware)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MfaVerifyRequest {
|
||||
secret_base32: String,
|
||||
code: String,
|
||||
}
|
||||
|
||||
async fn mfa_verify_handler(
|
||||
State(state): State<AppState>,
|
||||
auth_user: AuthUser,
|
||||
Json(req): Json<MfaVerifyRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let valid =
|
||||
mfa_totp::verify_code(&auth_user.username, &req.secret_base32, &req.code).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "invalid_code", "message": "Invalid TOTP code" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("UPDATE users SET totp_secret = $1, mfa_enabled = TRUE WHERE id = $2")
|
||||
.bind(&req.secret_base32)
|
||||
.bind(auth_user.user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to save TOTP secret");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to enable MFA" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(user_id = %auth_user.user_id, "MFA enabled for user");
|
||||
Ok(Json(json!({ "message": "MFA enabled successfully" })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// DELETE /api/v1/auth/mfa (JWT required — disable own MFA)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DisableMfaRequest {
|
||||
password: String,
|
||||
}
|
||||
|
||||
async fn disable_mfa(
|
||||
State(state): State<AppState>,
|
||||
auth_user: AuthUser,
|
||||
Json(req): Json<DisableMfaRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Verify current password to confirm identity
|
||||
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
|
||||
.bind(auth_user.user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to fetch password hash");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?
|
||||
.flatten();
|
||||
|
||||
let hash_str = hash.unwrap_or_default();
|
||||
let valid = verify_password(&req.password, &hash_str).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE WHERE id = $1")
|
||||
.bind(auth_user.user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to disable MFA");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(user_id = %auth_user.user_id, "MFA disabled for user");
|
||||
Ok(Json(json!({ "message": "MFA disabled successfully" })))
|
||||
}
|
||||
516
crates/pm-web/src/routes/ca.rs
Executable file
516
crates/pm-web/src/routes/ca.rs
Executable file
@ -0,0 +1,516 @@
|
||||
//! CA / certificate management routes.
|
||||
//!
|
||||
//! ca_router() → mounted at /api/v1/ca
|
||||
//! GET /root.crt download_root_ca (any authed role)
|
||||
//!
|
||||
//! certs_router() → mounted at /api/v1/certificates
|
||||
//! GET / list_certificates (any authed role)
|
||||
//! POST /:cert_id/renew renew_cert (admin only)
|
||||
//! DELETE /:cert_id revoke_cert (admin only)
|
||||
//!
|
||||
//! host_cert_router() → merged under /api/v1/hosts
|
||||
//! GET /:host_id/client.crt download_client_cert (admin only)
|
||||
//! POST /:host_id/certificates issue_client_cert (admin only)
|
||||
//! POST /:host_id/certificates/reissue reissue_host_cert (admin only)
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, Query, State},
|
||||
http::{header, Response, StatusCode},
|
||||
response::Json,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::audit::{log_event, AuditAction};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::Row;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router constructors ───────────────────────────────────────────────────────
|
||||
|
||||
/// Handles routes mounted at /api/v1/ca
|
||||
pub fn ca_router() -> Router<AppState> {
|
||||
Router::new().route("/root.crt", get(download_root_ca))
|
||||
}
|
||||
|
||||
/// Handles routes mounted at /api/v1/certificates
|
||||
pub fn certs_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_certificates))
|
||||
.route("/{cert_id}/renew", post(renew_cert))
|
||||
.route("/{cert_id}", delete(revoke_cert))
|
||||
}
|
||||
|
||||
/// Handles cert-specific paths merged under /api/v1/hosts.
|
||||
/// Only adds paths not already claimed by the hosts router.
|
||||
pub fn host_cert_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/{host_id}/client.crt", get(download_client_cert))
|
||||
.route("/{host_id}/certificates", post(issue_client_cert))
|
||||
.route("/{host_id}/certificates/reissue", post(reissue_host_cert))
|
||||
}
|
||||
|
||||
// ── Shared types ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Row returned from the `certificates` table.
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
struct CertRow {
|
||||
id: Uuid,
|
||||
host_id: Option<Uuid>,
|
||||
serial_number: String,
|
||||
common_name: String,
|
||||
/// Cast to TEXT in all queries to avoid custom-enum decode.
|
||||
status: String,
|
||||
issued_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
revoked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Query params for `list_certificates`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CertListQuery {
|
||||
host_id: Option<Uuid>,
|
||||
status: Option<String>,
|
||||
}
|
||||
|
||||
/// Request body for `issue_client_cert`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IssueCertRequest {
|
||||
hostname: String,
|
||||
}
|
||||
|
||||
// ── Helper: build PEM download response ──────────────────────────────────────
|
||||
|
||||
fn pem_response(pem: String, filename: &str) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
|
||||
let disposition = format!("attachment; filename=\"{filename}\"");
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/x-pem-file")
|
||||
.header(header::CONTENT_DISPOSITION, disposition)
|
||||
.body(Body::from(pem))
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build PEM response");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Response build error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── Helper: admin-only guard ──────────────────────────────────────────────────
|
||||
|
||||
fn require_write_access(user: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
|
||||
if !user.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Helper: map sqlx error to 500 ─────────────────────────────────────────────
|
||||
|
||||
fn db_error(e: sqlx::Error) -> (StatusCode, Json<Value>) {
|
||||
tracing::error!(error = %e, "Database error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── Helper: build the full IssuedCert JSON response ──────────────────────────
|
||||
|
||||
fn issued_cert_json(issued: &pm_ca::IssuedCert) -> Value {
|
||||
json!({
|
||||
"cert_pem": issued.cert_pem,
|
||||
"key_pem": issued.key_pem,
|
||||
"serial_number": issued.serial_number,
|
||||
"expires_at": issued.expires_at,
|
||||
"server_cert_pem": issued.server_cert_pem,
|
||||
"server_key_pem": issued.server_key_pem,
|
||||
"server_serial_number": issued.server_serial_number,
|
||||
"ca_root_pem": issued.ca_root_pem,
|
||||
})
|
||||
}
|
||||
|
||||
// ── GET /api/v1/ca/root.crt ───────────────────────────────────────────────────
|
||||
|
||||
/// Download the root CA certificate as a PEM file.
|
||||
async fn download_root_ca(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
|
||||
let pem = state.ca.root_cert_pem().to_owned();
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateDownloaded,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some("root_ca"),
|
||||
json!({ "operation": "download_root_ca" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
pem_response(pem, "ca.crt")
|
||||
}
|
||||
|
||||
// ── GET /api/v1/certificates ──────────────────────────────────────────────────
|
||||
|
||||
/// List certificates with optional `?host_id=` and `?status=` filters.
|
||||
async fn list_certificates(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Query(q): Query<CertListQuery>,
|
||||
) -> Result<Json<Vec<CertRow>>, (StatusCode, Json<Value>)> {
|
||||
// Use the non-macro query_as form — avoids needing DATABASE_URL at compile
|
||||
// time. status is cast to TEXT so sqlx decodes it into String directly.
|
||||
let rows: Vec<CertRow> = match (q.host_id, q.status.as_deref()) {
|
||||
(Some(hid), Some(st)) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
WHERE host_id = $1 AND status::text = $2
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.bind(hid)
|
||||
.bind(st)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
(Some(hid), None) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
WHERE host_id = $1
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.bind(hid)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
(None, Some(st)) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
WHERE status::text = $1
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.bind(st)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
(None, None) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
}
|
||||
.map_err(db_error)?;
|
||||
|
||||
Ok(Json(rows))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:host_id/client.crt ────────────────────────────────────
|
||||
|
||||
/// Download the most recent active client certificate PEM for a host.
|
||||
async fn download_client_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
let cert_pem: Option<String> = sqlx::query_scalar(
|
||||
r#"SELECT cert_pem
|
||||
FROM certificates
|
||||
WHERE host_id = $1
|
||||
AND status = 'active'::cert_status
|
||||
AND common_name NOT LIKE '%-server'
|
||||
ORDER BY issued_at DESC
|
||||
LIMIT 1"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to fetch client cert");
|
||||
db_error(e)
|
||||
})?;
|
||||
|
||||
match cert_pem {
|
||||
Some(pem) => {
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateDownloaded,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "operation": "download_client_cert" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
pem_response(pem, "client.crt")
|
||||
},
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"code": "not_found",
|
||||
"message": "No active certificate found for this host"
|
||||
}
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/certificates ─────────────────────────────────
|
||||
|
||||
/// Issue a new mTLS client certificate (and server certificate) for a host.
|
||||
/// **The private keys are returned only once — the caller must save them.**
|
||||
async fn issue_client_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
Json(req): Json<IssueCertRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
// Look up the host's IP address from the database.
|
||||
let ip_address: String = sqlx::query_scalar("SELECT host(ip_address) FROM hosts WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to fetch host IP address");
|
||||
if e.to_string().contains("no rows") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
} else {
|
||||
db_error(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let issued = state
|
||||
.ca
|
||||
.issue_client_cert(host_id, &req.hostname, &ip_address, &state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, hostname = %req.hostname,
|
||||
"Failed to issue client cert");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateIssued,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "hostname": req.hostname, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(issued_cert_json(&issued)))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/certificates/:cert_id/renew ─────────────────────────────────
|
||||
|
||||
/// Revoke the specified certificate and issue a replacement with the same CN.
|
||||
async fn renew_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(cert_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
let issued = state.ca.renew_cert(cert_id, &state.db).await.map_err(|e| {
|
||||
let msg = e.to_string();
|
||||
tracing::error!(error = %e, %cert_id, "Failed to renew cert");
|
||||
if msg.contains("not found") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(
|
||||
json!({ "error": { "code": "not_found", "message": "Certificate not found" } }),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
|
||||
)
|
||||
}
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateRenewed,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&cert_id.to_string()),
|
||||
json!({ "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(issued_cert_json(&issued)))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/certificates/reissue ────────────────────────
|
||||
|
||||
/// Revoke ALL active certificates for a host and issue new ones.
|
||||
/// The private keys are returned only once — the caller must save them.
|
||||
async fn reissue_host_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
// Look up the host's FQDN and IP address for the new certificate CN and SANs.
|
||||
let row = sqlx::query("SELECT fqdn, host(ip_address) AS ip_address FROM hosts WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to fetch host FQDN/IP");
|
||||
if e.to_string().contains("no rows") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
} else {
|
||||
db_error(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let fqdn: String = row.try_get("fqdn").map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to read fqdn");
|
||||
db_error(e)
|
||||
})?;
|
||||
let ip_address: String = row.try_get("ip_address").map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to read ip_address");
|
||||
db_error(e)
|
||||
})?;
|
||||
|
||||
// Revoke all active certificates for this host.
|
||||
let revoked = sqlx::query(
|
||||
"UPDATE certificates SET status = 'revoked'::cert_status, revoked_at = NOW() \
|
||||
WHERE host_id = $1 AND status = 'active'::cert_status",
|
||||
)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(db_error)?;
|
||||
|
||||
tracing::info!(%host_id, rows_revoked = revoked.rows_affected(), "Revoked all active certs for host");
|
||||
|
||||
// Issue a new certificate bundle using the host's FQDN and IP.
|
||||
let issued = state
|
||||
.ca
|
||||
.issue_client_cert(host_id, &fqdn, &ip_address, &state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to issue new cert during reissue");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateReissued,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "hostname": &fqdn, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number, "rows_revoked": revoked.rows_affected() }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(issued_cert_json(&issued)))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/certificates/:cert_id ─────────────────────────────────────
|
||||
|
||||
/// Revoke a certificate by ID. Sets status to 'revoked' in the database.
|
||||
async fn revoke_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(cert_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
state
|
||||
.ca
|
||||
.revoke_cert(cert_id, &state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = e.to_string();
|
||||
tracing::error!(error = %e, %cert_id, "Failed to revoke cert");
|
||||
if msg.contains("not found") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Certificate not found" } })),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
|
||||
)
|
||||
}
|
||||
})?;
|
||||
|
||||
tracing::info!(%cert_id, "Certificate revoked via API");
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateRevoked,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&cert_id.to_string()),
|
||||
json!({ "operation": "revoke" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "revoked": true })))
|
||||
}
|
||||
304
crates/pm-web/src/routes/discovery.rs
Executable file
304
crates/pm-web/src/routes/discovery.rs
Executable file
@ -0,0 +1,304 @@
|
||||
//! CIDR auto-discovery routes.
|
||||
//!
|
||||
//! POST /api/v1/discovery/cidr — start a CIDR scan
|
||||
//! GET /api/v1/discovery/:scan_id — get scan results
|
||||
//! POST /api/v1/discovery/:id/register — register a discovered host
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{DiscoveryCidrRequest, DiscoveryResult, RegisterDiscoveredRequest},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
net::{IpAddr, TcpStream},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{sync::Semaphore, task};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
/// Maximum concurrent TCP probes during CIDR scan.
|
||||
const MAX_CONCURRENT_PROBES: usize = 128;
|
||||
/// TCP connect timeout per probe.
|
||||
const PROBE_TIMEOUT_SECS: u64 = 2;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/cidr", post(start_cidr_scan))
|
||||
.route("/{scan_id}", get(get_scan_results))
|
||||
.route("/{id}/register", post(register_discovered_host))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/discovery/cidr ───────────────────────────────────────────────
|
||||
|
||||
async fn start_cidr_scan(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<DiscoveryCidrRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let cidr: ipnet::IpNet = req.cidr.parse().map_err(|_| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "bad_request", "message": "Invalid CIDR range" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let agent_port = req.agent_port.unwrap_or(12443) as u16;
|
||||
let scan_id = Uuid::new_v4();
|
||||
|
||||
// Clear previous results for this type of scan and start async scan
|
||||
let pool = state.db.clone();
|
||||
let scan_id_clone = scan_id;
|
||||
let cidr_str = req.cidr.clone();
|
||||
|
||||
// Spawn non-blocking background scan
|
||||
task::spawn(async move {
|
||||
run_cidr_scan(pool, scan_id_clone, cidr, agent_port).await;
|
||||
});
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::DiscoveryScanStarted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("discovery"),
|
||||
Some(&scan_id.to_string()),
|
||||
json!({ "cidr": cidr_str }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(scan_id = %scan_id, cidr = %req.cidr, "CIDR scan started");
|
||||
Ok(Json(
|
||||
json!({ "scan_id": scan_id, "message": "Discovery scan started", "cidr": req.cidr }),
|
||||
))
|
||||
}
|
||||
|
||||
/// Background CIDR scanner.
|
||||
async fn run_cidr_scan(pool: sqlx::PgPool, scan_id: Uuid, cidr: ipnet::IpNet, port: u16) {
|
||||
let semaphore = std::sync::Arc::new(Semaphore::new(MAX_CONCURRENT_PROBES));
|
||||
let hosts: Vec<IpAddr> = cidr.hosts().collect();
|
||||
let total = hosts.len();
|
||||
|
||||
tracing::info!(scan_id = %scan_id, total = total, "CIDR scan probing {} hosts", total);
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for ip in hosts {
|
||||
let sem = semaphore.clone();
|
||||
let pool_clone = pool.clone();
|
||||
let h = task::spawn(async move {
|
||||
let _permit = sem.acquire().await.ok()?;
|
||||
probe_and_store(pool_clone, scan_id, ip, port).await
|
||||
});
|
||||
handles.push(h);
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
let _ = h.await;
|
||||
}
|
||||
|
||||
tracing::info!(scan_id = %scan_id, "CIDR scan complete");
|
||||
}
|
||||
|
||||
/// Probe a single IP:port and store the result if the port is open.
|
||||
async fn probe_and_store(pool: sqlx::PgPool, scan_id: Uuid, ip: IpAddr, port: u16) -> Option<()> {
|
||||
let addr = format!("{ip}:{port}");
|
||||
|
||||
// TCP connect probe (blocking, run in thread pool)
|
||||
// TCP connect probe (blocking, run in thread pool)
|
||||
let addr_clone = addr.clone();
|
||||
let open = task::spawn_blocking(move || {
|
||||
TcpStream::connect_timeout(
|
||||
&match addr_clone.parse() {
|
||||
Ok(a) => a,
|
||||
Err(_) => return false,
|
||||
},
|
||||
Duration::from_secs(PROBE_TIMEOUT_SECS),
|
||||
)
|
||||
.is_ok()
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !open {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Reverse DNS lookup (best-effort)
|
||||
let ip_clone = ip;
|
||||
let fqdn = task::spawn_blocking(move || {
|
||||
use std::net::ToSocketAddrs;
|
||||
let addr = format!("{ip_clone}:{port}");
|
||||
addr.to_socket_addrs()
|
||||
.ok()
|
||||
.and_then(|mut a| a.next())
|
||||
.and_then(|_| dns_lookup_for_ip(ip_clone))
|
||||
})
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let _ = sqlx::query(
|
||||
r#"INSERT INTO discovery_results (scan_id, ip_address, fqdn, agent_port)
|
||||
VALUES ($1, $2::inet, $3, $4)
|
||||
ON CONFLICT DO NOTHING"#,
|
||||
)
|
||||
.bind(scan_id)
|
||||
.bind(ip.to_string())
|
||||
.bind(fqdn)
|
||||
.bind(port as i32)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
tracing::debug!(ip = %ip, port = port, "Discovered agent");
|
||||
Some(())
|
||||
}
|
||||
|
||||
/// Simple reverse DNS lookup.
|
||||
fn dns_lookup_for_ip(ip: IpAddr) -> Option<String> {
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
let _addr = SocketAddr::new(ip, 0);
|
||||
// Standard library doesn't have reverse lookup; use getaddrinfo via format
|
||||
let host = format!("{ip}");
|
||||
// Best-effort: try to resolve numeric address to hostname
|
||||
(host + ":0")
|
||||
.to_socket_addrs()
|
||||
.ok()?
|
||||
.next()
|
||||
.map(|a| a.ip().to_string())
|
||||
.filter(|s| s != &ip.to_string())
|
||||
}
|
||||
|
||||
// ── GET /api/v1/discovery/:scan_id ────────────────────────────────────────────
|
||||
|
||||
async fn get_scan_results(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(scan_id): Path<Uuid>,
|
||||
) -> Result<Json<Vec<DiscoveryResult>>, (StatusCode, Json<Value>)> {
|
||||
sqlx::query_as::<_, DiscoveryResult>(
|
||||
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
|
||||
agent_version, os_name, agent_port, discovered_at, registered
|
||||
FROM discovery_results
|
||||
WHERE scan_id = $1
|
||||
ORDER BY ip_address"#,
|
||||
)
|
||||
.bind(scan_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── POST /api/v1/discovery/:id/register ──────────────────────────────────────
|
||||
|
||||
async fn register_discovered_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<RegisterDiscoveredRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch discovery result
|
||||
let result: Option<DiscoveryResult> = sqlx::query_as(
|
||||
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
|
||||
agent_version, os_name, agent_port, discovered_at, registered
|
||||
FROM discovery_results WHERE id = $1"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let result = result.ok_or_else(|| (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Discovery result not found" } }))
|
||||
))?;
|
||||
|
||||
let fqdn = result.fqdn.as_deref().unwrap_or(&result.ip_address);
|
||||
let display_name = req.display_name.as_deref().unwrap_or(fqdn);
|
||||
|
||||
let host_id: Uuid = sqlx::query_scalar(
|
||||
r#"INSERT INTO hosts (fqdn, ip_address, display_name, agent_port)
|
||||
VALUES ($1, $2::inet, $3, $4)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(fqdn)
|
||||
.bind(&result.ip_address)
|
||||
.bind(display_name)
|
||||
.bind(result.agent_port)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Assign to groups
|
||||
if let Some(group_ids) = &req.group_ids {
|
||||
for gid in group_ids {
|
||||
let _ = sqlx::query("INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING")
|
||||
.bind(host_id).bind(gid).execute(&state.db).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as registered
|
||||
let _ = sqlx::query("UPDATE discovery_results SET registered = TRUE WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::HostRegistered,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "from_discovery": true, "ip": result.ip_address }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(
|
||||
json!({ "host_id": host_id, "message": "Host registered from discovery" }),
|
||||
))
|
||||
}
|
||||
319
crates/pm-web/src/routes/enrollment.rs
Normal file
319
crates/pm-web/src/routes/enrollment.rs
Normal file
@ -0,0 +1,319 @@
|
||||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use pm_auth::AuthUser;
|
||||
use pm_core::{
|
||||
db,
|
||||
models::{
|
||||
CreateEnrollmentRequest, EnrollmentRequest, EnrollmentStatusResponse, Host, PkiBundle,
|
||||
},
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct HostConflict {
|
||||
pub existing_host: Host,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Define public enrollment routes.
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/enroll", post(enroll_host))
|
||||
.route("/enroll/status/{token}", get(enroll_status))
|
||||
}
|
||||
|
||||
/// POST /api/v1/enroll
|
||||
/// Initiates host self-enrollment.
|
||||
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
|
||||
async fn enroll_host(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<CreateEnrollmentRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Generate secure random polling token
|
||||
let polling_token: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(64)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
// For database storage, we'll hash the token (spec says hashed)
|
||||
// Using a simple SHA256 or similar for the hash storage
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(polling_token.as_bytes());
|
||||
let token_hash = hex::encode(hasher.finalize());
|
||||
|
||||
// 3. Store in DB
|
||||
db::create_enrollment_request(&state.db, payload, token_hash)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to create enrollment request");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 4. Return the raw token to the client
|
||||
Ok((
|
||||
StatusCode::ACCEPTED,
|
||||
Json(serde_json::json!({ "polling_token": polling_token })),
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
|
||||
/// GET /api/v1/enroll/status/{token}
|
||||
/// Returns status of enrollment (pending/approved/denied/not_found).
|
||||
async fn enroll_status(
|
||||
State(state): State<AppState>,
|
||||
Path(token): Path<String>,
|
||||
) -> Result<Json<EnrollmentStatusResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Hash the provided token to match DB
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let token_hash = hex::encode(hasher.finalize());
|
||||
|
||||
// 1. Check enrollment_requests table
|
||||
let requests = db::list_enrollment_requests(&state.db).await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(req) = requests.into_iter().find(|r| r.polling_token == token_hash) {
|
||||
if req.expires_at < Utc::now() {
|
||||
return Ok(Json(EnrollmentStatusResponse::NotFound));
|
||||
}
|
||||
return Ok(Json(EnrollmentStatusResponse::Pending));
|
||||
}
|
||||
|
||||
// 2. If not in pending, check if it was recently approved.
|
||||
if let Some(pki) = state.approved_enrollments.get(&token_hash) {
|
||||
return Ok(Json(EnrollmentStatusResponse::Approved {
|
||||
ca_crt: pki.ca_crt.clone(),
|
||||
server_crt: pki.server_crt.clone(),
|
||||
server_key: pki.server_key.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Json(EnrollmentStatusResponse::NotFound))
|
||||
}
|
||||
|
||||
/// Define admin enrollment routes.
|
||||
pub fn admin_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/enrollments", get(list_admin_enrollments))
|
||||
.route("/enrollments/{id}/approve", post(approve_enrollment))
|
||||
.route("/enrollments/{id}/deny", delete(deny_enrollment))
|
||||
}
|
||||
|
||||
/// GET /api/v1/admin/enrollments
|
||||
/// Lists all pending enrollment requests.
|
||||
async fn list_admin_enrollments(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Vec<EnrollmentRequest>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
db::list_enrollment_requests(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list enrollment requests");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// POST /api/v1/admin/enrollments/{id}/approve
|
||||
/// Approves a pending enrollment request, generates PKI, and moves to hosts table.
|
||||
async fn approve_enrollment(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<uuid::Uuid>,
|
||||
auth: AuthUser,
|
||||
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch the enrollment request
|
||||
let mut requests = db::list_enrollment_requests(&state.db).await.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list enrollment requests for approval");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let enrollment_request = match requests.iter().position(|r| r.id == id) {
|
||||
Some(idx) => requests.remove(idx),
|
||||
None => return Ok(StatusCode::NOT_FOUND),
|
||||
};
|
||||
|
||||
// Check for FQDN/IP collision in hosts table
|
||||
if let Some(existing_host) = sqlx::query_as::<_, Host>(
|
||||
"SELECT id, fqdn, ip_address::text, display_name, os_family, os_name, arch, agent_version, health_status, last_health_at, last_patch_at, agent_port, notes, registered_at, updated_at FROM hosts WHERE fqdn = $1 OR ip_address = $2::inet"
|
||||
)
|
||||
.bind(&enrollment_request.fqdn)
|
||||
.bind(enrollment_request.ip_address.to_string())
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to check for host collision");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": "Database error" })))
|
||||
})? {
|
||||
return Err((
|
||||
StatusCode::CONFLICT,
|
||||
Json(serde_json::json!({ "error": "Host collision detected", "conflict": HostConflict { existing_host, message: "FQDN or IP already exists".to_string() } }))
|
||||
));
|
||||
}
|
||||
|
||||
// Move to hosts table FIRST (certificates table has FK reference to hosts)
|
||||
let os_family = enrollment_request
|
||||
.os_details
|
||||
.get("os")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let os_name = enrollment_request
|
||||
.os_details
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
// Build os_name from os + os_version if "name" is absent
|
||||
let os = enrollment_request
|
||||
.os_details
|
||||
.get("os")
|
||||
.and_then(|v| v.as_str())?;
|
||||
let ver = enrollment_request
|
||||
.os_details
|
||||
.get("os_version")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
Some(format!("{} {}", os, ver).trim().to_string())
|
||||
});
|
||||
let arch = enrollment_request
|
||||
.os_details
|
||||
.get("architecture")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let display_name = enrollment_request
|
||||
.hostname
|
||||
.clone()
|
||||
.unwrap_or_else(|| enrollment_request.fqdn.clone());
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO hosts (id, fqdn, ip_address, os_family, os_name, arch, display_name, registered_at, updated_at)
|
||||
VALUES ($1, $2, $3::inet, $4, $5, $6, $7, NOW(), NOW())
|
||||
"#,
|
||||
)
|
||||
.bind(enrollment_request.id)
|
||||
.bind(&enrollment_request.fqdn)
|
||||
.bind(enrollment_request.ip_address.to_string())
|
||||
.bind(&os_family)
|
||||
.bind(&os_name)
|
||||
.bind(&arch)
|
||||
.bind(&display_name)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to insert host after approval");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Generate PKI bundle using CA (after host row exists)
|
||||
let issued = state
|
||||
.ca
|
||||
.issue_client_cert(
|
||||
enrollment_request.id,
|
||||
&enrollment_request.fqdn,
|
||||
&enrollment_request.ip_address.to_string(),
|
||||
&state.db,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to issue client certificate");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Certificate generation failed" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Delete from enrollment_requests table
|
||||
db::delete_enrollment_request(&state.db, id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to delete enrollment request after approval");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Store PKI bundle in cache for client retrieval
|
||||
let pki = PkiBundle {
|
||||
ca_crt: issued.ca_root_pem,
|
||||
server_crt: issued.server_cert_pem,
|
||||
server_key: issued.server_key_pem,
|
||||
};
|
||||
state
|
||||
.approved_enrollments
|
||||
.insert(enrollment_request.polling_token.clone(), pki);
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/admin/enrollments/{id}/deny
|
||||
/// Denies and purges a pending enrollment request.
|
||||
async fn deny_enrollment(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<uuid::Uuid>,
|
||||
auth: AuthUser,
|
||||
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
db::delete_enrollment_request(&state.db, id)
|
||||
.await
|
||||
.map(|_| StatusCode::NO_CONTENT)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to deny enrollment request");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})
|
||||
}
|
||||
312
crates/pm-web/src/routes/groups.rs
Executable file
312
crates/pm-web/src/routes/groups.rs
Executable file
@ -0,0 +1,312 @@
|
||||
//! Group management routes.
|
||||
//!
|
||||
//! GET /api/v1/groups — list all groups
|
||||
//! POST /api/v1/groups — create group (admin)
|
||||
//! GET /api/v1/groups/:id — get group detail + members
|
||||
//! PUT /api/v1/groups/:id — update group (admin)
|
||||
//! DELETE /api/v1/groups/:id — delete group (admin)
|
||||
//! POST /api/v1/groups/:id/users/:user_id — add user to group (admin)
|
||||
//! DELETE /api/v1/groups/:id/users/:user_id — remove user from group (admin)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateGroupRequest, Group, UpdateGroupRequest},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_groups).post(create_group))
|
||||
.route(
|
||||
"/{id}",
|
||||
get(get_group).put(update_group).delete(delete_group),
|
||||
)
|
||||
.route(
|
||||
"/{id}/users/{user_id}",
|
||||
post(add_user_to_group).delete(remove_user_from_group),
|
||||
)
|
||||
}
|
||||
|
||||
async fn list_groups(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
|
||||
sqlx::query_as::<_, Group>(
|
||||
"SELECT id, name, description, created_at, updated_at FROM groups ORDER BY name",
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list groups");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateGroupRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let id: Uuid =
|
||||
sqlx::query_scalar("INSERT INTO groups (name, description) VALUES ($1, $2) RETURNING id")
|
||||
.bind(&req.name)
|
||||
.bind(req.description.as_deref().unwrap_or(""))
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"Group name already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("group"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "name": req.name }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "id": id, "message": "Group created" })))
|
||||
}
|
||||
|
||||
async fn get_group(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let group: Option<Group> = sqlx::query_as(
|
||||
"SELECT id, name, description, created_at, updated_at FROM groups WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let group = group.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Fetch member counts
|
||||
let host_count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM host_groups WHERE group_id = $1")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let user_count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM user_groups WHERE group_id = $1")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(Json(
|
||||
json!({ "group": group, "host_count": host_count, "user_count": user_count }),
|
||||
))
|
||||
}
|
||||
|
||||
async fn update_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<UpdateGroupRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query(
|
||||
"UPDATE groups SET name = COALESCE($1, name), description = COALESCE($2, description), updated_at = NOW() WHERE id = $3"
|
||||
)
|
||||
.bind(req.name.as_deref())
|
||||
.bind(req.description.as_deref())
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Json(json!({ "message": "Group updated" })))
|
||||
}
|
||||
|
||||
async fn delete_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query("DELETE FROM groups WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("group"),
|
||||
Some(&id.to_string()),
|
||||
json!({}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Group deleted" })))
|
||||
}
|
||||
|
||||
async fn add_user_to_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((id, user_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO user_groups (user_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user_group"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "user_id": user_id, "action": "added" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User added to group" })))
|
||||
}
|
||||
|
||||
async fn remove_user_from_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((id, user_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM user_groups WHERE user_id = $1 AND group_id = $2")
|
||||
.bind(user_id)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user_group"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "user_id": user_id, "action": "removed" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User removed from group" })))
|
||||
}
|
||||
1159
crates/pm-web/src/routes/health_checks.rs
Executable file
1159
crates/pm-web/src/routes/health_checks.rs
Executable file
File diff suppressed because it is too large
Load Diff
678
crates/pm-web/src/routes/hosts.rs
Executable file
678
crates/pm-web/src/routes/hosts.rs
Executable file
@ -0,0 +1,678 @@
|
||||
//! Host management routes.
|
||||
//!
|
||||
//! GET /api/v1/hosts — list hosts (RBAC scoped)
|
||||
//! POST /api/v1/hosts — register new host (admin only)
|
||||
//! GET /api/v1/hosts/{id} — get host detail
|
||||
//! DELETE /api/v1/hosts/{id} — remove host (admin only)
|
||||
//! PUT /api/v1/hosts/{id} — update host (write access)
|
||||
//! GET /api/v1/hosts/{id}/groups — list groups for host
|
||||
//! POST /api/v1/hosts/{id}/groups — assign host to group
|
||||
//! DELETE /api/v1/hosts/{id}/groups/{group_id} — remove host from group
|
||||
//! POST /api/v1/hosts/{id}/refresh — queue on-demand refresh (write access)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateHostRequest, Group, HostSummary, UpdateHostRequest},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_hosts).post(register_host))
|
||||
.route("/{id}", get(get_host).put(update_host).delete(remove_host))
|
||||
.route(
|
||||
"/{id}/groups",
|
||||
get(list_host_groups).post(add_host_to_group),
|
||||
)
|
||||
.route("/{id}/groups/{group_id}", delete(remove_host_from_group))
|
||||
.route("/{id}/refresh", post(refresh_host))
|
||||
}
|
||||
|
||||
// ── Query params ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct HostListQuery {
|
||||
pub group_id: Option<Uuid>,
|
||||
pub health_status: Option<String>,
|
||||
pub os_family: Option<String>,
|
||||
pub search: Option<String>,
|
||||
pub limit: Option<i64>,
|
||||
pub offset: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Response types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct HostListResponse {
|
||||
hosts: Vec<HostSummary>,
|
||||
total: i64,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
}
|
||||
|
||||
// ── Helper: check if operator can access a host ───────────────────────────────
|
||||
|
||||
async fn operator_can_access_host(
|
||||
pool: &sqlx::PgPool,
|
||||
user_id: Uuid,
|
||||
host_id: Uuid,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
// Admins can access all; operators can access hosts in their groups
|
||||
// OR ungrouped hosts (no group memberships)
|
||||
let in_group: bool = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM host_groups hg
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE hg.host_id = $1 AND ug.user_id = $2
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if in_group {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Ungrouped hosts are accessible to any operator
|
||||
let ungrouped: bool =
|
||||
sqlx::query_scalar("SELECT NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(ungrouped)
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn list_hosts(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Query(q): Query<HostListQuery>,
|
||||
) -> Result<Json<HostListResponse>, (StatusCode, Json<Value>)> {
|
||||
let limit = q.limit.unwrap_or(50).min(200);
|
||||
let offset = q.offset.unwrap_or(0);
|
||||
|
||||
// For operators: only show hosts in their groups (or ungrouped)
|
||||
let hosts: Vec<HostSummary> = if auth.role.is_admin() {
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT h.id, h.fqdn, host(h.ip_address)::text AS ip_address, h.display_name,
|
||||
h.os_family, h.os_name, h.health_status, h.agent_version,
|
||||
COALESCE(hpd.patch_count, 0) AS patches_missing,
|
||||
CASE
|
||||
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
|
||||
THEN NULL
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM host_health_checks hc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT healthy FROM host_health_check_results r
|
||||
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
|
||||
) lr ON TRUE
|
||||
WHERE hc.host_id = h.id AND hc.enabled = TRUE
|
||||
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
|
||||
)
|
||||
THEN 'some_unhealthy'
|
||||
ELSE 'all_healthy'
|
||||
END AS health_check_status,
|
||||
h.registered_at
|
||||
FROM hosts h
|
||||
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
|
||||
ORDER BY h.fqdn
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT h.id, h.fqdn, host(h.ip_address)::text AS ip_address,
|
||||
h.display_name, h.os_family, h.os_name,
|
||||
h.health_status, h.agent_version,
|
||||
COALESCE(hpd.patch_count, 0) AS patches_missing,
|
||||
CASE
|
||||
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
|
||||
THEN NULL
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM host_health_checks hc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT healthy FROM host_health_check_results r
|
||||
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
|
||||
) lr ON TRUE
|
||||
WHERE hc.host_id = h.id AND hc.enabled = TRUE
|
||||
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
|
||||
)
|
||||
THEN 'some_unhealthy'
|
||||
ELSE 'all_healthy'
|
||||
END AS health_check_status,
|
||||
h.registered_at
|
||||
FROM hosts h
|
||||
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
|
||||
WHERE
|
||||
-- Hosts in operator's groups
|
||||
EXISTS (
|
||||
SELECT 1 FROM host_groups hg
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE hg.host_id = h.id AND ug.user_id = $3
|
||||
)
|
||||
-- OR ungrouped hosts
|
||||
OR NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = h.id)
|
||||
ORDER BY h.fqdn
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.bind(auth.user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
}
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM hosts")
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(Json(HostListResponse {
|
||||
hosts,
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
}))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts ────────────────────────────────────────────────────────
|
||||
|
||||
async fn register_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateHostRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Admin only
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Resolve FQDN to IP address
|
||||
let ip_address = resolve_fqdn(&req.fqdn).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "fqdn_resolution_failed", "message": e } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let display_name = req.display_name.clone().unwrap_or_else(|| req.fqdn.clone());
|
||||
let agent_port = req.agent_port.unwrap_or(12443);
|
||||
let notes = req.notes.clone().unwrap_or_default();
|
||||
|
||||
// Insert host
|
||||
let host_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO hosts (fqdn, ip_address, display_name, agent_port, notes)
|
||||
VALUES ($1, $2::inet, $3, $4, $5)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(&req.fqdn)
|
||||
.bind(&ip_address)
|
||||
.bind(&display_name)
|
||||
.bind(agent_port)
|
||||
.bind(¬es)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"Host with this FQDN and IP already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
tracing::error!(error = %e, "Failed to register host");
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Assign to groups if specified
|
||||
if let Some(group_ids) = &req.group_ids {
|
||||
for gid in group_ids {
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(gid)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Audit log
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::HostRegistered,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "fqdn": req.fqdn, "ip": ip_address }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(host_id = %host_id, fqdn = %req.fqdn, "Host registered");
|
||||
Ok(Json(json!({ "id": host_id, "message": "Host registered" })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn get_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !can_access {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let host: Option<Value> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT row_to_json(h) FROM (
|
||||
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
|
||||
os_family, os_name, arch, agent_version, health_status,
|
||||
last_health_at, last_patch_at, agent_port, notes,
|
||||
registered_at, updated_at
|
||||
FROM hosts WHERE id = $1
|
||||
) h
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to get host");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
host.map(Json).ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:id ──────────────────────────────────────────────────
|
||||
|
||||
async fn remove_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch FQDN for audit before deletion
|
||||
let fqdn: Option<String> = sqlx::query_scalar("SELECT fqdn FROM hosts WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let result = sqlx::query("DELETE FROM hosts WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to remove host");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::HostRemoved,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "fqdn": fqdn }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(host_id = %id, "Host removed");
|
||||
Ok(Json(json!({ "message": "Host removed" })))
|
||||
}
|
||||
|
||||
// ── PUT /api/v1/hosts/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn update_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<UpdateHostRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Update only fields that were provided; COALESCE preserves existing values.
|
||||
let host = sqlx::query_scalar(
|
||||
r#"
|
||||
WITH updated AS (
|
||||
UPDATE hosts SET
|
||||
fqdn = COALESCE($1, fqdn),
|
||||
ip_address = COALESCE($2::inet, ip_address),
|
||||
display_name = COALESCE($3, display_name),
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
RETURNING id
|
||||
)
|
||||
SELECT row_to_json(h) FROM (
|
||||
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
|
||||
os_family, os_name, arch, agent_version, health_status,
|
||||
last_health_at, last_patch_at, agent_port, notes,
|
||||
registered_at, updated_at
|
||||
FROM hosts WHERE id = (SELECT id FROM updated)
|
||||
) h
|
||||
"#,
|
||||
)
|
||||
.bind(&req.fqdn)
|
||||
.bind(&req.ip_address)
|
||||
.bind(&req.display_name)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, host_id = %id, "Failed to update host");
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"A host with this FQDN and IP already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
host.map(Json).ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:id/groups ──────────────────────────────────────────────
|
||||
|
||||
async fn list_host_groups(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !can_access {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let groups: Vec<Group> = sqlx::query_as(
|
||||
r#"SELECT g.id, g.name, g.description, g.created_at, g.updated_at
|
||||
FROM groups g
|
||||
JOIN host_groups hg ON hg.group_id = g.id
|
||||
WHERE hg.host_id = $1
|
||||
ORDER BY g.name"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list host groups");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(groups))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:id/groups ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AddToGroupRequest {
|
||||
group_id: Uuid,
|
||||
}
|
||||
|
||||
async fn add_host_to_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<AddToGroupRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(id)
|
||||
.bind(req.group_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to add host to group");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "group_id": req.group_id, "action": "added" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Host added to group" })))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:id/groups/:group_id ─────────────────────────────────
|
||||
|
||||
async fn remove_host_from_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((id, group_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM host_groups WHERE host_id = $1 AND group_id = $2")
|
||||
.bind(id)
|
||||
.bind(group_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to remove host from group");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "group_id": group_id, "action": "removed" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Host removed from group" })))
|
||||
}
|
||||
|
||||
// ── FQDN resolution ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Resolve an FQDN (or IP) to its primary IP address.
|
||||
/// If the input is already a valid IP, returns it as-is.
|
||||
async fn resolve_fqdn(fqdn: &str) -> Result<String, String> {
|
||||
use std::net::ToSocketAddrs;
|
||||
// Try direct IP parse first
|
||||
if fqdn.parse::<std::net::IpAddr>().is_ok() {
|
||||
return Ok(fqdn.to_string());
|
||||
}
|
||||
// DNS resolution
|
||||
let addr = format!("{fqdn}:0");
|
||||
match tokio::task::spawn_blocking(move || addr.to_socket_addrs()).await {
|
||||
Ok(Ok(mut addrs)) => addrs
|
||||
.next()
|
||||
.map(|a| a.ip().to_string())
|
||||
.ok_or_else(|| format!("No addresses found for {fqdn}")),
|
||||
_ => Err(format!("Failed to resolve FQDN: {fqdn}")),
|
||||
}
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:id/refresh ───────────────────────────────────────────
|
||||
|
||||
/// Queue an on-demand health + patch refresh for a single host.
|
||||
///
|
||||
/// Sends a PostgreSQL NOTIFY on the `refresh_requested` channel; the
|
||||
/// pm-worker refresh listener picks this up and polls the host immediately.
|
||||
/// Requires Operator or Admin role (any authenticated user).
|
||||
async fn refresh_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<(StatusCode, Json<Value>), (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
// Verify the host exists.
|
||||
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "refresh_host: db error checking host existence");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !exists {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// NOTIFY the worker's refresh listener.
|
||||
sqlx::query("SELECT pg_notify('refresh_requested', $1)")
|
||||
.bind(id.to_string())
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "refresh_host: pg_notify failed");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to queue refresh" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(%id, "On-demand refresh queued");
|
||||
|
||||
Ok((
|
||||
StatusCode::ACCEPTED,
|
||||
Json(json!({ "message": "Refresh queued" })),
|
||||
))
|
||||
}
|
||||
677
crates/pm-web/src/routes/jobs.rs
Executable file
677
crates/pm-web/src/routes/jobs.rs
Executable file
@ -0,0 +1,677 @@
|
||||
//! Patch job management routes.
|
||||
//!
|
||||
//! POST /api/v1/jobs — create a new patch job (operator+)
|
||||
//! GET /api/v1/jobs — list jobs with pagination (RBAC scoped)
|
||||
//! GET /api/v1/jobs/{id} — get job detail + per-host status
|
||||
//! POST /api/v1/jobs/{id}/cancel — cancel a queued/pending job (admin or creator)
|
||||
//! POST /api/v1/jobs/{id}/rollback — create a rollback job (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateJobRequest, PatchJobSummary},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_jobs).post(create_job))
|
||||
.route("/{id}", get(get_job))
|
||||
.route("/{id}/cancel", post(cancel_job))
|
||||
.route("/{id}/rollback", post(rollback_job))
|
||||
}
|
||||
|
||||
// ── Query params ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct JobListQuery {
|
||||
pub limit: Option<i64>,
|
||||
pub offset: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Response types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JobListResponse {
|
||||
jobs: Vec<PatchJobSummary>,
|
||||
total: i64,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
}
|
||||
|
||||
/// Per-host row included in `GET /api/v1/jobs/{id}` response.
|
||||
#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
|
||||
struct JobHostRow {
|
||||
pub id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub display_name: String,
|
||||
pub status: String,
|
||||
pub agent_job_id: Option<String>,
|
||||
pub retry_count: i32,
|
||||
pub output: String,
|
||||
pub error_message: Option<String>,
|
||||
pub last_error: Option<String>,
|
||||
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
// ── Error helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── RBAC helper ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Returns `true` when the operator's groups contain at least one host that
|
||||
/// belongs to the given job. Admins always pass this check at the call site.
|
||||
async fn operator_can_access_job(
|
||||
pool: &sqlx::PgPool,
|
||||
user_id: Uuid,
|
||||
job_id: Uuid,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN host_groups hg ON hg.host_id = pjh.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh.job_id = $1
|
||||
AND ug.user_id = $2
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn create_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateJobRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
if req.host_ids.is_empty() {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"host_ids must not be empty",
|
||||
));
|
||||
}
|
||||
|
||||
// Encode package list as JSONB.
|
||||
let patch_selection = serde_json::to_value(&req.packages).unwrap_or(json!([]));
|
||||
let notes = req.notes.clone().unwrap_or_default();
|
||||
|
||||
// Insert the parent job row; the DB NOTIFY trigger fires automatically
|
||||
// when immediate = TRUE (see migration 003_jobs_scheduling.sql).
|
||||
let job_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, created_by_user_id, maintenance_window_id,
|
||||
immediate, patch_selection, notes)
|
||||
VALUES
|
||||
('patch_apply'::job_kind, 'queued'::job_status, $1, $2, $3, $4, $5)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.bind(req.maintenance_window_id)
|
||||
.bind(req.immediate)
|
||||
.bind(&patch_selection)
|
||||
.bind(¬es)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "create_job: insert patch_jobs failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Insert one patch_job_hosts row per requested host.
|
||||
for host_id in &req.host_ids {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
VALUES ($1, $2, 'queued'::job_status)
|
||||
ON CONFLICT (job_id, host_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e, %job_id, %host_id,
|
||||
"create_job: insert patch_job_hosts failed"
|
||||
);
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&job_id.to_string()),
|
||||
json!({
|
||||
"kind": "patch_apply",
|
||||
"immediate": req.immediate,
|
||||
"host_count": req.host_ids.len(),
|
||||
"packages": req.packages,
|
||||
"notes": notes,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
job_id = %job_id,
|
||||
host_count = req.host_ids.len(),
|
||||
immediate = req.immediate,
|
||||
user = %auth.username,
|
||||
"Patch job created"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "id": job_id, "message": "Job created" })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/jobs ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn list_jobs(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Query(q): Query<JobListQuery>,
|
||||
) -> Result<Json<JobListResponse>, (StatusCode, Json<Value>)> {
|
||||
let limit = q.limit.unwrap_or(50).min(200);
|
||||
let offset = q.offset.unwrap_or(0);
|
||||
|
||||
let jobs: Vec<PatchJobSummary> = if auth.role.is_admin() {
|
||||
// Admins see every job.
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.kind,
|
||||
pj.status,
|
||||
pj.immediate,
|
||||
pj.notes,
|
||||
pj.created_at,
|
||||
pj.started_at,
|
||||
pj.completed_at,
|
||||
COUNT(pjh.id) AS host_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
|
||||
FROM patch_jobs pj
|
||||
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
|
||||
GROUP BY pj.id
|
||||
ORDER BY pj.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
} else {
|
||||
// Operators: only jobs where at least one host is in their groups.
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.kind,
|
||||
pj.status,
|
||||
pj.immediate,
|
||||
pj.notes,
|
||||
pj.created_at,
|
||||
pj.started_at,
|
||||
pj.completed_at,
|
||||
COUNT(pjh.id) AS host_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
|
||||
FROM patch_jobs pj
|
||||
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh2
|
||||
JOIN host_groups hg ON hg.host_id = pjh2.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh2.job_id = pj.id
|
||||
AND ug.user_id = $3
|
||||
)
|
||||
GROUP BY pj.id
|
||||
ORDER BY pj.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.bind(auth.user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
}
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "list_jobs: query failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
// Total count for pagination metadata.
|
||||
let total: i64 = if auth.role.is_admin() {
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM patch_jobs")
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(DISTINCT pj.id)
|
||||
FROM patch_jobs pj
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN host_groups hg ON hg.host_id = pjh.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh.job_id = pj.id
|
||||
AND ug.user_id = $1
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
};
|
||||
|
||||
Ok(Json(JobListResponse {
|
||||
jobs,
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
}))
|
||||
}
|
||||
// ── GET /api/v1/jobs/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn get_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// RBAC: operators may only view jobs touching their group's hosts.
|
||||
if !auth.role.is_admin() {
|
||||
let allowed = operator_can_access_job(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !allowed {
|
||||
return Err(err(StatusCode::FORBIDDEN, "forbidden", "Access denied"));
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the job header row as JSON.
|
||||
let job: Option<Value> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT row_to_json(j) FROM (
|
||||
SELECT id, kind, status, created_by_user_id, parent_job_id,
|
||||
maintenance_window_id, immediate, patch_selection, notes,
|
||||
created_at, started_at, completed_at
|
||||
FROM patch_jobs
|
||||
WHERE id = $1
|
||||
) j
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "get_job: failed to fetch job");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
let job = job.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
|
||||
|
||||
// Fetch per-host status rows joined to the host display name.
|
||||
let hosts: Vec<JobHostRow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pjh.id,
|
||||
pjh.host_id,
|
||||
COALESCE(h.display_name, h.fqdn) AS display_name,
|
||||
pjh.status::text AS status,
|
||||
pjh.agent_job_id,
|
||||
pjh.retry_count,
|
||||
pjh.output,
|
||||
pjh.error_message,
|
||||
pjh.last_error,
|
||||
pjh.started_at,
|
||||
pjh.completed_at
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.job_id = $1
|
||||
ORDER BY h.display_name
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "get_job: failed to fetch host rows");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "job": job, "hosts": hosts })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs/:id/cancel ─────────────────────────────────────────────
|
||||
|
||||
async fn cancel_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Fetch the job to verify it exists and check ownership.
|
||||
let row: Option<(String, Option<Uuid>)> =
|
||||
sqlx::query_as("SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: db fetch failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
let (status_str, creator_id) =
|
||||
row.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
|
||||
|
||||
// Only admin or the job creator may cancel.
|
||||
if !auth.role.can_write() {
|
||||
let is_creator = creator_id == Some(auth.user_id);
|
||||
if !is_creator {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Only queued or pending jobs can be cancelled.
|
||||
if status_str != "queued" && status_str != "pending" {
|
||||
return Err(err(
|
||||
StatusCode::CONFLICT,
|
||||
"invalid_state",
|
||||
format!(
|
||||
"Cannot cancel a job in '{}' state; only queued or pending jobs may be cancelled",
|
||||
status_str
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Cancel the parent job.
|
||||
sqlx::query("UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Cancel all queued/pending host rows for this job.
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'cancelled'::job_status
|
||||
WHERE job_id = $1
|
||||
AND status IN ('queued'::job_status, 'pending'::job_status)
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: update patch_job_hosts failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Fire job-level pg_notify so the frontend can update the job row.
|
||||
let notify_payload = json!({
|
||||
"event_type": "job",
|
||||
"job_id": id.to_string(),
|
||||
"host_id": "",
|
||||
"status": "cancelled",
|
||||
"succeeded_count": 0,
|
||||
"failed_count": 0,
|
||||
"host_count": 0,
|
||||
});
|
||||
if let Ok(payload_str) = serde_json::to_string(¬ify_payload) {
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_str)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, %id, "cancel_job: job-level pg_notify failed");
|
||||
} else {
|
||||
tracing::info!(%id, "cancel_job: job-level pg_notify sent");
|
||||
}
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobCancelled,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "previous_status": status_str }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(job_id = %id, user = %auth.username, "Patch job cancelled");
|
||||
Ok(Json(json!({ "message": "Job cancelled" })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs/:id/rollback ────────────────────────────────────────────
|
||||
|
||||
async fn rollback_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Admin-only operation.
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
|
||||
// Verify the original job exists.
|
||||
let original_exists: bool =
|
||||
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM patch_jobs WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "rollback_job: existence check failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !original_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Job not found"));
|
||||
}
|
||||
|
||||
// Gather the host IDs from the original job.
|
||||
let host_ids: Vec<Uuid> =
|
||||
sqlx::query_scalar("SELECT host_id FROM patch_job_hosts WHERE job_id = $1")
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "rollback_job: host fetch failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if host_ids.is_empty() {
|
||||
return Err(err(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"no_hosts",
|
||||
"Original job has no host entries to roll back",
|
||||
));
|
||||
}
|
||||
|
||||
// Create the rollback job row (immediate = true so the worker picks it up
|
||||
// right away and the NOTIFY trigger fires).
|
||||
let rollback_job_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, created_by_user_id, parent_job_id, immediate,
|
||||
patch_selection, notes)
|
||||
VALUES
|
||||
('rollback'::job_kind, 'queued'::job_status, $1, $2, TRUE,
|
||||
'[]'::jsonb, $3)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.bind(id)
|
||||
.bind(format!("Rollback of job {}", id))
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, parent_job_id = %id, "rollback_job: insert failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Replicate host list into the rollback job.
|
||||
for host_id in &host_ids {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
VALUES ($1, $2, 'queued'::job_status)
|
||||
ON CONFLICT (job_id, host_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(rollback_job_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e, %rollback_job_id, %host_id,
|
||||
"rollback_job: insert patch_job_hosts failed"
|
||||
);
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobRollback,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&rollback_job_id.to_string()),
|
||||
json!({
|
||||
"original_job_id": id,
|
||||
"rollback_job_id": rollback_job_id,
|
||||
"host_count": host_ids.len(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
rollback_job_id = %rollback_job_id,
|
||||
original_job_id = %id,
|
||||
user = %auth.username,
|
||||
"Rollback job created"
|
||||
);
|
||||
|
||||
Ok(Json(json!({
|
||||
"id": rollback_job_id,
|
||||
"parent_job_id": id,
|
||||
"message": "Rollback job created"
|
||||
})))
|
||||
}
|
||||
452
crates/pm-web/src/routes/maintenance_windows.rs
Normal file
452
crates/pm-web/src/routes/maintenance_windows.rs
Normal file
@ -0,0 +1,452 @@
|
||||
//! Maintenance window management routes.
|
||||
//!
|
||||
//! GET /api/v1/hosts/{id}/maintenance-windows — list windows for host
|
||||
//! GET /api/v1/maintenance-windows — list ALL windows (bulk)
|
||||
//! POST /api/v1/hosts/{id}/maintenance-windows — create window for host
|
||||
//! PUT /api/v1/hosts/{id}/maintenance-windows/{win_id} — update window
|
||||
//! DELETE /api/v1/hosts/{id}/maintenance-windows/{win_id} — delete window
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, put},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Mount as a nested router under `/hosts/{host_id}/maintenance-windows`.
|
||||
/// Axum will merge the `{host_id}` path segment from the parent nest.
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_windows).post(create_window))
|
||||
.route("/{win_id}", put(update_window).delete(delete_window))
|
||||
}
|
||||
|
||||
/// Top-level router for `/api/v1/maintenance-windows` — bulk list-all endpoint.
|
||||
pub fn all_windows_router() -> Router<AppState> {
|
||||
Router::new().route("/", get(list_all_windows))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/maintenance-windows ──────────────────────────────────────────
|
||||
|
||||
/// Bulk endpoint: return every maintenance window across all hosts.
|
||||
/// Eliminates N+1 queries from the frontend (one request instead of one per host).
|
||||
async fn list_all_windows(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
ORDER BY host_id, created_at ASC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "list_all_windows: query failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "windows": windows })))
|
||||
}
|
||||
|
||||
// ── Error helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:host_id/maintenance-windows ────────────────────────────
|
||||
|
||||
async fn list_windows(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Verify host exists.
|
||||
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "list_windows: host existence check failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !host_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
|
||||
}
|
||||
|
||||
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
WHERE host_id = $1
|
||||
ORDER BY created_at ASC
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "list_windows: query failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "windows": windows })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/maintenance-windows ───────────────────────────
|
||||
|
||||
async fn create_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
Json(req): Json<CreateMaintenanceWindowRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
// Validate: weekly requires recurrence_day 0-6
|
||||
if req.recurrence == pm_core::models::WindowRecurrence::Weekly {
|
||||
match req.recurrence_day {
|
||||
Some(d) if (0..=6).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Weekly recurrence requires recurrence_day 0-6 (0=Sunday)",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate: monthly requires recurrence_day 1-31
|
||||
if req.recurrence == pm_core::models::WindowRecurrence::Monthly {
|
||||
match req.recurrence_day {
|
||||
Some(d) if (1..=31).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Monthly recurrence requires recurrence_day 1-31",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Verify host exists.
|
||||
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "create_window: host existence check failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !host_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
|
||||
}
|
||||
|
||||
let duration = req.duration_minutes.unwrap_or(60);
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
let auto_apply = req.auto_apply.unwrap_or(true);
|
||||
|
||||
let window: MaintenanceWindow = sqlx::query_as(
|
||||
r#"
|
||||
INSERT INTO maintenance_windows
|
||||
(host_id, label, recurrence, start_at, duration_minutes, recurrence_day, enabled, auto_apply)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(&req.label)
|
||||
.bind(&req.recurrence)
|
||||
.bind(req.start_at)
|
||||
.bind(duration)
|
||||
.bind(req.recurrence_day)
|
||||
.bind(enabled)
|
||||
.bind(auto_apply)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "create_window: insert failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&window.id.to_string()),
|
||||
json!({
|
||||
"host_id": host_id,
|
||||
"label": window.label,
|
||||
"recurrence": window.recurrence.to_string(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %window.id,
|
||||
%host_id,
|
||||
recurrence = %window.recurrence,
|
||||
user = %auth.username,
|
||||
"Maintenance window created"
|
||||
);
|
||||
|
||||
Ok(Json(json!(window)))
|
||||
}
|
||||
|
||||
// ── PUT /api/v1/hosts/:host_id/maintenance-windows/:win_id ───────────────────
|
||||
|
||||
async fn update_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
|
||||
Json(req): Json<UpdateMaintenanceWindowRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
// Fetch existing record (verify ownership and existence).
|
||||
let existing: Option<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
WHERE id = $1 AND host_id = $2
|
||||
"#,
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "update_window: fetch failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
let existing = existing.ok_or_else(|| {
|
||||
err(
|
||||
StatusCode::NOT_FOUND,
|
||||
"not_found",
|
||||
"Maintenance window not found",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Apply partial updates using existing values as defaults.
|
||||
let new_label = req.label.unwrap_or(existing.label);
|
||||
let new_recurrence = req.recurrence.unwrap_or(existing.recurrence);
|
||||
let new_start_at = req.start_at.unwrap_or(existing.start_at);
|
||||
let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes);
|
||||
let new_rec_day = req.recurrence_day.or(existing.recurrence_day);
|
||||
let new_enabled = req.enabled.unwrap_or(existing.enabled);
|
||||
let new_auto_apply = req.auto_apply.unwrap_or(existing.auto_apply);
|
||||
|
||||
// Validate recurrence_day for the final recurrence type.
|
||||
if new_recurrence == pm_core::models::WindowRecurrence::Weekly {
|
||||
match new_rec_day {
|
||||
Some(d) if (0..=6).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Weekly recurrence requires recurrence_day 0-6",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
if new_recurrence == pm_core::models::WindowRecurrence::Monthly {
|
||||
match new_rec_day {
|
||||
Some(d) if (1..=31).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Monthly recurrence requires recurrence_day 1-31",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let updated: MaintenanceWindow = sqlx::query_as(
|
||||
r#"
|
||||
UPDATE maintenance_windows
|
||||
SET label = $3,
|
||||
recurrence = $4,
|
||||
start_at = $5,
|
||||
duration_minutes = $6,
|
||||
recurrence_day = $7,
|
||||
enabled = $8,
|
||||
auto_apply = $9,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND host_id = $2
|
||||
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.bind(&new_label)
|
||||
.bind(&new_recurrence)
|
||||
.bind(new_start_at)
|
||||
.bind(new_duration)
|
||||
.bind(new_rec_day)
|
||||
.bind(new_enabled)
|
||||
.bind(new_auto_apply)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "update_window: update failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&win_id.to_string()),
|
||||
json!({ "host_id": host_id }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %win_id,
|
||||
%host_id,
|
||||
user = %auth.username,
|
||||
"Maintenance window updated"
|
||||
);
|
||||
|
||||
Ok(Json(json!(updated)))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:host_id/maintenance-windows/:win_id ────────────────
|
||||
|
||||
async fn delete_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
let result = sqlx::query("DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2")
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "delete_window: delete failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(err(
|
||||
StatusCode::NOT_FOUND,
|
||||
"not_found",
|
||||
"Maintenance window not found",
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&win_id.to_string()),
|
||||
json!({ "host_id": host_id }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %win_id,
|
||||
%host_id,
|
||||
user = %auth.username,
|
||||
"Maintenance window deleted"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "message": "Maintenance window deleted" })))
|
||||
}
|
||||
17
crates/pm-web/src/routes/mod.rs
Executable file
17
crates/pm-web/src/routes/mod.rs
Executable file
@ -0,0 +1,17 @@
|
||||
//! Route modules for the pm-web API.
|
||||
pub mod auth;
|
||||
pub mod ca;
|
||||
pub mod discovery;
|
||||
pub mod enrollment;
|
||||
pub mod groups;
|
||||
pub mod health_checks;
|
||||
pub mod hosts;
|
||||
pub mod jobs;
|
||||
pub mod maintenance_windows;
|
||||
pub mod settings;
|
||||
pub mod sso;
|
||||
pub mod status;
|
||||
pub mod users;
|
||||
pub mod ws;
|
||||
|
||||
pub mod reports;
|
||||
163
crates/pm-web/src/routes/reports.rs
Executable file
163
crates/pm-web/src/routes/reports.rs
Executable file
@ -0,0 +1,163 @@
|
||||
//! Report generation endpoints.
|
||||
//!
|
||||
//! GET /api/v1/reports/compliance?format=csv|pdf&from=...&to=...&group_id=...
|
||||
//! GET /api/v1/reports/patch-history?format=csv|pdf&from=...&to=...
|
||||
//! GET /api/v1/reports/vulnerability?format=csv|pdf&from=...&to=...
|
||||
//! GET /api/v1/reports/audit?format=csv|pdf&from=...&to=...
|
||||
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
extract::{Query, State},
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use pm_reports::{ReportParams, ReportType};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ReportQuery {
|
||||
/// "csv" or "pdf" (defaults to "csv")
|
||||
format: Option<String>,
|
||||
from: Option<chrono::DateTime<chrono::Utc>>,
|
||||
to: Option<chrono::DateTime<chrono::Utc>>,
|
||||
group_id: Option<uuid::Uuid>,
|
||||
}
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/compliance", get(compliance_report))
|
||||
.route("/patch-history", get(patch_history_report))
|
||||
.route("/vulnerability", get(vulnerability_report))
|
||||
.route("/audit", get(audit_report))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_report(
|
||||
db: sqlx::PgPool,
|
||||
params: ReportParams,
|
||||
use_pdf: bool,
|
||||
csv_name: &'static str,
|
||||
pdf_name: &'static str,
|
||||
) -> Response {
|
||||
let (ct, disposition, result) = if use_pdf {
|
||||
let disp = format!("attachment; filename=\"{}\"", pdf_name);
|
||||
let data = pm_reports::generate_pdf(&db, ¶ms).await;
|
||||
("application/pdf", disp, data)
|
||||
} else {
|
||||
let disp = format!("attachment; filename=\"{}\"", csv_name);
|
||||
let data = pm_reports::generate_csv(&db, ¶ms).await;
|
||||
("text/csv; charset=utf-8", disp, data)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(bytes) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static(ct));
|
||||
headers.insert(
|
||||
header::CONTENT_DISPOSITION,
|
||||
HeaderValue::from_str(&disposition)
|
||||
.unwrap_or_else(|_| HeaderValue::from_static("attachment")),
|
||||
);
|
||||
(headers, Bytes::from(bytes)).into_response()
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "report generation failed");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Report error: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn compliance_report(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<ReportQuery>,
|
||||
) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::Compliance,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"compliance-report.csv",
|
||||
"compliance-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn patch_history_report(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<ReportQuery>,
|
||||
) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::PatchHistory,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"patch-history-report.csv",
|
||||
"patch-history-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn vulnerability_report(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<ReportQuery>,
|
||||
) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::Vulnerability,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"vulnerability-report.csv",
|
||||
"vulnerability-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn audit_report(State(state): State<AppState>, Query(q): Query<ReportQuery>) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::Audit,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"audit-report.csv",
|
||||
"audit-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
977
crates/pm-web/src/routes/settings.rs
Executable file
977
crates/pm-web/src/routes/settings.rs
Executable file
@ -0,0 +1,977 @@
|
||||
//! Settings management routes.
|
||||
//!
|
||||
//! GET /api/v1/settings — get all settings (admin only)
|
||||
//! PUT /api/v1/settings — update settings (admin only)
|
||||
//! POST /api/v1/settings/sso/discover — discover OIDC endpoints (admin only)
|
||||
//! POST /api/v1/settings/sso/test — test OIDC provider connectivity (admin only)
|
||||
//! POST /api/v1/settings/azure-sso/test — backward-compat alias for SSO test (admin only)
|
||||
//! POST /api/v1/settings/smtp/test — send test email (admin only)
|
||||
//! GET /api/v1/settings/ip-whitelist — get IP whitelist (admin only)
|
||||
//! PUT /api/v1/settings/ip-whitelist — update IP whitelist (admin only)
|
||||
//! POST /api/v1/settings/audit-integrity — verify audit log integrity (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use lettre::{
|
||||
message::{header::ContentType, Mailbox},
|
||||
transport::smtp::authentication::Credentials,
|
||||
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::audit::{log_event, verify_integrity, AuditAction};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ============================================================
|
||||
// Data structures
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SettingsResponse {
|
||||
pub oidc: OidcConfigResponse,
|
||||
pub smtp: SmtpConfig,
|
||||
pub polling: PollingConfig,
|
||||
pub ip_whitelist: Vec<String>,
|
||||
pub web_tls_strategy: String,
|
||||
pub notification: NotificationConfig,
|
||||
pub sso_callback_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OidcConfigResponse {
|
||||
pub enabled: bool,
|
||||
pub provider_type: String, // "keycloak", "azure", "custom"
|
||||
pub display_name: String,
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String, // Always masked in responses
|
||||
pub redirect_uri: String,
|
||||
pub scopes: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SmtpConfig {
|
||||
pub enabled: bool,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub username: String,
|
||||
pub from: String,
|
||||
pub tls_mode: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PollingConfig {
|
||||
pub health_poll_interval_secs: u64,
|
||||
pub patch_poll_interval_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateSettingsRequest {
|
||||
pub oidc: Option<OidcConfigUpdate>,
|
||||
pub smtp: Option<SmtpConfigUpdate>,
|
||||
pub polling: Option<PollingConfigUpdate>,
|
||||
pub ip_whitelist: Option<Vec<String>>,
|
||||
pub web_tls_strategy: Option<String>,
|
||||
pub notification: Option<NotificationConfigUpdate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct NotificationConfig {
|
||||
pub email_enabled: bool,
|
||||
pub email_from: String,
|
||||
pub recipients: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct NotificationConfigUpdate {
|
||||
pub email_enabled: Option<bool>,
|
||||
pub email_from: Option<String>,
|
||||
pub recipients: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub provider_type: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub discovery_url: Option<String>,
|
||||
pub client_id: Option<String>,
|
||||
pub client_secret: Option<String>,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub scopes: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcDiscoveryRequest {
|
||||
pub discovery_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct OidcDiscoveryResult {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmtpConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub host: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub username: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub from: Option<String>,
|
||||
pub tls_mode: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PollingConfigUpdate {
|
||||
pub health_poll_interval_secs: Option<u64>,
|
||||
pub patch_poll_interval_secs: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IpWhitelistUpdate {
|
||||
pub entries: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Router
|
||||
// ============================================================
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(get_settings).put(update_settings))
|
||||
.route("/sso/discover", post(discover_oidc))
|
||||
.route("/sso/test", post(test_oidc))
|
||||
.route("/azure-sso/test", post(test_azure_sso_compat))
|
||||
.route("/smtp/test", post(test_smtp))
|
||||
.route(
|
||||
"/ip-whitelist",
|
||||
get(get_ip_whitelist).put(update_ip_whitelist),
|
||||
)
|
||||
.route("/audit-integrity", post(audit_integrity))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
const MASKED: &str = "********";
|
||||
|
||||
fn write_access_required(auth: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_system_config(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Result<HashMap<String, String>, (StatusCode, Json<Value>)> {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as("SELECT key, value FROM system_config")
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load system_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(rows.into_iter().collect())
|
||||
}
|
||||
|
||||
fn build_settings_response(
|
||||
cfg: &HashMap<String, String>,
|
||||
oidc: OidcConfigResponse,
|
||||
) -> SettingsResponse {
|
||||
let get = |key: &str| -> String { cfg.get(key).cloned().unwrap_or_default() };
|
||||
|
||||
let recipients: Vec<String> =
|
||||
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
|
||||
|
||||
SettingsResponse {
|
||||
oidc,
|
||||
smtp: SmtpConfig {
|
||||
enabled: get("smtp_enabled") == "true",
|
||||
host: get("smtp_host"),
|
||||
port: get("smtp_port").parse().unwrap_or(587),
|
||||
username: get("smtp_username"),
|
||||
from: get("smtp_from"),
|
||||
tls_mode: get("smtp_tls_mode"),
|
||||
},
|
||||
polling: PollingConfig {
|
||||
health_poll_interval_secs: get("health_poll_interval_secs").parse().unwrap_or(300),
|
||||
patch_poll_interval_secs: get("patch_poll_interval_secs").parse().unwrap_or(1800),
|
||||
},
|
||||
ip_whitelist: serde_json::from_str(&get("ip_whitelist")).unwrap_or_default(),
|
||||
web_tls_strategy: get("web_tls_strategy"),
|
||||
notification: NotificationConfig {
|
||||
email_enabled: get("notification_email_enabled") == "true",
|
||||
email_from: get("notification_email_from"),
|
||||
recipients,
|
||||
},
|
||||
sso_callback_url: get("sso_callback_url"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_config_key(
|
||||
pool: &sqlx::PgPool,
|
||||
key: &str,
|
||||
value: &str,
|
||||
) -> Result<(), (StatusCode, Json<Value>)> {
|
||||
sqlx::query("UPDATE system_config SET value = $1, updated_at = NOW() WHERE key = $2")
|
||||
.bind(value)
|
||||
.bind(key)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, key, "Failed to update system_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_oidc_config(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Result<OidcConfigResponse, (StatusCode, Json<Value>)> {
|
||||
let row: Option<(bool, String, String, String, String, String, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(match row {
|
||||
Some((
|
||||
enabled,
|
||||
provider_type,
|
||||
display_name,
|
||||
discovery_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
scopes,
|
||||
)) => OidcConfigResponse {
|
||||
enabled,
|
||||
provider_type,
|
||||
display_name,
|
||||
discovery_url,
|
||||
client_id,
|
||||
client_secret: if client_secret.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
MASKED.to_string()
|
||||
},
|
||||
redirect_uri,
|
||||
scopes,
|
||||
},
|
||||
None => OidcConfigResponse {
|
||||
enabled: false,
|
||||
provider_type: "azure".to_string(),
|
||||
display_name: "Azure AD".to_string(),
|
||||
discovery_url: String::new(),
|
||||
client_id: String::new(),
|
||||
client_secret: String::new(),
|
||||
redirect_uri: String::new(),
|
||||
scopes: "openid profile email".to_string(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/settings
|
||||
// ============================================================
|
||||
|
||||
async fn get_settings(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
let cfg = load_system_config(&state.db).await?;
|
||||
// Inject read-only config values from TOML file (not stored in DB)
|
||||
let mut cfg = cfg;
|
||||
cfg.insert(
|
||||
"sso_callback_url".to_string(),
|
||||
state.config.security.sso_callback_url.clone(),
|
||||
);
|
||||
let oidc = fetch_oidc_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, oidc)))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/settings
|
||||
// ============================================================
|
||||
|
||||
async fn update_settings(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<UpdateSettingsRequest>,
|
||||
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
// Update OIDC config
|
||||
if let Some(oidc) = req.oidc {
|
||||
let update_secret = oidc
|
||||
.client_secret
|
||||
.as_ref()
|
||||
.is_some_and(|s| s != MASKED && !s.is_empty());
|
||||
|
||||
let result = if update_secret {
|
||||
sqlx::query(
|
||||
"UPDATE oidc_config SET \
|
||||
enabled = COALESCE($1, enabled), \
|
||||
provider_type = COALESCE($2, provider_type), \
|
||||
display_name = COALESCE($3, display_name), \
|
||||
discovery_url = COALESCE($4, discovery_url), \
|
||||
client_id = COALESCE($5, client_id), \
|
||||
client_secret = $6, \
|
||||
redirect_uri = COALESCE($7, redirect_uri), \
|
||||
scopes = COALESCE($8, scopes), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = 1",
|
||||
)
|
||||
.bind(oidc.enabled)
|
||||
.bind(&oidc.provider_type)
|
||||
.bind(&oidc.display_name)
|
||||
.bind(&oidc.discovery_url)
|
||||
.bind(&oidc.client_id)
|
||||
.bind(oidc.client_secret.as_deref().unwrap_or(""))
|
||||
.bind(&oidc.redirect_uri)
|
||||
.bind(&oidc.scopes)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE oidc_config SET \
|
||||
enabled = COALESCE($1, enabled), \
|
||||
provider_type = COALESCE($2, provider_type), \
|
||||
display_name = COALESCE($3, display_name), \
|
||||
discovery_url = COALESCE($4, discovery_url), \
|
||||
client_id = COALESCE($5, client_id), \
|
||||
redirect_uri = COALESCE($6, redirect_uri), \
|
||||
scopes = COALESCE($7, scopes), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = 1",
|
||||
)
|
||||
.bind(oidc.enabled)
|
||||
.bind(&oidc.provider_type)
|
||||
.bind(&oidc.display_name)
|
||||
.bind(&oidc.discovery_url)
|
||||
.bind(&oidc.client_id)
|
||||
.bind(&oidc.redirect_uri)
|
||||
.bind(&oidc.scopes)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
result.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": format!("Failed to update OIDC config: {}", e) } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("oidc"),
|
||||
Some("1"),
|
||||
json!({ "section": "oidc" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update SMTP config
|
||||
if let Some(smtp) = &req.smtp {
|
||||
if let Some(v) = smtp.enabled {
|
||||
update_config_key(&state.db, "smtp_enabled", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.host {
|
||||
update_config_key(&state.db, "smtp_host", v).await?;
|
||||
}
|
||||
if let Some(v) = smtp.port {
|
||||
update_config_key(&state.db, "smtp_port", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.username {
|
||||
update_config_key(&state.db, "smtp_username", v).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.password {
|
||||
if v != MASKED {
|
||||
update_config_key(&state.db, "smtp_password", v).await?;
|
||||
}
|
||||
}
|
||||
if let Some(ref v) = smtp.from {
|
||||
update_config_key(&state.db, "smtp_from", v).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.tls_mode {
|
||||
update_config_key(&state.db, "smtp_tls_mode", v).await?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("smtp"),
|
||||
Some("system_config"),
|
||||
json!({ "section": "smtp" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update polling config
|
||||
if let Some(polling) = &req.polling {
|
||||
if let Some(v) = polling.health_poll_interval_secs {
|
||||
update_config_key(&state.db, "health_poll_interval_secs", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(v) = polling.patch_poll_interval_secs {
|
||||
update_config_key(&state.db, "patch_poll_interval_secs", &v.to_string()).await?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("polling"),
|
||||
Some("system_config"),
|
||||
json!({ "section": "polling" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update IP whitelist
|
||||
if let Some(ref entries) = req.ip_whitelist {
|
||||
let json_str = serde_json::to_string(entries).unwrap_or_else(|_| "[]".to_string());
|
||||
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
|
||||
|
||||
// Update in-memory AuthConfig for immediate enforcement
|
||||
state.auth_config.update_ip_whitelist(entries.clone());
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("ip_whitelist"),
|
||||
Some("system_config"),
|
||||
json!({ "entries": entries }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update web TLS strategy
|
||||
if let Some(ref v) = req.web_tls_strategy {
|
||||
update_config_key(&state.db, "web_tls_strategy", v).await?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("web_tls_strategy"),
|
||||
Some("system_config"),
|
||||
json!({ "web_tls_strategy": v }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update notification config
|
||||
if let Some(notif) = &req.notification {
|
||||
if let Some(v) = notif.email_enabled {
|
||||
update_config_key(&state.db, "notification_email_enabled", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(ref v) = notif.email_from {
|
||||
update_config_key(&state.db, "notification_email_from", v).await?;
|
||||
}
|
||||
if let Some(ref v) = notif.recipients {
|
||||
let json_str = serde_json::to_string(v).unwrap_or_else(|_| "[]".to_string());
|
||||
update_config_key(&state.db, "notification_email_recipients", &json_str).await?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("notification"),
|
||||
Some("system_config"),
|
||||
json!({ "section": "notification" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Return updated settings
|
||||
let cfg = load_system_config(&state.db).await?;
|
||||
// Inject read-only config values from TOML file (not stored in DB)
|
||||
let mut cfg = cfg;
|
||||
cfg.insert(
|
||||
"sso_callback_url".to_string(),
|
||||
state.config.security.sso_callback_url.clone(),
|
||||
);
|
||||
let oidc = fetch_oidc_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, oidc)))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/sso/discover
|
||||
// ============================================================
|
||||
|
||||
async fn discover_oidc(
|
||||
State(_state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<OidcDiscoveryRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
if req.discovery_url.is_empty() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": "discovery_url is required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
match client.get(&req.discovery_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: Value = resp.json().await.unwrap_or(json!({}));
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"issuer": body.get("issuer").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"authorization_endpoint": body.get("authorization_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"token_endpoint": body.get("token_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"jwks_uri": body.get("jwks_uri").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"userinfo_endpoint": body.get("userinfo_endpoint").and_then(|v| v.as_str()),
|
||||
})))
|
||||
},
|
||||
Ok(resp) => Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({ "error": { "code": "discovery_failed", "message": format!("Discovery endpoint returned HTTP {}", resp.status()) } }),
|
||||
),
|
||||
)),
|
||||
Err(e) => Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({ "error": { "code": "discovery_failed", "message": format!("Failed to reach discovery endpoint: {}", e) } }),
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/sso/test
|
||||
// ============================================================
|
||||
|
||||
async fn test_oidc(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let row: Option<(bool, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, discovery_url FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (enabled, provider_type, discovery_url) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC is not configured"
|
||||
})));
|
||||
},
|
||||
};
|
||||
|
||||
if !enabled {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC is not enabled"
|
||||
})));
|
||||
}
|
||||
|
||||
if discovery_url.is_empty() {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC discovery URL is not set"
|
||||
})));
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
match client.get(&discovery_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: Value = resp.json().await.unwrap_or(json!({}));
|
||||
let issuer = body.get("issuer").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let provider_label = match provider_type.as_str() {
|
||||
"keycloak" => "Keycloak",
|
||||
"azure" => "Azure AD",
|
||||
_ => "OIDC",
|
||||
};
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": format!("{} provider verified successfully", provider_label),
|
||||
"issuer": issuer,
|
||||
"provider_type": provider_type,
|
||||
})))
|
||||
},
|
||||
Ok(resp) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to reach OIDC provider: HTTP {}", resp.status())
|
||||
}))),
|
||||
Err(e) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to reach OIDC provider: {}", e)
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/azure-sso/test (backward-compatible alias)
|
||||
// ============================================================
|
||||
|
||||
async fn test_azure_sso_compat(
|
||||
state: State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
test_oidc(state, auth).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/smtp/test
|
||||
// ============================================================
|
||||
|
||||
async fn test_smtp(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let cfg = load_system_config(&state.db).await?;
|
||||
|
||||
let smtp_enabled = cfg.get("smtp_enabled").map(|v| v.as_str()) == Some("true");
|
||||
if !smtp_enabled {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "SMTP is not enabled"
|
||||
})));
|
||||
}
|
||||
|
||||
let host = cfg.get("smtp_host").cloned().unwrap_or_default();
|
||||
let port: u16 = cfg
|
||||
.get("smtp_port")
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(587);
|
||||
let username = cfg.get("smtp_username").cloned().unwrap_or_default();
|
||||
let password = cfg.get("smtp_password").cloned().unwrap_or_default();
|
||||
let from_addr = cfg.get("smtp_from").cloned().unwrap_or_default();
|
||||
let tls_mode = cfg
|
||||
.get("smtp_tls_mode")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "starttls".to_string());
|
||||
|
||||
let recipients_str = cfg
|
||||
.get("notification_email_recipients")
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let recipients: Vec<String> = serde_json::from_str(&recipients_str).unwrap_or_default();
|
||||
|
||||
if host.is_empty() || from_addr.is_empty() {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "SMTP host or from address is not configured"
|
||||
})));
|
||||
}
|
||||
|
||||
let result = send_smtp_test(
|
||||
&host,
|
||||
port,
|
||||
&username,
|
||||
&password,
|
||||
&from_addr,
|
||||
&tls_mode,
|
||||
&recipients,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
let recipient_info = if recipients.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" and {} recipient(s)", recipients.len())
|
||||
};
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": format!("Test email sent successfully to from address{}", recipient_info)
|
||||
})))
|
||||
},
|
||||
Err(e) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to send test email: {}", e)
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_smtp_test(
|
||||
host: &str,
|
||||
port: u16,
|
||||
username: &str,
|
||||
password: &str,
|
||||
from_addr: &str,
|
||||
tls_mode: &str,
|
||||
recipients: &[String],
|
||||
) -> Result<(), String> {
|
||||
let from_mailbox: Mailbox = from_addr
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid from address: {}", e))?;
|
||||
|
||||
let mut builder = Message::builder()
|
||||
.from(from_mailbox.clone())
|
||||
.to(from_mailbox);
|
||||
|
||||
for recipient in recipients {
|
||||
if let Ok(addr) = recipient.parse() {
|
||||
builder = builder.bcc(addr);
|
||||
}
|
||||
}
|
||||
|
||||
let body = if recipients.is_empty() {
|
||||
"This is a test email from Linux Patch Manager.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"This is a test email from Linux Patch Manager.\n\nSent to: {}",
|
||||
recipients.join(", ")
|
||||
)
|
||||
};
|
||||
|
||||
let email = builder
|
||||
.subject("Linux Patch Manager — SMTP Test")
|
||||
.header(ContentType::TEXT_PLAIN)
|
||||
.body(body)
|
||||
.map_err(|e| format!("Failed to build email: {}", e))?;
|
||||
|
||||
let result = match tls_mode {
|
||||
"tls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(host)
|
||||
.map_err(|e| format!("TLS relay error: {}", e))?;
|
||||
builder = builder.port(port);
|
||||
if !username.is_empty() {
|
||||
builder = builder
|
||||
.credentials(Credentials::new(username.to_string(), password.to_string()));
|
||||
}
|
||||
let transport = builder.build();
|
||||
transport.send(email).await
|
||||
},
|
||||
"starttls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(host)
|
||||
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
|
||||
builder = builder.port(port);
|
||||
if !username.is_empty() {
|
||||
builder = builder
|
||||
.credentials(Credentials::new(username.to_string(), password.to_string()));
|
||||
}
|
||||
let transport = builder.build();
|
||||
transport.send(email).await
|
||||
},
|
||||
_ => {
|
||||
// "none" — plaintext / no TLS
|
||||
let mut builder =
|
||||
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(host).port(port);
|
||||
if !username.is_empty() {
|
||||
builder = builder
|
||||
.credentials(Credentials::new(username.to_string(), password.to_string()));
|
||||
}
|
||||
let transport = builder.build();
|
||||
transport.send(email).await
|
||||
},
|
||||
};
|
||||
|
||||
result
|
||||
.map(|_| ())
|
||||
.map_err(|e| format!("SMTP send error: {}", e))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/settings/ip-whitelist
|
||||
// ============================================================
|
||||
|
||||
async fn get_ip_whitelist(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let value: Option<String> = sqlx::query_scalar(
|
||||
"SELECT value FROM system_config WHERE key = 'ip_whitelist'",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load ip_whitelist");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let entries: Vec<String> = serde_json::from_str(&value.unwrap_or_default()).unwrap_or_default();
|
||||
Ok(Json(json!({ "entries": entries })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/settings/ip-whitelist
|
||||
// ============================================================
|
||||
|
||||
async fn update_ip_whitelist(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<IpWhitelistUpdate>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
// Validate each entry
|
||||
for entry in &req.entries {
|
||||
if entry.parse::<ipnet::IpNet>().is_err() && entry.parse::<std::net::IpAddr>().is_err() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": format!("Invalid CIDR or IP: {}", entry) } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let json_str = serde_json::to_string(&req.entries).unwrap_or_else(|_| "[]".to_string());
|
||||
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
|
||||
|
||||
// Update in-memory AuthConfig for immediate enforcement
|
||||
state.auth_config.update_ip_whitelist(req.entries.clone());
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("ip_whitelist"),
|
||||
Some("system_config"),
|
||||
json!({ "entries": req.entries }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
Ok(Json(json!({ "entries": req.entries })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/audit-integrity
|
||||
// ============================================================
|
||||
|
||||
/// Verify audit log hash chain integrity.
|
||||
/// Returns whether the chain is intact, rows checked, and any errors.
|
||||
async fn audit_integrity(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let result = verify_integrity(&state.db).await;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::AuditIntegrityVerified,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("audit_log"),
|
||||
None,
|
||||
json!({
|
||||
"intact": result.intact,
|
||||
"rows_checked": result.rows_checked,
|
||||
"error_count": result.errors.len(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({
|
||||
"intact": result.intact,
|
||||
"rows_checked": result.rows_checked,
|
||||
"errors": result.errors.iter().map(|e| json!({
|
||||
"row_id": e.row_id,
|
||||
"expected_hash": e.expected_hash,
|
||||
"actual_hash": e.actual_hash,
|
||||
})).collect::<Vec<_>>(),
|
||||
})))
|
||||
}
|
||||
838
crates/pm-web/src/routes/sso.rs
Executable file
838
crates/pm-web/src/routes/sso.rs
Executable file
@ -0,0 +1,838 @@
|
||||
//! Generic OIDC SSO routes (Keycloak, Azure AD, Custom).
|
||||
//!
|
||||
//! Public routes (no auth required):
|
||||
//! GET /api/v1/auth/sso/login — redirect to OIDC provider authorization URL
|
||||
//! GET /api/v1/auth/sso/callback — handle OIDC provider callback, redirect to frontend SPA
|
||||
//!
|
||||
//! Backward-compatible aliases:
|
||||
//! GET /api/v1/auth/azure/login → redirects to generic SSO login
|
||||
//! GET /api/v1/auth/azure/callback → redirects to generic SSO callback
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json, Redirect},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
|
||||
use pm_auth::{jwt::issue_access_token, refresh};
|
||||
use pm_core::audit::{log_event, AuditAction};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ============================================================
|
||||
// Data structures
|
||||
// ============================================================
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SsoSession {
|
||||
pub code_verifier: String,
|
||||
pub created_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
#[allow(dead_code)]
|
||||
access_token: Option<String>,
|
||||
id_token: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
expires_in: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IdTokenClaims {
|
||||
email: Option<String>,
|
||||
name: Option<String>,
|
||||
sub: Option<String>,
|
||||
oid: Option<String>,
|
||||
preferred_username: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct DbUserForSso {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
display_name: String,
|
||||
role: String,
|
||||
is_active: bool,
|
||||
mfa_enabled: bool,
|
||||
}
|
||||
|
||||
/// OIDC provider configuration from database.
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
pub struct OidcConfig {
|
||||
pub enabled: bool,
|
||||
pub provider_type: String,
|
||||
pub display_name: String,
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub redirect_uri: String,
|
||||
pub scopes: String,
|
||||
}
|
||||
|
||||
/// Cached OIDC discovery document.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OidcDiscovery {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>,
|
||||
pub fetched_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Cache for OIDC discovery documents and JWKS with TTL-based refresh.
|
||||
#[derive(Default)]
|
||||
pub struct OidcCache {
|
||||
pub discovery: Option<OidcDiscovery>,
|
||||
pub jwks: Option<serde_json::Value>,
|
||||
pub jwks_fetched_at: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// JWKS cache TTL in seconds (1 hour).
|
||||
const JWKS_CACHE_TTL_SECS: i64 = 3600;
|
||||
/// Discovery cache TTL in seconds (1 hour).
|
||||
const DISCOVERY_CACHE_TTL_SECS: i64 = 3600;
|
||||
|
||||
// ============================================================
|
||||
// Router
|
||||
// ============================================================
|
||||
|
||||
pub fn public_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", get(sso_login))
|
||||
.route("/callback", get(sso_callback))
|
||||
.route("/config", get(sso_config))
|
||||
}
|
||||
|
||||
/// Backward-compatible Azure SSO routes — redirect to generic SSO endpoints.
|
||||
pub fn azure_compat_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", get(azure_login_redirect))
|
||||
.route("/callback", get(azure_callback_redirect))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/sso/config
|
||||
// ============================================================
|
||||
|
||||
/// Public endpoint returning minimal SSO configuration for the login page.
|
||||
/// Returns only: enabled, display_name, auth_url — no secrets exposed.
|
||||
async fn sso_config(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
// If we can't load config, SSO is effectively disabled
|
||||
return Ok(Json(json!({
|
||||
"enabled": false,
|
||||
"display_name": "SSO",
|
||||
"auth_url": ""
|
||||
})));
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Json(json!({
|
||||
"enabled": config.enabled,
|
||||
"display_name": if config.display_name.is_empty() { "SSO".to_string() } else { config.display_name },
|
||||
"auth_url": "/api/v1/auth/sso/login"
|
||||
})))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/sso/login
|
||||
// ============================================================
|
||||
|
||||
async fn sso_login(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
let config = load_oidc_config(&state.db).await?;
|
||||
|
||||
if !config.enabled {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "SSO is not enabled" } })),
|
||||
));
|
||||
}
|
||||
|
||||
if config.discovery_url.is_empty() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "SSO discovery URL is not configured" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch OIDC discovery document (with caching)
|
||||
let discovery = match fetch_discovery(&state).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
json!({ "error": { "code": "internal_error", "message": format!("Failed to fetch OIDC discovery: {}", e) } }),
|
||||
),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Generate PKCE code_verifier (32 random bytes → base64url)
|
||||
let mut verifier_bytes = [0u8; 32];
|
||||
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut verifier_bytes);
|
||||
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
|
||||
|
||||
// code_challenge = BASE64URL(SHA256(code_verifier))
|
||||
let challenge_digest = Sha256::digest(code_verifier.as_bytes());
|
||||
let code_challenge = URL_SAFE_NO_PAD.encode(challenge_digest);
|
||||
|
||||
// Generate state token
|
||||
let state_token = Uuid::new_v4().to_string();
|
||||
|
||||
// Store (state_token, code_verifier) in sso_sessions DashMap
|
||||
state.sso_sessions.insert(
|
||||
state_token.clone(),
|
||||
SsoSession {
|
||||
code_verifier,
|
||||
created_at: Utc::now(),
|
||||
},
|
||||
);
|
||||
|
||||
// Build authorization URL from discovery
|
||||
let encoded_scopes = urlencoding::encode(&config.scopes);
|
||||
let auth_url = format!(
|
||||
"{}?client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||
discovery.authorization_endpoint,
|
||||
urlencoding::encode(&config.client_id),
|
||||
urlencoding::encode(&config.redirect_uri),
|
||||
encoded_scopes,
|
||||
code_challenge,
|
||||
state_token
|
||||
);
|
||||
|
||||
Ok(Redirect::to(&auth_url))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/sso/callback
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: Option<String>,
|
||||
state: Option<String>,
|
||||
error: Option<String>,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
async fn sso_callback(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<Redirect, Redirect> {
|
||||
let callback_url = &state.config.security.sso_callback_url;
|
||||
|
||||
let error_redirect = |code: &str, message: &str| -> Redirect {
|
||||
let url = format!(
|
||||
"{}?error={}&error_description={}",
|
||||
callback_url,
|
||||
urlencoding::encode(code),
|
||||
urlencoding::encode(message)
|
||||
);
|
||||
Redirect::to(&url)
|
||||
};
|
||||
|
||||
if let Some(error) = params.error {
|
||||
let desc = params.error_description.unwrap_or_default();
|
||||
let message = format!("OIDC provider error: {} - {}", error, desc);
|
||||
return Err(error_redirect("sso_error", &message));
|
||||
}
|
||||
|
||||
let code = match params.code {
|
||||
Some(c) => c,
|
||||
None => return Err(error_redirect("bad_request", "Missing authorization code")),
|
||||
};
|
||||
|
||||
let state_token = match params.state {
|
||||
Some(s) => s,
|
||||
None => return Err(error_redirect("bad_request", "Missing state parameter")),
|
||||
};
|
||||
|
||||
let sso_session = match state.sso_sessions.remove(&state_token).map(|(_, v)| v) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Err(error_redirect(
|
||||
"bad_request",
|
||||
"Invalid or expired state token",
|
||||
))
|
||||
},
|
||||
};
|
||||
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to load OIDC config",
|
||||
))
|
||||
},
|
||||
};
|
||||
|
||||
let discovery = match fetch_discovery(&state).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to fetch OIDC discovery");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to fetch OIDC discovery",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Exchange code for tokens
|
||||
let client = match reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
return Err(error_redirect("internal_error", "HTTP client error"));
|
||||
},
|
||||
};
|
||||
|
||||
let mut params_vec: Vec<(&str, String)> = vec![
|
||||
("grant_type", "authorization_code".to_string()),
|
||||
("code", code.clone()),
|
||||
("redirect_uri", config.redirect_uri.clone()),
|
||||
("client_id", config.client_id.clone()),
|
||||
("code_verifier", sso_session.code_verifier.clone()),
|
||||
];
|
||||
|
||||
// For confidential clients (Azure AD), include client_secret
|
||||
if !config.client_secret.is_empty() {
|
||||
params_vec.push(("client_secret", config.client_secret.clone()));
|
||||
}
|
||||
|
||||
let token_resp = match client
|
||||
.post(&discovery.token_endpoint)
|
||||
.form(¶ms_vec)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Token exchange request failed");
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
&format!("Token exchange failed: {}", e),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
if !token_resp.status().is_success() {
|
||||
let status = token_resp.status();
|
||||
let body = token_resp.text().await.unwrap_or_default();
|
||||
tracing::error!(status = %status, body = %body, "Token exchange failed");
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
&format!("Token exchange failed: HTTP {}", status),
|
||||
));
|
||||
}
|
||||
|
||||
let token_data: TokenResponse = match token_resp.json().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to parse token response");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to parse token response",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
let id_token = match token_data.id_token {
|
||||
Some(t) => t,
|
||||
None => return Err(error_redirect("sso_error", "No id_token in response")),
|
||||
};
|
||||
|
||||
let claims = match verify_id_token(&id_token, &config, &discovery, &state.oidc_cache).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to verify id_token");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to verify id_token",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
let email = claims.email.unwrap_or_default();
|
||||
let name = claims.name.unwrap_or_default();
|
||||
let oidc_sub = claims.sub.unwrap_or_default();
|
||||
let azure_oid = claims.oid.unwrap_or_default();
|
||||
let preferred_username = claims.preferred_username.unwrap_or_else(|| email.clone());
|
||||
|
||||
let provider_subject = if !oidc_sub.is_empty() {
|
||||
oidc_sub.clone()
|
||||
} else if !azure_oid.is_empty() {
|
||||
azure_oid.clone()
|
||||
} else {
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
"Missing subject identifier in id_token",
|
||||
));
|
||||
};
|
||||
|
||||
if email.is_empty() {
|
||||
return Err(error_redirect("sso_error", "Missing email in id_token"));
|
||||
}
|
||||
|
||||
let auth_provider = match config.provider_type.as_str() {
|
||||
"keycloak" => "keycloak",
|
||||
"azure" => "azure_sso",
|
||||
_ => "oidc",
|
||||
};
|
||||
|
||||
// First try exact match: email AND auth_provider
|
||||
let user_opt: Option<DbUserForSso> = match sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
|
||||
FROM users WHERE email = $1 AND auth_provider = $2::auth_provider"#,
|
||||
)
|
||||
.bind(&email)
|
||||
.bind(auth_provider)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to look up SSO user");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
},
|
||||
};
|
||||
|
||||
let user = match user_opt {
|
||||
Some(u) if !u.is_active => {
|
||||
return Err(error_redirect("account_disabled", "Account is disabled"));
|
||||
},
|
||||
Some(u) => u,
|
||||
None => {
|
||||
// Try to find existing user by email alone (may have different auth_provider)
|
||||
let existing_user: Option<DbUserForSso> = match sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
|
||||
FROM users WHERE email = $1"#,
|
||||
)
|
||||
.bind(&email)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to look up existing user by email");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
},
|
||||
};
|
||||
|
||||
match existing_user {
|
||||
Some(existing) if !existing.is_active => {
|
||||
return Err(error_redirect("account_disabled", "Account is disabled"));
|
||||
},
|
||||
Some(existing) => {
|
||||
// Link existing local user to SSO provider
|
||||
tracing::info!(user_id = %existing.id, "Linking existing user to SSO provider");
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE users SET auth_provider = $1::auth_provider, azure_oid = COALESCE(azure_oid, $2), oidc_sub = COALESCE(oidc_sub, $3) WHERE id = $4",
|
||||
)
|
||||
.bind(auth_provider)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.bind(existing.id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, "Failed to link user to SSO provider");
|
||||
return Err(error_redirect("internal_error", "Failed to link SSO account"));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
None,
|
||||
Some(auth_provider),
|
||||
Some("user"),
|
||||
Some(&existing.id.to_string()),
|
||||
json!({ "action": "sso_link", "auth_provider": auth_provider, "email": email }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
DbUserForSso {
|
||||
id: existing.id,
|
||||
username: existing.username.clone(),
|
||||
display_name: if name.is_empty() {
|
||||
existing.display_name.clone()
|
||||
} else {
|
||||
name
|
||||
},
|
||||
role: existing.role.clone(),
|
||||
is_active: existing.is_active,
|
||||
mfa_enabled: existing.mfa_enabled,
|
||||
}
|
||||
},
|
||||
None => {
|
||||
// No existing user - create new one
|
||||
let id: Uuid = match sqlx::query_scalar(
|
||||
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid, oidc_sub)
|
||||
VALUES ($1, $2, $3, 'reporter'::user_role, $4::auth_provider, $5, $6)
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(&preferred_username)
|
||||
.bind(&name)
|
||||
.bind(&email)
|
||||
.bind(auth_provider)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to create SSO user");
|
||||
return Err(error_redirect("internal_error", "Failed to create user"));
|
||||
},
|
||||
};
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
None,
|
||||
Some(auth_provider),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "auth_provider": auth_provider, "email": email }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
DbUserForSso {
|
||||
id,
|
||||
username: preferred_username,
|
||||
display_name: name,
|
||||
role: "reporter".to_string(),
|
||||
is_active: true,
|
||||
mfa_enabled: false,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Update last_login_at and provider subject IDs
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1), oidc_sub = COALESCE(oidc_sub, $2) WHERE id = $3",
|
||||
)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.bind(user.id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, "Failed to update last_login_at");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
}
|
||||
|
||||
let access_ttl = state.config.security.jwt_access_ttl_secs as i64;
|
||||
let access_token = match issue_access_token(
|
||||
user.id,
|
||||
&user.username,
|
||||
&user.role,
|
||||
access_ttl,
|
||||
&state.signing_key_pem,
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to issue access token");
|
||||
return Err(error_redirect("internal_error", "Token issuance failed"));
|
||||
},
|
||||
};
|
||||
|
||||
let raw_refresh = match refresh::issue(&state.db, user.id, None, None).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to issue refresh token");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Refresh token issuance failed",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserLogin,
|
||||
Some(user.id),
|
||||
Some(&user.username),
|
||||
None,
|
||||
None,
|
||||
json!({ "auth_provider": auth_provider }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let user_json = json!({
|
||||
"id": user.id.to_string(),
|
||||
"username": user.username,
|
||||
"display_name": user.display_name,
|
||||
"role": user.role,
|
||||
"auth_provider": auth_provider,
|
||||
"mfa_enabled": user.mfa_enabled,
|
||||
});
|
||||
|
||||
let redirect_url = format!(
|
||||
"{}?access_token={}&refresh_token={}&token_type=Bearer&expires_in={}&user={}",
|
||||
callback_url,
|
||||
urlencoding::encode(&access_token),
|
||||
urlencoding::encode(&raw_refresh.0),
|
||||
access_ttl,
|
||||
urlencoding::encode(&user_json.to_string()),
|
||||
);
|
||||
|
||||
Ok(Redirect::to(&redirect_url))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Backward-compatible Azure SSO redirect handlers
|
||||
// ============================================================
|
||||
|
||||
async fn azure_login_redirect(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
sso_login(State(state)).await
|
||||
}
|
||||
|
||||
async fn azure_callback_redirect(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<Redirect, Redirect> {
|
||||
sso_callback(State(state), axum::extract::Query(params)).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Database helpers
|
||||
// ============================================================
|
||||
|
||||
async fn load_oidc_config(pool: &sqlx::PgPool) -> Result<OidcConfig, (StatusCode, Json<Value>)> {
|
||||
let row: Option<OidcConfig> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(row.unwrap_or(OidcConfig {
|
||||
enabled: false,
|
||||
provider_type: "azure".to_string(),
|
||||
display_name: "Azure AD".to_string(),
|
||||
discovery_url: String::new(),
|
||||
client_id: String::new(),
|
||||
client_secret: String::new(),
|
||||
redirect_uri: String::new(),
|
||||
scopes: "openid profile email".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OIDC Discovery & JWKS
|
||||
// ============================================================
|
||||
|
||||
async fn fetch_discovery(state: &AppState) -> Result<OidcDiscovery, String> {
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err("Failed to load OIDC config".to_string());
|
||||
},
|
||||
};
|
||||
let discovery_url = config.discovery_url;
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let cache = state.oidc_cache.lock().await;
|
||||
if let Some(ref disc) = cache.discovery {
|
||||
let elapsed = Utc::now().signed_duration_since(disc.fetched_at);
|
||||
if elapsed.num_seconds() < DISCOVERY_CACHE_TTL_SECS {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Discovery fetch failed: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"Discovery fetch failed: HTTP {} — {}",
|
||||
status, body
|
||||
));
|
||||
}
|
||||
|
||||
let doc: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse discovery document: {}", e))?;
|
||||
|
||||
let discovery = OidcDiscovery {
|
||||
issuer: doc
|
||||
.get("issuer")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
authorization_endpoint: doc
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
token_endpoint: doc
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
jwks_uri: doc
|
||||
.get("jwks_uri")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
userinfo_endpoint: doc
|
||||
.get("userinfo_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string()),
|
||||
fetched_at: Utc::now(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut cache = state.oidc_cache.lock().await;
|
||||
cache.discovery = Some(discovery.clone());
|
||||
}
|
||||
|
||||
Ok(discovery)
|
||||
}
|
||||
|
||||
async fn verify_id_token(
|
||||
token: &str,
|
||||
config: &OidcConfig,
|
||||
discovery: &OidcDiscovery,
|
||||
oidc_cache: &Arc<Mutex<OidcCache>>,
|
||||
) -> Result<IdTokenClaims, String> {
|
||||
let header = decode_header(token).map_err(|e| format!("Failed to decode JWT header: {}", e))?;
|
||||
let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
|
||||
|
||||
let jwks = {
|
||||
let cache = oidc_cache.lock().await;
|
||||
let needs_fetch = match (&cache.jwks, &cache.jwks_fetched_at) {
|
||||
(None, _) => true,
|
||||
(Some(_), None) => true,
|
||||
(Some(_), Some(fetched)) => {
|
||||
let elapsed = Utc::now().signed_duration_since(*fetched);
|
||||
elapsed.num_seconds() > JWKS_CACHE_TTL_SECS
|
||||
},
|
||||
};
|
||||
|
||||
if needs_fetch {
|
||||
drop(cache);
|
||||
let jwks_value = fetch_jwks(&discovery.jwks_uri).await?;
|
||||
let mut cache = oidc_cache.lock().await;
|
||||
cache.jwks = Some(jwks_value);
|
||||
cache.jwks_fetched_at = Some(Utc::now());
|
||||
cache.jwks.clone().unwrap()
|
||||
} else {
|
||||
cache.jwks.clone().unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
let keys_array = jwks
|
||||
.get("keys")
|
||||
.ok_or("JWKS response missing 'keys' array")?
|
||||
.as_array()
|
||||
.ok_or("JWKS 'keys' is not an array")?;
|
||||
|
||||
let jwk = keys_array
|
||||
.iter()
|
||||
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()))
|
||||
.ok_or_else(|| format!("No matching JWK found for kid: {}", kid))?;
|
||||
|
||||
let n = jwk
|
||||
.get("n")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or("JWK missing 'n' (modulus) field")?;
|
||||
let e = jwk
|
||||
.get("e")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or("JWK missing 'e' (exponent) field")?;
|
||||
|
||||
let decoding_key = DecodingKey::from_rsa_components(n, e)
|
||||
.map_err(|e| format!("Failed to construct RSA decoding key: {}", e))?;
|
||||
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.iss = Some(HashSet::from([discovery.issuer.clone()]));
|
||||
validation.aud = Some(HashSet::from([config.client_id.clone()]));
|
||||
validation.leeway = 60;
|
||||
|
||||
let token_data = decode::<IdTokenClaims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("JWT signature verification failed: {}", e))?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
async fn fetch_jwks(jwks_uri: &str) -> Result<serde_json::Value, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build HTTP client for JWKS fetch: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(jwks_uri)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("JWKS fetch request failed: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("JWKS fetch failed: HTTP {} — {}", status, body));
|
||||
}
|
||||
|
||||
resp.json::<serde_json::Value>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse JWKS response: {}", e))
|
||||
}
|
||||
145
crates/pm-web/src/routes/status.rs
Executable file
145
crates/pm-web/src/routes/status.rs
Executable file
@ -0,0 +1,145 @@
|
||||
//! Fleet status routes.
|
||||
//!
|
||||
//! GET /api/v1/status/fleet — aggregate health and patch summary across all hosts.
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::Json, routing::get, Router};
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new().route("/fleet", get(fleet_status))
|
||||
}
|
||||
|
||||
// ── Response type ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FleetStatus {
|
||||
pub total_hosts: i64,
|
||||
pub healthy: i64,
|
||||
pub degraded: i64,
|
||||
pub unreachable: i64,
|
||||
pub pending: i64,
|
||||
pub total_pending_patches: i64,
|
||||
pub hosts_requiring_reboot: i64,
|
||||
pub compliance_pct: f64,
|
||||
}
|
||||
|
||||
// ── GET /api/v1/status/fleet ──────────────────────────────────────────────────
|
||||
|
||||
pub async fn fleet_status(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<FleetStatus>, (StatusCode, Json<Value>)> {
|
||||
// ── 1. Host health aggregates ─────────────────────────────────────────
|
||||
let health_row: (i64, i64, i64, i64, i64) = sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) AS total_hosts,
|
||||
COUNT(*) FILTER (WHERE health_status = 'healthy') AS healthy,
|
||||
COUNT(*) FILTER (WHERE health_status = 'degraded') AS degraded,
|
||||
COUNT(*) FILTER (WHERE health_status = 'unreachable') AS unreachable,
|
||||
COUNT(*) FILTER (WHERE health_status = 'pending') AS pending
|
||||
FROM hosts
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query host health aggregates");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (total_hosts, healthy, degraded, unreachable, pending) = health_row;
|
||||
|
||||
// ── 2. Total pending patches across fleet (latest row per host) ───────
|
||||
let total_pending_patches: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COALESCE(SUM(patch_count), 0)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) patch_count
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query total pending patches");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// ── 3. Hosts requiring a reboot (latest patch row per host) ───────────
|
||||
let hosts_requiring_reboot: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) available_patches
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
WHERE available_patches @> '[{"requires_reboot": true}]'
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query reboot-required hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// ── 4. Compliance: hosts with zero pending patches / total hosts ───────
|
||||
// Hosts that have been polled and have patch_count == 0 are considered
|
||||
// compliant. Hosts with no patch data at all are excluded from the
|
||||
// compliance calculation.
|
||||
let compliant_hosts: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) patch_count
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
WHERE patch_count = 0
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query compliant hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let compliance_pct = if total_hosts == 0 {
|
||||
100.0_f64
|
||||
} else {
|
||||
(compliant_hosts as f64 / total_hosts as f64) * 100.0
|
||||
};
|
||||
|
||||
// Round to one decimal place.
|
||||
let compliance_pct = (compliance_pct * 10.0).round() / 10.0;
|
||||
|
||||
Ok(Json(FleetStatus {
|
||||
total_hosts,
|
||||
healthy,
|
||||
degraded,
|
||||
unreachable,
|
||||
pending,
|
||||
total_pending_patches,
|
||||
hosts_requiring_reboot,
|
||||
compliance_pct,
|
||||
}))
|
||||
}
|
||||
571
crates/pm-web/src/routes/users.rs
Executable file
571
crates/pm-web/src/routes/users.rs
Executable file
@ -0,0 +1,571 @@
|
||||
//! User management routes.
|
||||
//!
|
||||
//! GET /api/v1/users — list users (admin only)
|
||||
//! POST /api/v1/users — create user (admin only)
|
||||
//! GET /api/v1/users/:id — get user detail
|
||||
//! PUT /api/v1/users/:id — update user
|
||||
//! DELETE /api/v1/users/:id — delete user (admin only)
|
||||
//! GET /api/v1/users/me — current user profile
|
||||
//! POST /api/v1/users/:id/revoke — revoke all sessions (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{delete, get, post, put},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::validate_password_strength;
|
||||
use pm_auth::{hash_password, rbac::AuthUser, session::force_logout, verify_password};
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{
|
||||
AdminResetPasswordRequest, ChangePasswordRequest, CreateUserRequest, UpdateUserRequest,
|
||||
User,
|
||||
},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_users).post(create_user))
|
||||
.route("/me", get(get_current_user))
|
||||
.route("/me/password", put(change_own_password))
|
||||
.route("/{id}", get(get_user).put(update_user).delete(delete_user))
|
||||
.route("/{id}/password", put(admin_reset_password))
|
||||
.route("/{id}/mfa", delete(admin_disable_mfa))
|
||||
.route("/{id}/revoke", post(revoke_user_sessions))
|
||||
}
|
||||
|
||||
async fn list_users(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Vec<User>>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query_as::<_, User>(
|
||||
r#"SELECT id, username, display_name, email, role, auth_provider,
|
||||
mfa_enabled, is_active, force_password_reset, last_login_at,
|
||||
created_at, updated_at
|
||||
FROM users ORDER BY username"#,
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateUserRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate password strength
|
||||
if let Err(msg) = validate_password_strength(&req.password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
let hash = hash_password(&req.password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let role = match req.role.to_lowercase().as_str() {
|
||||
"admin" => "admin",
|
||||
"reporter" => "reporter",
|
||||
_ => "operator",
|
||||
};
|
||||
|
||||
let id: Uuid = sqlx::query_scalar(
|
||||
r#"INSERT INTO users (username, display_name, email, role, auth_provider, password_hash)
|
||||
VALUES ($1, $2, $3, $4::user_role, 'local', $5)
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(&req.username)
|
||||
.bind(req.display_name.as_deref().unwrap_or(&req.username))
|
||||
.bind(&req.email)
|
||||
.bind(role)
|
||||
.bind(&hash)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"Username or email already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "username": req.username }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "id": id, "message": "User created" })))
|
||||
}
|
||||
|
||||
async fn get_current_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
|
||||
fetch_user(&state.db, auth.user_id).await
|
||||
}
|
||||
|
||||
async fn get_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
|
||||
// Users can see themselves; admin can see anyone
|
||||
if !auth.role.is_admin() && auth.user_id != id {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
fetch_user(&state.db, id).await
|
||||
}
|
||||
|
||||
async fn fetch_user(
|
||||
pool: &sqlx::PgPool,
|
||||
id: Uuid,
|
||||
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
|
||||
let user: Option<User> = sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, email, role, auth_provider,
|
||||
mfa_enabled, is_active, force_password_reset, last_login_at,
|
||||
created_at, updated_at
|
||||
FROM users WHERE id = $1"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
user.map(Json).ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<UpdateUserRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() && auth.user_id != id {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
// Only admins can change role or active status
|
||||
if (req.role.is_some() || req.is_active.is_some() || req.force_password_reset.is_some())
|
||||
&& !auth.role.is_admin()
|
||||
{
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "Admin role required to change role, status, or force_password_reset" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let role_str = req
|
||||
.role
|
||||
.as_deref()
|
||||
.map(|r| match r.to_lowercase().as_str() {
|
||||
"admin" => "admin",
|
||||
"reporter" => "reporter",
|
||||
_ => "operator",
|
||||
});
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"UPDATE users SET
|
||||
display_name = COALESCE($1, display_name),
|
||||
email = COALESCE($2, email),
|
||||
role = COALESCE($3::user_role, role),
|
||||
is_active = COALESCE($4, is_active),
|
||||
force_password_reset = COALESCE($5, force_password_reset),
|
||||
updated_at = NOW()
|
||||
WHERE id = $6"#,
|
||||
)
|
||||
.bind(req.display_name.as_deref())
|
||||
.bind(req.email.as_deref())
|
||||
.bind(role_str)
|
||||
.bind(req.is_active)
|
||||
.bind(req.force_password_reset)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User updated" })))
|
||||
}
|
||||
|
||||
async fn delete_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
if auth.user_id == id {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": "Cannot delete your own account" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query("DELETE FROM users WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User deleted" })))
|
||||
}
|
||||
|
||||
async fn revoke_user_sessions(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let count = force_logout(&state.db, id).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(
|
||||
json!({ "message": "Sessions revoked", "count": count }),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/users/me/password — change own password
|
||||
// ============================================================
|
||||
|
||||
async fn change_own_password(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<ChangePasswordRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Fetch current password hash
|
||||
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
|
||||
.bind(auth.user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to fetch password hash");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?
|
||||
.flatten();
|
||||
|
||||
let hash_str = hash.unwrap_or_default();
|
||||
let valid = verify_password(&req.current_password, &hash_str).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate new password strength
|
||||
if let Err(msg) = validate_password_strength(&req.new_password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
let new_hash = hash_password(&req.new_password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, updated_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(&new_hash)
|
||||
.bind(auth.user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update password");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&auth.user_id.to_string()),
|
||||
json!({ "action": "password_change" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Password changed successfully" })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/users/:id/password — admin reset password
|
||||
// ============================================================
|
||||
|
||||
async fn admin_reset_password(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<AdminResetPasswordRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Verify target user exists
|
||||
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !exists {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate new password strength
|
||||
if let Err(msg) = validate_password_strength(&req.new_password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
let new_hash = hash_password(&req.new_password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users SET password_hash = $1, force_password_reset = $2, updated_at = NOW() WHERE id = $3",
|
||||
)
|
||||
.bind(&new_hash)
|
||||
.bind(req.force_password_reset)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to reset password");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to reset password" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "action": "admin_password_reset", "force_password_reset": req.force_password_reset }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Password reset successfully" })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// DELETE /api/v1/users/:id/mfa — admin disable MFA
|
||||
// ============================================================
|
||||
|
||||
async fn admin_disable_mfa(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE, updated_at = NOW() WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to disable MFA");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "action": "admin_mfa_disabled" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "MFA disabled successfully" })))
|
||||
}
|
||||
205
crates/pm-web/src/routes/ws.rs
Executable file
205
crates/pm-web/src/routes/ws.rs
Executable file
@ -0,0 +1,205 @@
|
||||
//! WebSocket relay routes — M7
|
||||
//!
|
||||
//! POST /api/v1/ws/ticket — create a single-use WS auth ticket (JWT-protected)
|
||||
//! GET /api/v1/ws/jobs — browser WebSocket endpoint (ticket-authenticated)
|
||||
|
||||
use axum::{
|
||||
extract::ws::{Message, WebSocket},
|
||||
extract::{Query, State, WebSocketUpgrade},
|
||||
http::StatusCode,
|
||||
response::{Json, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::postgres::PgListener;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── WsTicket ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Single-use WebSocket authentication ticket stored in-memory.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WsTicket {
|
||||
pub user_id: Uuid,
|
||||
pub role: String,
|
||||
pub expires_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Router for ticket-issuance endpoint (JWT-protected, merged into protected_api).
|
||||
pub fn ticket_router() -> Router<AppState> {
|
||||
Router::new().route("/ws/ticket", post(create_ticket_handler))
|
||||
}
|
||||
|
||||
/// Router for the WebSocket endpoint (ticket-authenticated, NO JWT middleware).
|
||||
pub fn ws_router() -> Router<AppState> {
|
||||
Router::new().route("/api/v1/ws/jobs", get(ws_handler))
|
||||
}
|
||||
|
||||
// ── Error helper ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── POST /api/v1/ws/ticket ────────────────────────────────────────────────────
|
||||
|
||||
/// Issue a single-use WebSocket authentication ticket (60 s expiry).
|
||||
pub async fn create_ticket_handler(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let ticket_id = Ulid::new().to_string();
|
||||
let expires_at = Utc::now() + Duration::seconds(60);
|
||||
|
||||
let ticket = WsTicket {
|
||||
user_id: auth.user_id,
|
||||
role: auth.role.as_str().to_string(),
|
||||
expires_at,
|
||||
};
|
||||
|
||||
state.ws_tickets.insert(ticket_id.clone(), ticket);
|
||||
|
||||
tracing::info!(
|
||||
user_id = %auth.user_id,
|
||||
username = %auth.username,
|
||||
ticket = %ticket_id,
|
||||
"WS ticket issued"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "ticket": ticket_id })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/ws/jobs ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsQuery {
|
||||
pub ticket: String,
|
||||
}
|
||||
|
||||
/// Browser WebSocket upgrade endpoint — authenticates via single-use ticket.
|
||||
pub async fn ws_handler(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<WsQuery>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, (StatusCode, Json<Value>)> {
|
||||
// Validate and consume the ticket atomically.
|
||||
let ticket = {
|
||||
let entry = state.ws_tickets.get(&q.ticket);
|
||||
match entry {
|
||||
None => {
|
||||
return Err(err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_ticket",
|
||||
"WebSocket ticket not found or already used",
|
||||
));
|
||||
},
|
||||
Some(t) => {
|
||||
if t.expires_at < Utc::now() {
|
||||
drop(t);
|
||||
state.ws_tickets.remove(&q.ticket);
|
||||
return Err(err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"ticket_expired",
|
||||
"WebSocket ticket has expired",
|
||||
));
|
||||
}
|
||||
t.clone()
|
||||
},
|
||||
}
|
||||
};
|
||||
// Single-use: remove immediately after validation.
|
||||
state.ws_tickets.remove(&q.ticket);
|
||||
|
||||
tracing::info!(
|
||||
user_id = %ticket.user_id,
|
||||
role = %ticket.role,
|
||||
"Browser WebSocket connection upgraded"
|
||||
);
|
||||
|
||||
let db = state.db.clone();
|
||||
Ok(ws.on_upgrade(move |socket| handle_browser_ws(socket, db, ticket)))
|
||||
}
|
||||
|
||||
// ── WebSocket handler ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Drive the browser WebSocket: LISTEN on `job_update` and forward payloads.
|
||||
async fn handle_browser_ws(mut socket: WebSocket, db: sqlx::PgPool, ticket: WsTicket) {
|
||||
// Acquire a dedicated PG listener connection.
|
||||
let mut listener = match PgListener::connect_with(&db).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "Failed to create PgListener");
|
||||
let _ = socket
|
||||
.send(Message::Text(
|
||||
json!({ "error": "internal_error" }).to_string().into(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = listener.listen("job_update").await {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener LISTEN failed");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS: LISTEN job_update started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward PG notifications to the browser.
|
||||
notify_result = listener.recv() => {
|
||||
match notify_result {
|
||||
Ok(notification) => {
|
||||
let payload = notification.payload().to_string();
|
||||
tracing::debug!(user_id = %ticket.user_id, payload = %payload, "Forwarding job_update");
|
||||
if socket.send(Message::Text(payload.into())).await.is_err() {
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS send failed — client disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener recv error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle incoming frames from the browser (ping/close).
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS closed by client");
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) if socket.send(Message::Pong(data.clone())).await.is_err() => {
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
tracing::debug!(error = %e, user_id = %ticket.user_id, "Browser WS recv error");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS handler exiting");
|
||||
}
|
||||
Reference in New Issue
Block a user