Private
Public Access
1
0
Files
linux_patch_manager/crates/pm-web/src/main.rs
Echo d326b25203
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 53s
CI Pipeline / Rust Unit Tests (push) Successful in 1m11s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 14s
CI Pipeline / Build .deb & Release (push) Has been skipped
fix(ca): make CA path configurable and prevent encrypted keys
- main.rs: use config.security.ca_cert_path parent directory instead
  of hardcoded /etc/patch-manager/ca for CA initialization.
- config.example.toml: add warning that CA key must be unencrypted PEM.
- This prevents silent generation of a second CA on fresh installs
  and ensures the manager always uses the configured CA.
2026-05-18 15:58:38 +00:00

311 lines
12 KiB
Rust

//! pm-web — Linux Patch Manager web server.
mod routes;
use axum::{extract::State, http::StatusCode, middleware, response::Json, routing::get, Router};
use axum_server::tls_rustls::RustlsConfig;
use dashmap::DashMap;
use pm_auth::{
jwt,
rbac::{require_auth, AuthConfig},
};
use pm_core::{
config::AppConfig, db, logging, models::PkiBundle, request_id::request_id_middleware,
};
use routes::sso::{OidcCache, SsoSession};
use routes::ws::WsTicket;
use serde_json::{json, Value};
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
/// Shared application state threaded through Axum.
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::PgPool,
pub config: Arc<AppConfig>,
pub signing_key_pem: String,
pub auth_config: Arc<AuthConfig>,
/// In-memory store for single-use WebSocket authentication tickets.
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
/// In-memory store for SSO PKCE sessions (state → code_verifier).
pub sso_sessions: Arc<DashMap<String, SsoSession>>,
/// Cached OIDC discovery document and JWKS for SSO id_token verification.
pub oidc_cache: Arc<Mutex<OidcCache>>,
/// Internal certificate authority for mTLS client cert issuance.
pub ca: Arc<pm_ca::CertAuthority>,
/// IP-based rate limits for enrollment requests.
pub enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>>,
/// Short-lived cache for approved enrollment PKI bundles.
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Install the default crypto provider for rustls (required since 0.23)
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
eprintln!("Config file not found or invalid, using defaults");
AppConfig::default()
});
logging::init(&config.logging);
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"patch-manager-web starting"
);
let signing_key_pem = jwt::load_signing_key(&config.security.jwt_signing_key_path)
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "JWT signing key not found (dev mode)");
String::new()
});
let verify_key_pem =
jwt::load_verify_key(&config.security.jwt_verify_key_path).unwrap_or_else(|e| {
tracing::warn!(error = %e, "JWT verify key not found (dev mode)");
String::new()
});
let auth_config = Arc::new(AuthConfig::new(
verify_key_pem,
&config.security.ip_whitelist,
));
let pool = db::init_pool(&config.database).await?;
db::run_migrations(&pool).await?;
// Initialise the internal CA using the configured certificate paths.
// The CA certificate and key must exist at the configured locations and be
// unencrypted PEM. If absent, a new CA is generated in that directory.
let ca_base = std::path::Path::new(&config.security.ca_cert_path)
.parent()
.expect("CA certificate path must have a parent directory");
let ca = pm_ca::CertAuthority::init(ca_base, &pool)
.await
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "CA init failed (dev mode)");
panic!("CA initialization failed: {}", e);
});
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
let enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>> = Arc::new(DashMap::new());
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
// Background task: purge expired WS tickets every 30 seconds.
{
let tickets = ws_tickets.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
let now = chrono::Utc::now();
let before = tickets.len();
tickets.retain(|_, v| v.expires_at > now);
let removed = before.saturating_sub(tickets.len());
if removed > 0 {
tracing::debug!(removed, "Purged expired WS tickets");
}
}
});
}
// Background task: purge expired SSO sessions every 60 seconds (sessions older than 10 minutes).
{
let sessions = sso_sessions.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
let now = chrono::Utc::now();
let cutoff = now - chrono::Duration::minutes(10);
let before = sessions.len();
sessions.retain(|_, v| v.created_at > cutoff);
let removed = before.saturating_sub(sessions.len());
if removed > 0 {
tracing::debug!(removed, "Purged expired SSO sessions");
}
}
});
}
// Background task: purge expired enrollment rate limits every 5 minutes.
{
let limits = enrollment_rate_limits.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300));
loop {
interval.tick().await;
let now = Instant::now();
limits.retain(|_, v| now.duration_since(*v) < Duration::from_secs(3600));
}
});
}
// Background task: purge approved enrollment PKI bundles every 10 minutes.
{
let approved = approved_enrollments.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(600));
loop {
interval.tick().await;
approved.clear();
}
});
}
let state = AppState {
db: pool,
config: Arc::new(config.clone()),
signing_key_pem,
auth_config,
ws_tickets,
sso_sessions,
ca: Arc::new(ca),
enrollment_rate_limits,
approved_enrollments,
oidc_cache,
};
let app = build_router(state);
let addr: SocketAddr = format!("{}:{}", config.server.host, config.server.port)
.parse()
.expect("Invalid bind address");
// Try to load TLS certificate and key; fall back to plain HTTP if missing.
let tls_cert = std::path::Path::new(&config.security.web_tls_cert_path);
let tls_key = std::path::Path::new(&config.security.web_tls_key_path);
if tls_cert.exists() && tls_key.exists() {
let tls_config = RustlsConfig::from_pem_file(
&config.security.web_tls_cert_path,
&config.security.web_tls_key_path,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load TLS certificates");
e
})?;
tracing::info!(%addr, "Listening (HTTPS)");
axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service())
.await?;
} else {
tracing::warn!(
cert_path = %config.security.web_tls_cert_path,
key_path = %config.security.web_tls_key_path,
"TLS certificates not found — falling back to plain HTTP. \
This is insecure for production!"
);
tracing::info!(%addr, "Listening (HTTP — no TLS)");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
}
Ok(())
}
/// Construct the full Axum router.
pub fn build_router(state: AppState) -> Router {
let static_dir = state.config.server.static_dir.clone();
let auth_config = state.auth_config.clone();
// All protected API routes — require valid JWT
let protected_api = Router::new()
// Auth: MFA setup/verify
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
.nest("/auth", routes::auth::protected_router())
// Hosts
.nest("/hosts", routes::hosts::router())
// Host-scoped certificate endpoints (merged separately to avoid conflict)
.nest("/hosts", routes::ca::host_cert_router())
// Groups
.nest("/groups", routes::groups::router())
// Users
.nest("/users", routes::users::router())
// Discovery
.nest("/discovery", routes::discovery::router())
// Fleet status
.nest("/status", routes::status::router())
// Patch jobs
.nest("/jobs", routes::jobs::router())
// Maintenance windows (nested under hosts path param)
.nest(
"/hosts/{host_id}/maintenance-windows",
routes::maintenance_windows::router(),
)
// CA root certificate download
.nest("/ca", routes::ca::ca_router())
// Certificate list / renew / revoke
.nest("/certificates", routes::ca::certs_router())
// WS ticket issuance (JWT-protected — ticket returned to browser, then used for WS upgrade)
.merge(routes::ws::ticket_router())
// Reports
.nest("/reports", routes::reports::router())
.nest(
"/hosts/{host_id}/health-checks",
routes::health_checks::router(),
)
// Settings (admin-only)
.nest("/settings", routes::settings::router())
// Admin enrollment routes (JWT protected, Admin role enforced)
.nest("/admin", routes::enrollment::admin_router())
// Apply auth middleware to all the above
.route_layer(middleware::from_fn(move |req, next| {
let auth_config = auth_config.clone();
require_auth(auth_config, req, next)
}));
Router::new()
.route("/status/health", get(health_handler))
// Public auth routes (no JWT needed)
.nest("/api/v1/auth", routes::auth::public_router())
// Public enrollment endpoints (rate-limited, no JWT)
.nest("/api/v1", routes::enrollment::router())
// Public SSO routes (no JWT needed)
.nest("/api/v1/auth/sso", routes::sso::public_router())
// Public Azure SSO routes (no JWT needed)
.nest("/api/v1/auth/azure", routes::sso::azure_compat_router())
// Protected API routes (JWT required)
.nest("/api/v1", protected_api)
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
.merge(routes::ws::ws_router())
// Serve React SPA
.fallback_service(
ServeDir::new(&static_dir)
.append_index_html_on_directories(true)
.fallback(ServeFile::new(format!("{}/index.html", static_dir))),
)
.layer(middleware::from_fn(request_id_middleware))
.layer(TraceLayer::new_for_http())
.with_state(state)
}
async fn health_handler(State(state): State<AppState>) -> Result<Json<Value>, StatusCode> {
let db_ok = sqlx::query("SELECT 1").execute(&state.db).await.is_ok();
let status = if db_ok { "healthy" } else { "degraded" };
let body = json!({ "service": "patch-manager-web", "version": env!("CARGO_PKG_VERSION"), "status": status, "database": if db_ok { "ok" } else { "error" } });
if db_ok {
Ok(Json(body))
} else {
Err(StatusCode::SERVICE_UNAVAILABLE)
}
}