From 5a4d4d583ecaa75ba0c04383355430ca0e750cd9 Mon Sep 17 00:00:00 2001 From: Echo Date: Fri, 24 Apr 2026 15:32:50 +0000 Subject: [PATCH] style: Apply rustfmt with stable-only config - Fixed rustfmt.toml to only use stable options (removed nightly-only) - Applied cargo fmt --all to fix formatting violations - Stable options: edition=2021, max_width=100, reorder_imports/modules, match_block_trailing_comma --- crates/pm-agent-client/src/client.rs | 40 +-- crates/pm-agent-client/src/lib.rs | 5 +- crates/pm-auth/src/jwt.rs | 11 +- crates/pm-auth/src/mfa_totp.rs | 6 +- crates/pm-auth/src/password.rs | 17 +- crates/pm-auth/src/rbac.rs | 8 +- crates/pm-auth/src/refresh.rs | 20 +- crates/pm-auth/src/session.rs | 18 +- crates/pm-ca/src/ca.rs | 72 ++-- crates/pm-core/src/audit.rs | 27 +- crates/pm-core/src/config.rs | 2 +- crates/pm-core/src/db.rs | 10 +- crates/pm-core/src/error.rs | 14 +- crates/pm-core/src/lib.rs | 13 +- crates/pm-core/src/logging.rs | 9 +- crates/pm-core/src/models.rs | 14 +- crates/pm-core/src/request_id.rs | 7 +- crates/pm-reports/src/csv.rs | 163 +++++---- crates/pm-reports/src/pdf.rs | 322 ++++++++++++------ crates/pm-web/src/main.rs | 55 ++- crates/pm-web/src/routes/auth.rs | 40 ++- crates/pm-web/src/routes/azure_sso.rs | 69 ++-- crates/pm-web/src/routes/ca.rs | 54 ++- crates/pm-web/src/routes/discovery.rs | 106 ++++-- crates/pm-web/src/routes/groups.rs | 216 +++++++++--- crates/pm-web/src/routes/hosts.rs | 119 ++++--- crates/pm-web/src/routes/jobs.rs | 121 ++++--- .../pm-web/src/routes/maintenance_windows.rs | 126 ++++--- crates/pm-web/src/routes/mod.rs | 6 +- crates/pm-web/src/routes/reports.rs | 52 +-- crates/pm-web/src/routes/settings.rs | 85 +++-- crates/pm-web/src/routes/status.rs | 8 +- crates/pm-web/src/routes/users.rs | 179 ++++++++-- crates/pm-web/src/routes/ws.rs | 15 +- crates/pm-worker/src/agent_loader.rs | 12 +- crates/pm-worker/src/email.rs | 43 ++- crates/pm-worker/src/health_poller.rs | 51 +-- crates/pm-worker/src/job_executor.rs | 172 +++++----- crates/pm-worker/src/main.rs | 30 +- crates/pm-worker/src/maintenance_scheduler.rs | 2 +- crates/pm-worker/src/patch_poller.rs | 39 +-- crates/pm-worker/src/refresh_listener.rs | 62 ++-- crates/pm-worker/src/ws_relay.rs | 86 +++-- rustfmt.toml | 12 +- 44 files changed, 1498 insertions(+), 1040 deletions(-) diff --git a/crates/pm-agent-client/src/client.rs b/crates/pm-agent-client/src/client.rs index e023753..5b5e7df 100644 --- a/crates/pm-agent-client/src/client.rs +++ b/crates/pm-agent-client/src/client.rs @@ -22,18 +22,15 @@ use std::time::Duration; -use reqwest::{ - tls::Version, - Certificate, ClientBuilder, Identity, -}; +use reqwest::{tls::Version, Certificate, ClientBuilder, Identity}; use serde::{de::DeserializeOwned, Serialize}; use tracing::{debug, instrument}; use crate::{ error::AgentClientError, types::{ - AgentEnvelope, HealthData, PackagesData, PatchesData, SystemInfoData, - ApplyPatchesRequest, ApplyPatchesResponse, AgentJobStatus, RollbackResponse, + AgentEnvelope, AgentJobStatus, ApplyPatchesRequest, ApplyPatchesResponse, HealthData, + PackagesData, PatchesData, RollbackResponse, SystemInfoData, }, }; @@ -151,11 +148,7 @@ impl AgentClient { /// Execute a GET request against `{base_url}/{path}` with optional query /// parameters, deserialize the [`AgentEnvelope`], and extract the `data` /// field — or propagate an [`AgentClientError::ApiError`]. - async fn get( - &self, - path: &str, - query: &[(&str, &str)], - ) -> Result + async fn get(&self, path: &str, query: &[(&str, &str)]) -> Result where T: DeserializeOwned, { @@ -190,11 +183,7 @@ impl AgentClient { // Fallback: use the HTTP status as the error indicator. return Err(AgentClientError::ApiError { code: status.as_str().to_string(), - message: format!( - "Agent returned HTTP {} for {}", - status.as_u16(), - url - ), + message: format!("Agent returned HTTP {} for {}", status.as_u16(), url), }); } @@ -220,21 +209,16 @@ impl AgentClient { /// `GET /api/v1/jobs/{id}` — poll an async agent job for status. #[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))] - pub async fn job_status( - &self, - job_id: &str, - ) -> Result { + pub async fn job_status(&self, job_id: &str) -> Result { self.get(&format!("jobs/{}", job_id), &[]).await } /// `POST /api/v1/jobs/{id}/rollback` — trigger rollback on the agent. #[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))] - pub async fn rollback_job( - &self, - job_id: &str, - ) -> Result { + pub async fn rollback_job(&self, job_id: &str) -> Result { let empty: serde_json::Value = serde_json::json!({}); - self.post(&format!("jobs/{}/rollback", job_id), &empty).await + self.post(&format!("jobs/{}/rollback", job_id), &empty) + .await } // -------------------------------------------------------- @@ -244,11 +228,7 @@ impl AgentClient { /// Execute a POST request against `{base_url}/{path}`, serialize `body` as /// JSON, deserialize the [`AgentEnvelope`], and extract the `data` field — /// or propagate an [`AgentClientError::ApiError`]. - async fn post( - &self, - path: &str, - body: &Req, - ) -> Result + async fn post(&self, path: &str, body: &Req) -> Result where Req: Serialize, Resp: DeserializeOwned, diff --git a/crates/pm-agent-client/src/lib.rs b/crates/pm-agent-client/src/lib.rs index 572ad39..3fd1fad 100644 --- a/crates/pm-agent-client/src/lib.rs +++ b/crates/pm-agent-client/src/lib.rs @@ -38,9 +38,6 @@ pub use error::AgentClientError; /// Response envelope and all data types. pub use types::{ - AgentEnvelope, AgentErrorBody, - HealthData, - Package, PackagesData, - Patch, PatchesData, + AgentEnvelope, AgentErrorBody, HealthData, Package, PackagesData, Patch, PatchesData, SystemInfoData, }; diff --git a/crates/pm-auth/src/jwt.rs b/crates/pm-auth/src/jwt.rs index d1f8dff..f4ef247 100644 --- a/crates/pm-auth/src/jwt.rs +++ b/crates/pm-auth/src/jwt.rs @@ -91,10 +91,7 @@ pub fn issue_access_token( } /// Validate and decode an access token using the Ed25519 public key PEM. -pub fn validate_access_token( - token: &str, - verify_key_pem: &str, -) -> Result { +pub fn validate_access_token(token: &str, verify_key_pem: &str) -> Result { let key = DecodingKey::from_ed_pem(verify_key_pem.as_bytes()) .map_err(|e| JwtError::KeyLoad(e.to_string()))?; @@ -115,14 +112,12 @@ pub fn validate_access_token( /// Load the Ed25519 signing key from a PEM file path. pub fn load_signing_key(path: &str) -> Result { - std::fs::read_to_string(path) - .map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}"))) + std::fs::read_to_string(path).map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}"))) } /// Load the Ed25519 verification (public) key from a PEM file path. pub fn load_verify_key(path: &str) -> Result { - std::fs::read_to_string(path) - .map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}"))) + std::fs::read_to_string(path).map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}"))) } #[cfg(test)] diff --git a/crates/pm-auth/src/mfa_totp.rs b/crates/pm-auth/src/mfa_totp.rs index 2c68579..1bd1e1b 100644 --- a/crates/pm-auth/src/mfa_totp.rs +++ b/crates/pm-auth/src/mfa_totp.rs @@ -66,9 +66,9 @@ fn build_totp(username: &str, secret_base32: &str) -> Result { // new(issuer, account_name, algorithm, digits, skew, step, secret) TOTP::new( Algorithm::SHA1, - 6, // digits - 1, // skew - 30, // step (seconds) + 6, // digits + 1, // skew + 30, // step (seconds) secret_bytes, Some(ISSUER.to_string()), username.to_string(), diff --git a/crates/pm-auth/src/password.rs b/crates/pm-auth/src/password.rs index fb3d2b5..c18491d 100644 --- a/crates/pm-auth/src/password.rs +++ b/crates/pm-auth/src/password.rs @@ -7,17 +7,15 @@ //! - Parallelism: 1 use argon2::{ - password_hash::{ - rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString, - }, + password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2, Params, Version, }; use thiserror::Error; /// Argon2id parameters per spec. const M_COST: u32 = 65536; // 64 MiB -const T_COST: u32 = 3; // 3 iterations -const P_COST: u32 = 1; // 1 thread +const T_COST: u32 = 3; // 3 iterations +const P_COST: u32 = 1; // 1 thread #[derive(Debug, Error)] pub enum PasswordError { @@ -33,7 +31,11 @@ pub enum PasswordError { fn argon2() -> Result, PasswordError> { let params = Params::new(M_COST, T_COST, P_COST, None) .map_err(|e| PasswordError::HashError(e.to_string()))?; - Ok(Argon2::new(argon2::Algorithm::Argon2id, Version::V0x13, params)) + Ok(Argon2::new( + argon2::Algorithm::Argon2id, + Version::V0x13, + params, + )) } /// Hash a plaintext password using Argon2id with a random salt. @@ -54,8 +56,7 @@ pub fn hash_password(password: &str) -> Result { /// /// Returns `Ok(true)` if the password matches, `Ok(false)` if not. pub fn verify_password(password: &str, hash: &str) -> Result { - let parsed_hash = - PasswordHash::new(hash).map_err(|_| PasswordError::InvalidHash)?; + let parsed_hash = PasswordHash::new(hash).map_err(|_| PasswordError::InvalidHash)?; let argon2 = argon2()?; diff --git a/crates/pm-auth/src/rbac.rs b/crates/pm-auth/src/rbac.rs index 4cd77b6..1ae1ad6 100644 --- a/crates/pm-auth/src/rbac.rs +++ b/crates/pm-auth/src/rbac.rs @@ -143,11 +143,7 @@ fn forbidden(message: &str) -> Response { /// /// Inserts `AuthUser` into request extensions on success. /// Rejects with 401 if token is missing/invalid, 403 if IP is blocked. -pub async fn require_auth( - auth_config: Arc, - mut req: Request, - next: Next, -) -> Response { +pub async fn require_auth(auth_config: Arc, mut req: Request, next: Next) -> Response { // IP whitelist check if let Some(ip) = extract_remote_ip(req.headers()) { if !auth_config.is_ip_allowed(&ip) { @@ -168,7 +164,7 @@ pub async fn require_auth( Err(e) => { tracing::debug!(error = %e, "JWT validation failed"); return unauthorized("Invalid token"); - } + }, }; let role = match UserRole::from_str(&claims.role) { diff --git a/crates/pm-auth/src/refresh.rs b/crates/pm-auth/src/refresh.rs index 4a97cb0..287b1c0 100644 --- a/crates/pm-auth/src/refresh.rs +++ b/crates/pm-auth/src/refresh.rs @@ -123,12 +123,10 @@ pub async fn rotate( } // Revoke old token - sqlx::query( - "UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE id = $1", - ) - .bind(stored.id) - .execute(pool) - .await?; + sqlx::query("UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE id = $1") + .bind(stored.id) + .execute(pool) + .await?; // Issue new token let new_token = issue(pool, stored.user_id, user_agent, ip_address).await?; @@ -138,10 +136,7 @@ pub async fn rotate( } /// Revoke all refresh tokens for a user (force logout). -pub async fn revoke_all_for_user( - pool: &PgPool, - user_id: Uuid, -) -> Result { +pub async fn revoke_all_for_user(pool: &PgPool, user_id: Uuid) -> Result { let result = sqlx::query( "UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE user_id = $1 AND revoked = FALSE", ) @@ -154,10 +149,7 @@ pub async fn revoke_all_for_user( } /// Revoke a single refresh token by its raw value. -pub async fn revoke( - pool: &PgPool, - raw_token: &str, -) -> Result<(), RefreshError> { +pub async fn revoke(pool: &PgPool, raw_token: &str) -> Result<(), RefreshError> { let hash = hex::encode(Sha256::digest(raw_token.as_bytes())); sqlx::query( diff --git a/crates/pm-auth/src/session.rs b/crates/pm-auth/src/session.rs index de0e710..648ddf5 100644 --- a/crates/pm-auth/src/session.rs +++ b/crates/pm-auth/src/session.rs @@ -122,13 +122,12 @@ pub async fn login( // Prevent timing-based username enumeration let _ = password::hash_password("dummy-timing-fill"); return Err(SessionError::InvalidCredentials); - } + }, }; // 2. Verify password let hash = user.password_hash.as_deref().unwrap_or(""); - let valid = password::verify_password(&req.password, hash) - .unwrap_or(false); + let valid = password::verify_password(&req.password, hash).unwrap_or(false); if !valid { tracing::warn!(username = %req.username, "Login failed: invalid password"); @@ -146,8 +145,7 @@ pub async fn login( let code = req.totp_code.as_deref().ok_or(SessionError::MfaRequired)?; let secret = user.totp_secret.as_deref().unwrap_or(""); - let mfa_ok = mfa_totp::verify_code(&user.username, secret, code) - .unwrap_or(false); + let mfa_ok = mfa_totp::verify_code(&user.username, secret, code).unwrap_or(false); if !mfa_ok { tracing::warn!(username = %req.username, "Login failed: invalid MFA code"); @@ -246,19 +244,13 @@ pub async fn refresh_session( } /// Logout: revoke the current refresh token. -pub async fn logout( - pool: &PgPool, - raw_refresh_token: &str, -) -> Result<(), SessionError> { +pub async fn logout(pool: &PgPool, raw_refresh_token: &str) -> Result<(), SessionError> { refresh::revoke(pool, raw_refresh_token).await?; Ok(()) } /// Force-logout: revoke all refresh tokens for a user. -pub async fn force_logout( - pool: &PgPool, - user_id: Uuid, -) -> Result { +pub async fn force_logout(pool: &PgPool, user_id: Uuid) -> Result { let count = refresh::revoke_all_for_user(pool, user_id).await?; Ok(count) } diff --git a/crates/pm-ca/src/ca.rs b/crates/pm-ca/src/ca.rs index b46e1f5..fe88f2c 100644 --- a/crates/pm-ca/src/ca.rs +++ b/crates/pm-ca/src/ca.rs @@ -13,8 +13,8 @@ use chrono::{DateTime, Duration as ChronoDuration, Utc}; use rand::RngCore; use rcgen::{ BasicConstraints, Certificate, CertificateParams, DistinguishedName, DnType, - ExtendedKeyUsagePurpose, Ia5String, IsCa, KeyPair, KeyUsagePurpose, SanType, - SerialNumber, PKCS_ECDSA_P256_SHA256, + ExtendedKeyUsagePurpose, Ia5String, IsCa, KeyPair, KeyUsagePurpose, SanType, SerialNumber, + PKCS_ECDSA_P256_SHA256, }; use sqlx::{PgPool, Row}; use time::{Duration as TimeDuration, OffsetDateTime}; @@ -83,10 +83,7 @@ fn chrono_offset_days(days: i64) -> DateTime { /// Build a `CertificateParams` with common fields pre-filled. /// Caller still needs to set `is_ca`, `key_usages`, `extended_key_usages`, and `subject_alt_names`. -fn base_params( - cn: &str, - validity_days: i64, -) -> Result<(CertificateParams, String, DateTime)> { +fn base_params(cn: &str, validity_days: i64) -> Result<(CertificateParams, String, DateTime)> { let (serial, serial_hex) = make_serial(); let expires_at = chrono_offset_days(validity_days); @@ -144,8 +141,7 @@ impl CertAuthority { .context("read ca.crt")?; // Validate that both PEMs parse without error. - KeyPair::from_pem(&ca_key_pem) - .context("parse CA private-key PEM")?; + KeyPair::from_pem(&ca_key_pem).context("parse CA private-key PEM")?; CertificateParams::from_ca_cert_pem(&ca_cert_pem) .context("parse CA certificate PEM")?; @@ -166,8 +162,8 @@ impl CertAuthority { .await .context("create CA directory")?; - let ca_key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256) - .context("generate CA key pair")?; + let ca_key = + KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate CA key pair")?; let (serial, serial_hex) = make_serial(); let expires_at = chrono_offset_days(365 * 10); @@ -177,20 +173,18 @@ impl CertAuthority { params.not_after = odt_offset_days(365 * 10); params.serial_number = Some(serial); params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - params.key_usages = vec![ - KeyUsagePurpose::KeyCertSign, - KeyUsagePurpose::CrlSign, - ]; + params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign]; let mut dn = DistinguishedName::new(); dn.push(DnType::CommonName, "Patch Manager Root CA"); dn.push(DnType::OrganizationName, "Patch Manager"); params.distinguished_name = dn; - let ca_cert_obj = params.self_signed(&ca_key) + let ca_cert_obj = params + .self_signed(&ca_key) .context("self-sign CA certificate")?; let ca_cert_pem = ca_cert_obj.pem(); - let ca_key_pem = ca_key.serialize_pem(); + let ca_key_pem = ca_key.serialize_pem(); write_protected(&key_path, &ca_key_pem) .await @@ -256,8 +250,8 @@ impl CertAuthority { ) -> Result { tracing::info!(host_id = %host_id, hostname, "Issuing mTLS client certificate"); - let key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256) - .context("generate client key pair")?; + let key = + KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate client key pair")?; let (mut params, serial_hex, expires_at) = base_params(hostname, 365)?; params.is_ca = IsCa::ExplicitNoCa; @@ -270,7 +264,7 @@ impl CertAuthority { .context("sign client cert with CA")?; let cert_pem = cert.pem(); - let key_pem = key.serialize_pem(); + let key_pem = key.serialize_pem(); sqlx::query( "INSERT INTO certificates \ @@ -294,7 +288,12 @@ impl CertAuthority { "Client certificate issued successfully" ); - Ok(IssuedCert { cert_pem, key_pem, serial_number: serial_hex, expires_at }) + Ok(IssuedCert { + cert_pem, + key_pem, + serial_number: serial_hex, + expires_at, + }) } /// Revoke a certificate by database ID. @@ -328,18 +327,16 @@ impl CertAuthority { tracing::info!(cert_id = %cert_id, "Renewing certificate"); // Fetch the existing cert's host_id and common_name. - let row = sqlx::query( - "SELECT host_id, common_name FROM certificates WHERE id = $1", - ) - .bind(cert_id) - .fetch_one(db) - .await - .context("fetch certificate for renewal")?; + let row = sqlx::query("SELECT host_id, common_name FROM certificates WHERE id = $1") + .bind(cert_id) + .fetch_one(db) + .await + .context("fetch certificate for renewal")?; - let host_id: Uuid = row.try_get("host_id") + let host_id: Uuid = row + .try_get("host_id") .context("certificate has no host_id (cannot renew root CA)")?; - let common_name: String = row.try_get("common_name") - .context("fetch common_name")?; + let common_name: String = row.try_get("common_name").context("fetch common_name")?; // Revoke the old cert first. self.revoke_cert(cert_id, db).await?; @@ -364,14 +361,11 @@ impl CertAuthority { /// /// Returns `(cert_pem, key_pem)`. This certificate is **not** stored in the /// database; it is intended for runtime use only. - pub async fn issue_web_tls_cert( - &self, - hostname: &str, - ) -> Result<(String, String)> { + pub async fn issue_web_tls_cert(&self, hostname: &str) -> Result<(String, String)> { tracing::info!(hostname, "Issuing web TLS certificate"); - let key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256) - .context("generate web TLS key pair")?; + let key = + KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate web TLS key pair")?; let (mut params, serial_hex, expires_at) = base_params(hostname, 365)?; params.is_ca = IsCa::ExplicitNoCa; @@ -387,7 +381,7 @@ impl CertAuthority { .context("sign web TLS cert with CA")?; let cert_pem = cert.pem(); - let key_pem = key.serialize_pem(); + let key_pem = key.serialize_pem(); tracing::info!( hostname, @@ -408,8 +402,8 @@ impl CertAuthority { /// The returned `Certificate` is used solely as an issuer reference when /// signing leaf certificates; it is never distributed directly. fn ca_objects(&self) -> Result<(KeyPair, Certificate)> { - let key = KeyPair::from_pem(&self.ca_key_pem) - .context("reconstruct CA key pair from PEM")?; + let key = + KeyPair::from_pem(&self.ca_key_pem).context("reconstruct CA key pair from PEM")?; let params = CertificateParams::from_ca_cert_pem(&self.ca_cert_pem) .context("reconstruct CA params from PEM")?; let cert = params diff --git a/crates/pm-core/src/audit.rs b/crates/pm-core/src/audit.rs index 757031c..09f19b4 100644 --- a/crates/pm-core/src/audit.rs +++ b/crates/pm-core/src/audit.rs @@ -101,8 +101,15 @@ pub async fn log_event( request_id: Option<&str>, ) { let result = write_audit_row( - pool, action, actor_user_id, actor_username, - target_type, target_id, details, ip_address, request_id, + pool, + action, + actor_user_id, + actor_username, + target_type, + target_id, + details, + ip_address, + request_id, ) .await; @@ -123,11 +130,10 @@ async fn write_audit_row( request_id: Option<&str>, ) -> Result<(), sqlx::Error> { // Fetch previous hash for chain - let prev_hash: Option = sqlx::query_scalar( - "SELECT row_hash FROM audit_log ORDER BY id DESC LIMIT 1", - ) - .fetch_optional(pool) - .await?; + let prev_hash: Option = + sqlx::query_scalar("SELECT row_hash FROM audit_log ORDER BY id DESC LIMIT 1") + .fetch_optional(pool) + .await?; let prev = prev_hash.unwrap_or_default(); let now = chrono::Utc::now().to_rfc3339(); @@ -245,7 +251,7 @@ pub async fn verify_integrity(pool: &PgPool) -> IntegrityResult { rows_checked: 0, errors: vec![], }; - } + }, }; let mut errors = Vec::new(); @@ -273,10 +279,7 @@ pub async fn verify_integrity(pool: &PgPool) -> IntegrityResult { .unwrap_or_default(); let ip_str = row.ip_address.as_deref().unwrap_or(""); let rid = row.request_id.as_deref().unwrap_or(""); - let created_str = row - .created_at - .map(|c| c.to_rfc3339()) - .unwrap_or_default(); + let created_str = row.created_at.map(|c| c.to_rfc3339()).unwrap_or_default(); let mut hasher = Sha256::new(); hasher.update(row.prev_hash.as_bytes()); diff --git a/crates/pm-core/src/config.rs b/crates/pm-core/src/config.rs index 357d567..a9e94e4 100644 --- a/crates/pm-core/src/config.rs +++ b/crates/pm-core/src/config.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use config::{Config, ConfigError, Environment, File}; +use serde::{Deserialize, Serialize}; /// Top-level application configuration. #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/crates/pm-core/src/db.rs b/crates/pm-core/src/db.rs index 453d403..a9311bc 100644 --- a/crates/pm-core/src/db.rs +++ b/crates/pm-core/src/db.rs @@ -1,6 +1,6 @@ +use crate::config::DatabaseConfig; use sqlx::postgres::{PgPool, PgPoolOptions}; use std::time::Duration; -use crate::config::DatabaseConfig; /// Initialize and return a PostgreSQL connection pool. pub async fn init_pool(cfg: &DatabaseConfig) -> Result { @@ -59,11 +59,9 @@ pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::migrate::MigrateE /// Check that the database schema is at the expected version. /// Used by the worker to wait until migrations have been applied. pub async fn check_schema_version(pool: &PgPool) -> Result { - let row: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM _sqlx_migrations WHERE success = true", - ) - .fetch_one(pool) - .await?; + let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_migrations WHERE success = true") + .fetch_one(pool) + .await?; Ok(row.0) } diff --git a/crates/pm-core/src/error.rs b/crates/pm-core/src/error.rs index e9cd4b0..2c4d046 100644 --- a/crates/pm-core/src/error.rs +++ b/crates/pm-core/src/error.rs @@ -86,9 +86,11 @@ impl IntoResponse for AppError { AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, "forbidden", msg.clone()), AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, "bad_request", msg.clone()), AppError::Conflict(msg) => (StatusCode::CONFLICT, "conflict", msg.clone()), - AppError::UnprocessableEntity(msg) => { - (StatusCode::UNPROCESSABLE_ENTITY, "unprocessable_entity", msg.clone()) - } + AppError::UnprocessableEntity(msg) => ( + StatusCode::UNPROCESSABLE_ENTITY, + "unprocessable_entity", + msg.clone(), + ), AppError::Database(e) => { tracing::error!(error = %e, "Database error"); ( @@ -96,7 +98,7 @@ impl IntoResponse for AppError { "internal_error", "An internal error occurred".to_string(), ) - } + }, AppError::Internal(e) => { tracing::error!(error = %e, "Internal error"); ( @@ -104,7 +106,7 @@ impl IntoResponse for AppError { "internal_error", "An internal error occurred".to_string(), ) - } + }, AppError::Config(msg) => { tracing::error!(error = %msg, "Configuration error"); ( @@ -112,7 +114,7 @@ impl IntoResponse for AppError { "config_error", "Server configuration error".to_string(), ) - } + }, }; let body = ErrorResponse::new(code, message); diff --git a/crates/pm-core/src/lib.rs b/crates/pm-core/src/lib.rs index f9c6473..2f12b50 100644 --- a/crates/pm-core/src/lib.rs +++ b/crates/pm-core/src/lib.rs @@ -1,20 +1,19 @@ +pub mod audit; pub mod config; pub mod db; pub mod error; pub mod logging; pub mod models; -pub mod audit; pub mod request_id; // Re-export commonly used types -pub use error::{AppError, ErrorResponse}; pub use config::AppConfig; +pub use error::{AppError, ErrorResponse}; pub use models::{ - Host, HostSummary, HostHealthStatus, CreateHostRequest, - Group, CreateGroupRequest, UpdateGroupRequest, - User, UserRole as DbUserRole, AuthProvider, CreateUserRequest, UpdateUserRequest, - DiscoveryResult, DiscoveryCidrRequest, RegisterDiscoveredRequest, + AuthProvider, CreateGroupRequest, CreateHostRequest, CreateUserRequest, DiscoveryCidrRequest, + DiscoveryResult, Group, Host, HostHealthStatus, HostSummary, RegisterDiscoveredRequest, + UpdateGroupRequest, UpdateUserRequest, User, UserRole as DbUserRole, }; // Re-export audit integrity types -pub use audit::{verify_integrity, IntegrityResult, IntegrityError}; +pub use audit::{verify_integrity, IntegrityError, IntegrityResult}; diff --git a/crates/pm-core/src/logging.rs b/crates/pm-core/src/logging.rs index c6a286e..188821c 100644 --- a/crates/pm-core/src/logging.rs +++ b/crates/pm-core/src/logging.rs @@ -1,5 +1,5 @@ -use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use crate::config::LoggingConfig; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; /// Initialize the global tracing subscriber. /// @@ -10,8 +10,7 @@ use crate::config::LoggingConfig; /// Log level is controlled by `cfg.level` (e.g. `"info"`, `"debug"`). /// The `RUST_LOG` environment variable overrides `cfg.level`. pub fn init(cfg: &LoggingConfig) { - let filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new(&cfg.level)); + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cfg.level)); match cfg.format.as_str() { "json" => { @@ -19,13 +18,13 @@ pub fn init(cfg: &LoggingConfig) { .with(filter) .with(fmt::layer().json().with_current_span(true)) .init(); - } + }, _ => { tracing_subscriber::registry() .with(filter) .with(fmt::layer().pretty()) .init(); - } + }, } tracing::info!(format = %cfg.format, level = %cfg.level, "Logging initialized"); diff --git a/crates/pm-core/src/models.rs b/crates/pm-core/src/models.rs index 7dcb3cb..bf7c799 100644 --- a/crates/pm-core/src/models.rs +++ b/crates/pm-core/src/models.rs @@ -211,11 +211,11 @@ pub enum JobStatus { impl std::fmt::Display for JobStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Queued => write!(f, "queued"), - Self::Pending => write!(f, "pending"), - Self::Running => write!(f, "running"), + Self::Queued => write!(f, "queued"), + Self::Pending => write!(f, "pending"), + Self::Running => write!(f, "running"), Self::Succeeded => write!(f, "succeeded"), - Self::Failed => write!(f, "failed"), + Self::Failed => write!(f, "failed"), Self::Cancelled => write!(f, "cancelled"), } } @@ -321,9 +321,9 @@ pub enum WindowRecurrence { impl std::fmt::Display for WindowRecurrence { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Once => write!(f, "once"), - Self::Daily => write!(f, "daily"), - Self::Weekly => write!(f, "weekly"), + Self::Once => write!(f, "once"), + Self::Daily => write!(f, "daily"), + Self::Weekly => write!(f, "weekly"), Self::Monthly => write!(f, "monthly"), } } diff --git a/crates/pm-core/src/request_id.rs b/crates/pm-core/src/request_id.rs index 68cce01..e5d8dc8 100644 --- a/crates/pm-core/src/request_id.rs +++ b/crates/pm-core/src/request_id.rs @@ -1,9 +1,4 @@ -use axum::{ - extract::Request, - http::HeaderValue, - middleware::Next, - response::Response, -}; +use axum::{extract::Request, http::HeaderValue, middleware::Next, response::Response}; use ulid::Ulid; /// HTTP header name for request correlation IDs. diff --git a/crates/pm-reports/src/csv.rs b/crates/pm-reports/src/csv.rs index 5707d8f..0820113 100644 --- a/crates/pm-reports/src/csv.rs +++ b/crates/pm-reports/src/csv.rs @@ -4,10 +4,7 @@ use crate::{ReportParams, ReportType}; use anyhow::Context; /// Generate a CSV report and return the raw bytes. -pub async fn generate_csv( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +pub async fn generate_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { match params.report_type { ReportType::Compliance => compliance_csv(pool, params).await, ReportType::PatchHistory => patch_history_csv(pool, params).await, @@ -20,12 +17,10 @@ pub async fn generate_csv( // Compliance // --------------------------------------------------------------------------- -async fn compliance_csv( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +async fn compliance_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { let rows = if let Some(gid) = params.group_id { - sqlx::query(" + sqlx::query( + " SELECT h.id::text AS host_id, h.display_name, @@ -47,13 +42,15 @@ WHERE h.id IN ( ) GROUP BY h.id, pd.total_packages, pd.pending_patches ORDER BY compliance_pct ASC -") +", + ) .bind(gid) .fetch_all(pool) .await .context("compliance query (group filter) failed")? } else { - sqlx::query(" + sqlx::query( + " SELECT h.id::text AS host_id, h.display_name, @@ -72,7 +69,8 @@ LEFT JOIN host_groups hg ON hg.host_id = h.id LEFT JOIN groups g ON g.id = hg.group_id GROUP BY h.id, pd.total_packages, pd.pending_patches ORDER BY compliance_pct ASC -") +", + ) .fetch_all(pool) .await .context("compliance query failed")? @@ -80,23 +78,29 @@ ORDER BY compliance_pct ASC let mut wtr = csv::Writer::from_writer(vec![]); wtr.write_record(&[ - "host_id", "display_name", "fqdn", "group_names", - "total_packages", "pending_patches", "compliance_pct", - "last_patch_at", "health_status", + "host_id", + "display_name", + "fqdn", + "group_names", + "total_packages", + "pending_patches", + "compliance_pct", + "last_patch_at", + "health_status", ])?; for row in &rows { use sqlx::Row; - let host_id: String = row.try_get("host_id").unwrap_or_default(); - let display_name: String = row.try_get("display_name").unwrap_or_default(); - let fqdn: String = row.try_get("fqdn").unwrap_or_default(); - let group_names: String = row.try_get("group_names").unwrap_or_default(); - let total_packages: i64 = row.try_get("total_packages").unwrap_or(0); - let pending_patches: i64 = row.try_get("pending_patches").unwrap_or(0); - let compliance_pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0); + let host_id: String = row.try_get("host_id").unwrap_or_default(); + let display_name: String = row.try_get("display_name").unwrap_or_default(); + let fqdn: String = row.try_get("fqdn").unwrap_or_default(); + let group_names: String = row.try_get("group_names").unwrap_or_default(); + let total_packages: i64 = row.try_get("total_packages").unwrap_or(0); + let pending_patches: i64 = row.try_get("pending_patches").unwrap_or(0); + let compliance_pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0); let last_patch_at: Option> = row.try_get("last_patch_at").unwrap_or(None); - let health_status: String = row.try_get("health_status").unwrap_or_default(); + let health_status: String = row.try_get("health_status").unwrap_or_default(); wtr.write_record(&[ host_id, @@ -118,11 +122,9 @@ ORDER BY compliance_pct ASC // Patch history // --------------------------------------------------------------------------- -async fn patch_history_csv( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { - let rows = sqlx::query(" +async fn patch_history_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { + let rows = sqlx::query( + " SELECT pj.id::text AS job_id, pj.kind::text AS job_kind, @@ -141,7 +143,8 @@ LEFT JOIN users u ON u.id = pj.created_by_user_id WHERE ($1::timestamptz IS NULL OR pjh.started_at >= $1) AND ($2::timestamptz IS NULL OR pjh.started_at <= $2) ORDER BY pjh.started_at DESC -") +", + ) .bind(params.from) .bind(params.to) .fetch_all(pool) @@ -150,24 +153,32 @@ ORDER BY pjh.started_at DESC let mut wtr = csv::Writer::from_writer(vec![]); wtr.write_record(&[ - "job_id", "job_kind", "job_status", "host_display_name", "host_fqdn", - "package_count", "started_at", "completed_at", "duration_seconds", "operator", + "job_id", + "job_kind", + "job_status", + "host_display_name", + "host_fqdn", + "package_count", + "started_at", + "completed_at", + "duration_seconds", + "operator", ])?; for row in &rows { use sqlx::Row; - let job_id: String = row.try_get("job_id").unwrap_or_default(); - let job_kind: String = row.try_get("job_kind").unwrap_or_default(); - let job_status: String = row.try_get("job_status").unwrap_or_default(); + let job_id: String = row.try_get("job_id").unwrap_or_default(); + let job_kind: String = row.try_get("job_kind").unwrap_or_default(); + let job_status: String = row.try_get("job_status").unwrap_or_default(); let display_name: String = row.try_get("display_name").unwrap_or_default(); - let fqdn: String = row.try_get("fqdn").unwrap_or_default(); - let package_count: i64 = row.try_get("package_count").unwrap_or(0); + let fqdn: String = row.try_get("fqdn").unwrap_or_default(); + let package_count: i64 = row.try_get("package_count").unwrap_or(0); let started_at: Option> = row.try_get("started_at").unwrap_or(None); let completed_at: Option> = row.try_get("completed_at").unwrap_or(None); let duration_seconds: Option = row.try_get("duration_seconds").unwrap_or(None); - let operator: String = row.try_get("operator").unwrap_or_default(); + let operator: String = row.try_get("operator").unwrap_or_default(); wtr.write_record(&[ job_id, @@ -190,17 +201,21 @@ ORDER BY pjh.started_at DESC // Vulnerability // --------------------------------------------------------------------------- -async fn vulnerability_csv( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +async fn vulnerability_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { let mut wtr = csv::Writer::from_writer(vec![]); wtr.write_record(&[ - "host_id", "display_name", "fqdn", "cve_id", - "package_name", "severity", "available_version", "last_seen_at", + "host_id", + "display_name", + "fqdn", + "cve_id", + "package_name", + "severity", + "available_version", + "last_seen_at", ])?; - let result = sqlx::query(" + let result = sqlx::query( + " SELECT h.id::text AS host_id, h.display_name, @@ -224,7 +239,8 @@ ORDER BY ELSE 4 END, h.display_name -") +", + ) .bind(params.from) .bind(params.to) .fetch_all(pool) @@ -234,12 +250,12 @@ ORDER BY Ok(rows) => { for row in &rows { use sqlx::Row; - let host_id: String = row.try_get("host_id").unwrap_or_default(); - let display_name: String = row.try_get("display_name").unwrap_or_default(); - let fqdn: String = row.try_get("fqdn").unwrap_or_default(); - let cve_id: String = row.try_get("cve_id").unwrap_or_default(); - let package_name: String = row.try_get("package_name").unwrap_or_default(); - let severity: String = row.try_get("severity").unwrap_or_default(); + let host_id: String = row.try_get("host_id").unwrap_or_default(); + let display_name: String = row.try_get("display_name").unwrap_or_default(); + let fqdn: String = row.try_get("fqdn").unwrap_or_default(); + let cve_id: String = row.try_get("cve_id").unwrap_or_default(); + let package_name: String = row.try_get("package_name").unwrap_or_default(); + let severity: String = row.try_get("severity").unwrap_or_default(); let available_version: String = row.try_get("available_version").unwrap_or_default(); let last_seen_at: Option> = @@ -256,15 +272,21 @@ ORDER BY last_seen_at.map(|d| d.to_rfc3339()).unwrap_or_default(), ])?; } - } + }, Err(e) => { tracing::warn!(error = %e, "vulnerability query failed — returning empty rows"); // write a comment row indicating empty data wtr.write_record(&[ - "(no data)", "", "", "", "", "", "", + "(no data)", + "", + "", + "", + "", + "", + "", &format!("query error: {}", e), ])?; - } + }, } Ok(wtr.into_inner().context("csv flush failed")?) @@ -274,11 +296,9 @@ ORDER BY // Audit // --------------------------------------------------------------------------- -async fn audit_csv( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { - let rows = sqlx::query(" +async fn audit_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { + let rows = sqlx::query( + " SELECT id::text AS id, created_at, @@ -293,7 +313,8 @@ WHERE ($1::timestamptz IS NULL OR created_at >= $1) AND ($2::timestamptz IS NULL OR created_at <= $2) ORDER BY created_at DESC LIMIT 10000 -") +", + ) .bind(params.from) .bind(params.to) .fetch_all(pool) @@ -302,21 +323,27 @@ LIMIT 10000 let mut wtr = csv::Writer::from_writer(vec![]); wtr.write_record(&[ - "id", "created_at", "action", "actor_username", - "target_type", "target_id", "ip_address", "request_id", + "id", + "created_at", + "action", + "actor_username", + "target_type", + "target_id", + "ip_address", + "request_id", ])?; for row in &rows { use sqlx::Row; - let id: String = row.try_get("id").unwrap_or_default(); + let id: String = row.try_get("id").unwrap_or_default(); let created_at: Option> = row.try_get("created_at").unwrap_or(None); - let action: String = row.try_get("action").unwrap_or_default(); + let action: String = row.try_get("action").unwrap_or_default(); let actor_username: String = row.try_get("actor_username").unwrap_or_default(); - let target_type: String = row.try_get("target_type").unwrap_or_default(); - let target_id: String = row.try_get("target_id").unwrap_or_default(); - let ip_address: String = row.try_get("ip_address").unwrap_or_default(); - let request_id: String = row.try_get("request_id").unwrap_or_default(); + let target_type: String = row.try_get("target_type").unwrap_or_default(); + let target_id: String = row.try_get("target_id").unwrap_or_default(); + let ip_address: String = row.try_get("ip_address").unwrap_or_default(); + let request_id: String = row.try_get("request_id").unwrap_or_default(); wtr.write_record(&[ id, diff --git a/crates/pm-reports/src/pdf.rs b/crates/pm-reports/src/pdf.rs index 03c6368..c8d9c7c 100644 --- a/crates/pm-reports/src/pdf.rs +++ b/crates/pm-reports/src/pdf.rs @@ -6,9 +6,8 @@ use crate::{ReportParams, ReportType}; use anyhow::Context; use plotters::prelude::*; use printpdf::{ - BuiltinFont, ColorBits, ColorSpace, Image, ImageTransform, ImageXObject, - IndirectFontRef, Mm, PdfDocument, PdfLayerIndex, PdfLayerReference, - PdfPageIndex, Px, + BuiltinFont, ColorBits, ColorSpace, Image, ImageTransform, ImageXObject, IndirectFontRef, Mm, + PdfDocument, PdfLayerIndex, PdfLayerReference, PdfPageIndex, Px, }; const PAGE_W: f32 = 297.0; // A4 landscape width (mm) @@ -23,15 +22,12 @@ const NEW_PAGE_THRESHOLD: f32 = 20.0; // --------------------------------------------------------------------------- /// Generate a PDF report and return the raw bytes. -pub async fn generate_pdf( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +pub async fn generate_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { match params.report_type { - ReportType::Compliance => compliance_pdf(pool, params).await, - ReportType::PatchHistory => patch_history_pdf(pool, params).await, + ReportType::Compliance => compliance_pdf(pool, params).await, + ReportType::PatchHistory => patch_history_pdf(pool, params).await, ReportType::Vulnerability => vulnerability_pdf(pool, params).await, - ReportType::Audit => audit_pdf(pool, params).await, + ReportType::Audit => audit_pdf(pool, params).await, } } @@ -51,15 +47,10 @@ fn render_bar_chart( let mut pixel_buf = vec![0u8; (W * H * 3) as usize]; { - let root = BitMapBackend::with_buffer(&mut pixel_buf, (W, H)) - .into_drawing_area(); + let root = BitMapBackend::with_buffer(&mut pixel_buf, (W, H)).into_drawing_area(); root.fill(&WHITE)?; - let max_val = values - .iter() - .cloned() - .fold(0.0_f64, f64::max) - .max(1.0); + let max_val = values.iter().cloned().fold(0.0_f64, f64::max).max(1.0); let n = labels.len().max(1); let mut chart = ChartBuilder::on(&root) @@ -76,7 +67,11 @@ fn render_bar_chart( labels .get(*idx) .map(|s| { - if s.len() > 12 { s[..12].to_string() } else { s.clone() } + if s.len() > 12 { + s[..12].to_string() + } else { + s.clone() + } }) .unwrap_or_default() }) @@ -119,7 +114,7 @@ impl PdfBuilder { fn new(title: &str) -> anyhow::Result { let doc = PdfDocument::empty(title); let (page_idx, layer_idx) = doc.add_page(Mm(PAGE_W), Mm(PAGE_H), "Layer 1"); - let font = doc.add_builtin_font(BuiltinFont::Helvetica)?; + let font = doc.add_builtin_font(BuiltinFont::Helvetica)?; let font_bold = doc.add_builtin_font(BuiltinFont::HelveticaBold)?; Ok(Self { doc, @@ -212,19 +207,31 @@ impl PdfBuilder { fn write_title_page(pdf: &mut PdfBuilder, title: &str, params: &ReportParams) { pdf.write_text(title, 24.0, MARGIN, 160.0, true); pdf.write_text( - &format!("Generated: {}", chrono::Utc::now().format("%Y-%m-%d %H:%M UTC")), - 11.0, MARGIN, 148.0, false, + &format!( + "Generated: {}", + chrono::Utc::now().format("%Y-%m-%d %H:%M UTC") + ), + 11.0, + MARGIN, + 148.0, + false, ); if let Some(from) = params.from { pdf.write_text( &format!("From: {}", from.format("%Y-%m-%d")), - 10.0, MARGIN, 140.0, false, + 10.0, + MARGIN, + 140.0, + false, ); } if let Some(to) = params.to { pdf.write_text( &format!("To: {}", to.format("%Y-%m-%d")), - 10.0, MARGIN, 134.0, false, + 10.0, + MARGIN, + 134.0, + false, ); } if let Some(gid) = params.group_id { @@ -237,13 +244,11 @@ fn write_title_page(pdf: &mut PdfBuilder, title: &str, params: &ReportParams) { // Compliance PDF // --------------------------------------------------------------------------- -async fn compliance_pdf( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +async fn compliance_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { use sqlx::Row; let rows = if let Some(gid) = params.group_id { - sqlx::query(" + sqlx::query( + " SELECT h.display_name, h.fqdn, COALESCE(pd.total_packages,0) AS total_packages, COALESCE(pd.pending_patches,0) AS pending_patches, @@ -254,11 +259,15 @@ SELECT h.display_name, h.fqdn, FROM hosts h LEFT JOIN host_patch_data pd ON pd.host_id=h.id WHERE h.id IN (SELECT host_id FROM host_groups WHERE group_id=$1) GROUP BY h.id,pd.total_packages,pd.pending_patches -ORDER BY compliance_pct ASC") - .bind(gid).fetch_all(pool).await - .context("compliance PDF query (group) failed")? +ORDER BY compliance_pct ASC", + ) + .bind(gid) + .fetch_all(pool) + .await + .context("compliance PDF query (group) failed")? } else { - sqlx::query(" + sqlx::query( + " SELECT h.display_name, h.fqdn, COALESCE(pd.total_packages,0) AS total_packages, COALESCE(pd.pending_patches,0) AS pending_patches, @@ -268,26 +277,55 @@ SELECT h.display_name, h.fqdn, h.health_status::text AS health_status FROM hosts h LEFT JOIN host_patch_data pd ON pd.host_id=h.id GROUP BY h.id,pd.total_packages,pd.pending_patches -ORDER BY compliance_pct ASC") - .fetch_all(pool).await - .context("compliance PDF query failed")? +ORDER BY compliance_pct ASC", + ) + .fetch_all(pool) + .await + .context("compliance PDF query failed")? }; - let labels: Vec = rows.iter().map(|r| r.try_get::("display_name").unwrap_or_default()).collect(); - let values: Vec = rows.iter().map(|r| r.try_get::("compliance_pct").unwrap_or(0.0)).collect(); + let labels: Vec = rows + .iter() + .map(|r| r.try_get::("display_name").unwrap_or_default()) + .collect(); + let values: Vec = rows + .iter() + .map(|r| r.try_get::("compliance_pct").unwrap_or(0.0)) + .collect(); let mut pdf = PdfBuilder::new("Compliance Report")?; write_title_page(&mut pdf, "Compliance Report", params); let col_x: &[f32] = &[MARGIN, 65.0, 130.0, 165.0, 200.0, 235.0]; - pdf.table_row(&["Host","FQDN","Total Pkgs","Pending","Compliance %","Status"], col_x, 9.0, true); + pdf.table_row( + &[ + "Host", + "FQDN", + "Total Pkgs", + "Pending", + "Compliance %", + "Status", + ], + col_x, + 9.0, + true, + ); for row in &rows { - let name: String = row.try_get("display_name").unwrap_or_default(); - let fqdn: String = row.try_get("fqdn").unwrap_or_default(); - let total: i64 = row.try_get("total_packages").unwrap_or(0); - let pend: i64 = row.try_get("pending_patches").unwrap_or(0); - let pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0); + let name: String = row.try_get("display_name").unwrap_or_default(); + let fqdn: String = row.try_get("fqdn").unwrap_or_default(); + let total: i64 = row.try_get("total_packages").unwrap_or(0); + let pend: i64 = row.try_get("pending_patches").unwrap_or(0); + let pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0); let status: String = row.try_get("health_status").unwrap_or_default(); pdf.table_row( - &[&name,&fqdn,&total.to_string(),&pend.to_string(),&format!("{:.1}%",pct),&status], - col_x, 8.0, false, + &[ + &name, + &fqdn, + &total.to_string(), + &pend.to_string(), + &format!("{:.1}%", pct), + &status, + ], + col_x, + 8.0, + false, ); } if !labels.is_empty() { @@ -298,7 +336,7 @@ ORDER BY compliance_pct ASC") if let Err(e) = pdf.embed_image(raw, w, h, MARGIN, 10.0, 0.18, 0.18) { tracing::warn!(error = %e, "chart embed failed"); } - } + }, Err(e) => tracing::warn!(error = %e, "chart render failed"), } } @@ -309,12 +347,10 @@ ORDER BY compliance_pct ASC") // Patch history PDF // --------------------------------------------------------------------------- -async fn patch_history_pdf( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +async fn patch_history_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { use sqlx::Row; - let rows = sqlx::query(" + let rows = sqlx::query( + " SELECT pj.kind::text AS job_kind, pj.status::text AS job_status, h.display_name, h.fqdn, pjh.started_at, pjh.completed_at, EXTRACT(EPOCH FROM (pjh.completed_at-pjh.started_at))::bigint AS duration_seconds, @@ -325,33 +361,71 @@ JOIN hosts h ON h.id=pjh.host_id LEFT JOIN users u ON u.id=pj.created_by_user_id WHERE ($1::timestamptz IS NULL OR pjh.started_at>=$1) AND ($2::timestamptz IS NULL OR pjh.started_at<=$2) -ORDER BY pjh.started_at DESC") - .bind(params.from).bind(params.to).fetch_all(pool).await - .context("patch history PDF query failed")?; - let mut dc: std::collections::BTreeMap = std::collections::BTreeMap::new(); +ORDER BY pjh.started_at DESC", + ) + .bind(params.from) + .bind(params.to) + .fetch_all(pool) + .await + .context("patch history PDF query failed")?; + let mut dc: std::collections::BTreeMap = std::collections::BTreeMap::new(); for row in &rows { - if let Ok(Some(s)) = row.try_get::>,_>("started_at") { + if let Ok(Some(s)) = row.try_get::>, _>("started_at") { *dc.entry(s.format("%Y-%m-%d").to_string()).or_insert(0.0) += 1.0; } } let cl: Vec = dc.keys().cloned().collect(); - let cv: Vec = dc.values().cloned().collect(); + let cv: Vec = dc.values().cloned().collect(); let mut pdf = PdfBuilder::new("Patch History Report")?; write_title_page(&mut pdf, "Patch History Report", params); - let col_x: &[f32] = &[MARGIN,45.0,80.0,115.0,155.0,200.0,245.0,270.0]; - pdf.table_row(&["Kind","Status","Host","FQDN","Started","Completed","Dur(s)","Operator"], col_x, 9.0, true); + let col_x: &[f32] = &[MARGIN, 45.0, 80.0, 115.0, 155.0, 200.0, 245.0, 270.0]; + pdf.table_row( + &[ + "Kind", + "Status", + "Host", + "FQDN", + "Started", + "Completed", + "Dur(s)", + "Operator", + ], + col_x, + 9.0, + true, + ); for row in &rows { - let kind: String = row.try_get("job_kind").unwrap_or_default(); + let kind: String = row.try_get("job_kind").unwrap_or_default(); let status: String = row.try_get("job_status").unwrap_or_default(); - let name: String = row.try_get("display_name").unwrap_or_default(); - let fqdn: String = row.try_get("fqdn").unwrap_or_default(); - let started: String = row.try_get::>,_>("started_at") - .unwrap_or(None).map(|d| d.format("%Y-%m-%d %H:%M").to_string()).unwrap_or_default(); - let completed: String = row.try_get::>,_>("completed_at") - .unwrap_or(None).map(|d| d.format("%Y-%m-%d %H:%M").to_string()).unwrap_or_default(); - let dur: i64 = row.try_get("duration_seconds").unwrap_or(0); - let op: String = row.try_get("operator").unwrap_or_default(); - pdf.table_row(&[&kind,&status,&name,&fqdn,&started,&completed,&dur.to_string(),&op], col_x, 8.0, false); + let name: String = row.try_get("display_name").unwrap_or_default(); + let fqdn: String = row.try_get("fqdn").unwrap_or_default(); + let started: String = row + .try_get::>, _>("started_at") + .unwrap_or(None) + .map(|d| d.format("%Y-%m-%d %H:%M").to_string()) + .unwrap_or_default(); + let completed: String = row + .try_get::>, _>("completed_at") + .unwrap_or(None) + .map(|d| d.format("%Y-%m-%d %H:%M").to_string()) + .unwrap_or_default(); + let dur: i64 = row.try_get("duration_seconds").unwrap_or(0); + let op: String = row.try_get("operator").unwrap_or_default(); + pdf.table_row( + &[ + &kind, + &status, + &name, + &fqdn, + &started, + &completed, + &dur.to_string(), + &op, + ], + col_x, + 8.0, + false, + ); } if !cl.is_empty() { match render_bar_chart(&cl, &cv, "Jobs per Day") { @@ -361,7 +435,7 @@ ORDER BY pjh.started_at DESC") if let Err(e) = pdf.embed_image(raw, w, h, MARGIN, 10.0, 0.18, 0.18) { tracing::warn!(error = %e, "chart embed failed"); } - } + }, Err(e) => tracing::warn!(error = %e, "chart render failed"), } } @@ -372,10 +446,7 @@ ORDER BY pjh.started_at DESC") // Vulnerability PDF // --------------------------------------------------------------------------- -async fn vulnerability_pdf( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +async fn vulnerability_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { use sqlx::Row; // Query DB FIRST (before creating any non-Send PdfBuilder) let query_result = sqlx::query(" @@ -393,61 +464,104 @@ ORDER BY CASE cve.severity WHEN 'critical' THEN 1 WHEN 'high' THEN 2 WHEN 'mediu // Now create PdfBuilder (non-Send Rc types) after all awaits let mut pdf = PdfBuilder::new("Vulnerability Report")?; write_title_page(&mut pdf, "Vulnerability Exposure Report", params); - let col_x: &[f32] = &[MARGIN,55.0,100.0,130.0,175.0,215.0,255.0]; - pdf.table_row(&["Host","FQDN","CVE ID","Package","Severity","Fix Version","Last Seen"], col_x, 9.0, true); + let col_x: &[f32] = &[MARGIN, 55.0, 100.0, 130.0, 175.0, 215.0, 255.0]; + pdf.table_row( + &[ + "Host", + "FQDN", + "CVE ID", + "Package", + "Severity", + "Fix Version", + "Last Seen", + ], + col_x, + 9.0, + true, + ); match query_result { Ok(rows) => { for row in &rows { let name: String = row.try_get("display_name").unwrap_or_default(); let fqdn: String = row.try_get("fqdn").unwrap_or_default(); - let cve: String = row.try_get("cve_id").unwrap_or_default(); - let pkg: String = row.try_get("package_name").unwrap_or_default(); - let sev: String = row.try_get("severity").unwrap_or_default(); - let fix: String = row.try_get("available_version").unwrap_or_default(); - let seen: String = row.try_get::>,_>("last_seen_at") - .unwrap_or(None).map(|d| d.format("%Y-%m-%d").to_string()).unwrap_or_default(); - pdf.table_row(&[&name,&fqdn,&cve,&pkg,&sev,&fix,&seen], col_x, 8.0, false); + let cve: String = row.try_get("cve_id").unwrap_or_default(); + let pkg: String = row.try_get("package_name").unwrap_or_default(); + let sev: String = row.try_get("severity").unwrap_or_default(); + let fix: String = row.try_get("available_version").unwrap_or_default(); + let seen: String = row + .try_get::>, _>("last_seen_at") + .unwrap_or(None) + .map(|d| d.format("%Y-%m-%d").to_string()) + .unwrap_or_default(); + pdf.table_row( + &[&name, &fqdn, &cve, &pkg, &sev, &fix, &seen], + col_x, + 8.0, + false, + ); } - } + }, Err(e) => { tracing::warn!(error = %e, "vulnerability PDF query failed"); let y = pdf.current_y; pdf.write_text(&format!("No data: {}", e), 10.0, MARGIN, y, false); - } + }, } pdf.save() } - -async fn audit_pdf( - pool: &sqlx::PgPool, - params: &ReportParams, -) -> anyhow::Result> { +async fn audit_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result> { use sqlx::Row; - let rows = sqlx::query(" + let rows = sqlx::query( + " SELECT id::text AS id, created_at, action::text AS action, actor_username, target_type, target_id, ip_address::text AS ip_address, request_id FROM audit_log WHERE ($1::timestamptz IS NULL OR created_at>=$1) AND ($2::timestamptz IS NULL OR created_at<=$2) -ORDER BY created_at DESC LIMIT 10000") - .bind(params.from).bind(params.to).fetch_all(pool).await - .context("audit PDF query failed")?; +ORDER BY created_at DESC LIMIT 10000", + ) + .bind(params.from) + .bind(params.to) + .fetch_all(pool) + .await + .context("audit PDF query failed")?; let mut pdf = PdfBuilder::new("Audit Trail Report")?; write_title_page(&mut pdf, "Audit Trail Report", params); - let col_x: &[f32] = &[MARGIN,50.0,95.0,135.0,175.0,215.0,255.0]; - pdf.table_row(&["Timestamp","Action","Actor","Target Type","Target ID","IP","Request ID"], col_x, 9.0, true); + let col_x: &[f32] = &[MARGIN, 50.0, 95.0, 135.0, 175.0, 215.0, 255.0]; + pdf.table_row( + &[ + "Timestamp", + "Action", + "Actor", + "Target Type", + "Target ID", + "IP", + "Request ID", + ], + col_x, + 9.0, + true, + ); for row in &rows { - let created: String = row.try_get::>,_>("created_at") - .unwrap_or(None).map(|d| d.format("%Y-%m-%d %H:%M").to_string()).unwrap_or_default(); + let created: String = row + .try_get::>, _>("created_at") + .unwrap_or(None) + .map(|d| d.format("%Y-%m-%d %H:%M").to_string()) + .unwrap_or_default(); let action: String = row.try_get("action").unwrap_or_default(); - let actor: String = row.try_get("actor_username").unwrap_or_default(); - let ttype: String = row.try_get("target_type").unwrap_or_default(); - let tid: String = row.try_get("target_id").unwrap_or_default(); - let ip: String = row.try_get("ip_address").unwrap_or_default(); - let req: String = row.try_get("request_id").unwrap_or_default(); - pdf.table_row(&[&created,&action,&actor,&ttype,&tid,&ip,&req], col_x, 8.0, false); + let actor: String = row.try_get("actor_username").unwrap_or_default(); + let ttype: String = row.try_get("target_type").unwrap_or_default(); + let tid: String = row.try_get("target_id").unwrap_or_default(); + let ip: String = row.try_get("ip_address").unwrap_or_default(); + let req: String = row.try_get("request_id").unwrap_or_default(); + pdf.table_row( + &[&created, &action, &actor, &ttype, &tid, &ip, &req], + col_x, + 8.0, + false, + ); } pdf.save() } diff --git a/crates/pm-web/src/main.rs b/crates/pm-web/src/main.rs index 7257e38..81ffe8c 100644 --- a/crates/pm-web/src/main.rs +++ b/crates/pm-web/src/main.rs @@ -2,37 +2,18 @@ mod routes; -use axum::{ - extract::State, - http::StatusCode, - middleware, - response::Json, - routing::get, - Router, -}; +use axum::{extract::State, http::StatusCode, middleware, response::Json, routing::get, Router}; use dashmap::DashMap; -use pm_core::{ - config::AppConfig, - db, - logging, - request_id::request_id_middleware, -}; use pm_auth::{ jwt, - rbac::{AuthConfig, require_auth}, + rbac::{require_auth, AuthConfig}, }; -use routes::ws::WsTicket; +use pm_core::{config::AppConfig, db, logging, request_id::request_id_middleware}; use routes::azure_sso::SsoSession; +use routes::ws::WsTicket; use serde_json::{json, Value}; -use std::{ - net::SocketAddr, - sync::Arc, - time::Duration, -}; -use tower_http::{ - services::ServeDir, - trace::TraceLayer, -}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tower_http::{services::ServeDir, trace::TraceLayer}; /// Shared application state threaded through Axum. #[derive(Clone)] @@ -60,7 +41,10 @@ async fn main() -> anyhow::Result<()> { }); logging::init(&config.logging); - tracing::info!(version = env!("CARGO_PKG_VERSION"), "patch-manager-web starting"); + tracing::info!( + version = env!("CARGO_PKG_VERSION"), + "patch-manager-web starting" + ); let signing_key_pem = jwt::load_signing_key(&config.security.jwt_signing_key_path) .unwrap_or_else(|e| { @@ -68,8 +52,8 @@ async fn main() -> anyhow::Result<()> { String::new() }); - let verify_key_pem = jwt::load_verify_key(&config.security.jwt_verify_key_path) - .unwrap_or_else(|e| { + let verify_key_pem = + jwt::load_verify_key(&config.security.jwt_verify_key_path).unwrap_or_else(|e| { tracing::warn!(error = %e, "JWT verify key not found (dev mode)"); String::new() }); @@ -159,7 +143,10 @@ pub fn build_router(state: AppState) -> Router { // Patch jobs .nest("/jobs", routes::jobs::router()) // Maintenance windows (nested under hosts path param) - .nest("/hosts/:host_id/maintenance-windows", routes::maintenance_windows::router()) + .nest( + "/hosts/:host_id/maintenance-windows", + routes::maintenance_windows::router(), + ) // CA root certificate download .nest("/ca", routes::ca::ca_router()) // Certificate list / renew / revoke @@ -187,9 +174,7 @@ pub fn build_router(state: AppState) -> Router { // WebSocket browser endpoint — ticket-authenticated, outside JWT middleware .merge(routes::ws::ws_router()) // Serve React SPA - .fallback_service( - ServeDir::new(&static_dir).append_index_html_on_directories(true), - ) + .fallback_service(ServeDir::new(&static_dir).append_index_html_on_directories(true)) .layer(middleware::from_fn(request_id_middleware)) .layer(TraceLayer::new_for_http()) .with_state(state) @@ -199,5 +184,9 @@ async fn health_handler(State(state): State) -> Result, St let db_ok = sqlx::query("SELECT 1").execute(&state.db).await.is_ok(); let status = if db_ok { "healthy" } else { "degraded" }; let body = json!({ "service": "patch-manager-web", "version": env!("CARGO_PKG_VERSION"), "status": status, "database": if db_ok { "ok" } else { "error" } }); - if db_ok { Ok(Json(body)) } else { Err(StatusCode::SERVICE_UNAVAILABLE) } + if db_ok { + Ok(Json(body)) + } else { + Err(StatusCode::SERVICE_UNAVAILABLE) + } } diff --git a/crates/pm-web/src/routes/auth.rs b/crates/pm-web/src/routes/auth.rs index 1ab2eb6..17fb69a 100644 --- a/crates/pm-web/src/routes/auth.rs +++ b/crates/pm-web/src/routes/auth.rs @@ -18,8 +18,8 @@ use axum::{ }; use pm_auth::{ mfa_totp, - session::{self, LoginRequest, LoginResponse}, rbac::AuthUser, + session::{self, LoginRequest, LoginResponse}, }; use serde::Deserialize; use serde_json::{json, Value}; @@ -107,10 +107,17 @@ async fn login_handler( ), _ => { tracing::error!(error = %e, "Login error"); - (StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "An error occurred") - } + ( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "An error occurred", + ) + }, }; - (status, Json(json!({ "error": { "code": code, "message": message } }))) + ( + status, + Json(json!({ "error": { "code": code, "message": message } })), + ) }) } @@ -156,10 +163,17 @@ async fn refresh_handler( ), _ => { tracing::error!(error = %e, "Refresh error"); - (StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "An error occurred") - } + ( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "An error occurred", + ) + }, }; - (status, Json(json!({ "error": { "code": code, "message": msg } }))) + ( + status, + Json(json!({ "error": { "code": code, "message": msg } })), + ) }) } @@ -221,11 +235,13 @@ async fn mfa_verify_handler( auth_user: AuthUser, Json(req): Json, ) -> Result, (StatusCode, Json)> { - let valid = mfa_totp::verify_code(&auth_user.username, &req.secret_base32, &req.code) - .map_err(|e| ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), - ))?; + let valid = + mfa_totp::verify_code(&auth_user.username, &req.secret_base32, &req.code).map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })?; if !valid { return Err(( diff --git a/crates/pm-web/src/routes/azure_sso.rs b/crates/pm-web/src/routes/azure_sso.rs index 41330c0..aa2b434 100644 --- a/crates/pm-web/src/routes/azure_sso.rs +++ b/crates/pm-web/src/routes/azure_sso.rs @@ -97,15 +97,19 @@ async fn azure_login( None => { return Err(( StatusCode::FORBIDDEN, - Json(json!({ "error": { "code": "forbidden", "message": "Azure SSO is not configured" } })), + Json( + json!({ "error": { "code": "forbidden", "message": "Azure SSO is not configured" } }), + ), )); - } + }, }; if !enabled { return Err(( StatusCode::FORBIDDEN, - Json(json!({ "error": { "code": "forbidden", "message": "Azure SSO is not enabled" } })), + Json( + json!({ "error": { "code": "forbidden", "message": "Azure SSO is not enabled" } }), + ), )); } @@ -162,7 +166,9 @@ async fn azure_callback( let desc = params.error_description.unwrap_or_default(); return Err(( StatusCode::BAD_REQUEST, - Json(json!({ "error": { "code": "sso_error", "message": format!("Azure AD error: {} - {}", error, desc) } })), + Json( + json!({ "error": { "code": "sso_error", "message": format!("Azure AD error: {} - {}", error, desc) } }), + ), )); } @@ -176,7 +182,9 @@ async fn azure_callback( let state_token = params.state.ok_or_else(|| { ( StatusCode::BAD_REQUEST, - Json(json!({ "error": { "code": "bad_request", "message": "Missing state parameter" } })), + Json( + json!({ "error": { "code": "bad_request", "message": "Missing state parameter" } }), + ), ) })?; @@ -211,9 +219,11 @@ async fn azure_callback( None => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": "Azure SSO not configured" } })), + Json( + json!({ "error": { "code": "internal_error", "message": "Azure SSO not configured" } }), + ), )); - } + }, }; // Exchange code for tokens @@ -263,7 +273,9 @@ async fn azure_callback( tracing::error!(status = %status, body = %body, "Token exchange failed"); return Err(( StatusCode::BAD_GATEWAY, - Json(json!({ "error": { "code": "sso_error", "message": format!("Token exchange failed: HTTP {}", status) } })), + Json( + json!({ "error": { "code": "sso_error", "message": format!("Token exchange failed: HTTP {}", status) } }), + ), )); } @@ -302,7 +314,9 @@ async fn azure_callback( if email.is_empty() || oid.is_empty() { return Err(( StatusCode::BAD_GATEWAY, - Json(json!({ "error": { "code": "sso_error", "message": "Missing email or oid in id_token" } })), + Json( + json!({ "error": { "code": "sso_error", "message": "Missing email or oid in id_token" } }), + ), )); } @@ -326,9 +340,11 @@ async fn azure_callback( Some(u) if !u.is_active => { return Err(( StatusCode::FORBIDDEN, - Json(json!({ "error": { "code": "account_disabled", "message": "Account is disabled" } })), + Json( + json!({ "error": { "code": "account_disabled", "message": "Account is disabled" } }), + ), )); - } + }, Some(u) => u, None => { // Auto-create user with role=operator, auth_provider=azure_sso @@ -372,22 +388,24 @@ async fn azure_callback( is_active: true, mfa_enabled: false, } - } + }, }; // Update last_login_at and azure_oid - sqlx::query("UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1) WHERE id = $2") - .bind(&oid) - .bind(user.id) - .execute(&state.db) - .await - .map_err(|e| { - tracing::error!(error = %e, "Failed to update last_login_at"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), - ) - })?; + sqlx::query( + "UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1) WHERE id = $2", + ) + .bind(&oid) + .bind(user.id) + .execute(&state.db) + .await + .map_err(|e| { + tracing::error!(error = %e, "Failed to update last_login_at"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) + })?; // Issue JWT access token + refresh token let access_ttl = state.config.security.jwt_access_ttl_secs as i64; @@ -466,6 +484,5 @@ fn decode_jwt_payload(token: &str) -> Result { .decode(&payload_b64_padded) .map_err(|e| format!("Base64 decode error: {}", e))?; - serde_json::from_slice(&payload_bytes) - .map_err(|e| format!("JSON parse error: {}", e)) + serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON parse error: {}", e)) } diff --git a/crates/pm-web/src/routes/ca.rs b/crates/pm-web/src/routes/ca.rs index 2dd37ac..0074825 100644 --- a/crates/pm-web/src/routes/ca.rs +++ b/crates/pm-web/src/routes/ca.rs @@ -33,8 +33,7 @@ use crate::AppState; /// Handles routes mounted at /api/v1/ca pub fn ca_router() -> Router { - Router::new() - .route("/root.crt", get(download_root_ca)) + Router::new().route("/root.crt", get(download_root_ca)) } /// Handles routes mounted at /api/v1/certificates @@ -84,10 +83,7 @@ struct IssueCertRequest { // ── Helper: build PEM download response ────────────────────────────────────── -fn pem_response( - pem: String, - filename: &str, -) -> Result, (StatusCode, Json)> { +fn pem_response(pem: String, filename: &str) -> Result, (StatusCode, Json)> { let disposition = format!("attachment; filename=\"{filename}\""); Response::builder() .status(StatusCode::OK) @@ -174,7 +170,7 @@ async fn list_certificates( .bind(st) .fetch_all(&state.db) .await - } + }, (Some(hid), None) => { sqlx::query_as::<_, CertRow>( r#"SELECT id, host_id, serial_number, common_name, @@ -187,7 +183,7 @@ async fn list_certificates( .bind(hid) .fetch_all(&state.db) .await - } + }, (None, Some(st)) => { sqlx::query_as::<_, CertRow>( r#"SELECT id, host_id, serial_number, common_name, @@ -200,7 +196,7 @@ async fn list_certificates( .bind(st) .fetch_all(&state.db) .await - } + }, (None, None) => { sqlx::query_as::<_, CertRow>( r#"SELECT id, host_id, serial_number, common_name, @@ -211,7 +207,7 @@ async fn list_certificates( ) .fetch_all(&state.db) .await - } + }, } .map_err(db_error)?; @@ -259,7 +255,7 @@ async fn download_client_cert( ) .await; pem_response(pem, "client.crt") - } + }, None => Err(( StatusCode::NOT_FOUND, Json(json!({ @@ -328,25 +324,23 @@ async fn renew_cert( ) -> Result, (StatusCode, Json)> { require_admin(&auth)?; - let issued = state - .ca - .renew_cert(cert_id, &state.db) - .await - .map_err(|e| { - let msg = e.to_string(); - tracing::error!(error = %e, %cert_id, "Failed to renew cert"); - if msg.contains("not found") { - ( - StatusCode::NOT_FOUND, - Json(json!({ "error": { "code": "not_found", "message": "Certificate not found" } })), - ) - } else { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": msg } })), - ) - } - })?; + let issued = state.ca.renew_cert(cert_id, &state.db).await.map_err(|e| { + let msg = e.to_string(); + tracing::error!(error = %e, %cert_id, "Failed to renew cert"); + if msg.contains("not found") { + ( + StatusCode::NOT_FOUND, + Json( + json!({ "error": { "code": "not_found", "message": "Certificate not found" } }), + ), + ) + } else { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": msg } })), + ) + } + })?; log_event( &state.db, diff --git a/crates/pm-web/src/routes/discovery.rs b/crates/pm-web/src/routes/discovery.rs index ee3d230..43f5c48 100644 --- a/crates/pm-web/src/routes/discovery.rs +++ b/crates/pm-web/src/routes/discovery.rs @@ -11,11 +11,11 @@ use axum::{ routing::{get, post}, Router, }; +use pm_auth::rbac::AuthUser; use pm_core::{ audit::{log_event, AuditAction}, models::{DiscoveryCidrRequest, DiscoveryResult, RegisterDiscoveredRequest}, }; -use pm_auth::rbac::AuthUser; use serde_json::{json, Value}; use std::{ net::{IpAddr, TcpStream}, @@ -46,13 +46,18 @@ async fn start_cidr_scan( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } - let cidr: ipnet::IpNet = req.cidr.parse().map_err(|_| ( - StatusCode::BAD_REQUEST, - Json(json!({ "error": { "code": "bad_request", "message": "Invalid CIDR range" } })) - ))?; + let cidr: ipnet::IpNet = req.cidr.parse().map_err(|_| { + ( + StatusCode::BAD_REQUEST, + Json(json!({ "error": { "code": "bad_request", "message": "Invalid CIDR range" } })), + ) + })?; let agent_port = req.agent_port.unwrap_or(12443) as u16; let scan_id = Uuid::new_v4(); @@ -67,13 +72,23 @@ async fn start_cidr_scan( run_cidr_scan(pool, scan_id_clone, cidr, agent_port).await; }); - log_event(&state.db, AuditAction::DiscoveryScanStarted, - Some(auth.user_id), Some(&auth.username), - Some("discovery"), Some(&scan_id.to_string()), - json!({ "cidr": cidr_str }), None, None).await; + log_event( + &state.db, + AuditAction::DiscoveryScanStarted, + Some(auth.user_id), + Some(&auth.username), + Some("discovery"), + Some(&scan_id.to_string()), + json!({ "cidr": cidr_str }), + None, + None, + ) + .await; tracing::info!(scan_id = %scan_id, cidr = %req.cidr, "CIDR scan started"); - Ok(Json(json!({ "scan_id": scan_id, "message": "Discovery scan started", "cidr": req.cidr }))) + Ok(Json( + json!({ "scan_id": scan_id, "message": "Discovery scan started", "cidr": req.cidr }), + )) } /// Background CIDR scanner. @@ -103,12 +118,7 @@ async fn run_cidr_scan(pool: sqlx::PgPool, scan_id: Uuid, cidr: ipnet::IpNet, po } /// Probe a single IP:port and store the result if the port is open. -async fn probe_and_store( - pool: sqlx::PgPool, - scan_id: Uuid, - ip: IpAddr, - port: u16, -) -> Option<()> { +async fn probe_and_store(pool: sqlx::PgPool, scan_id: Uuid, ip: IpAddr, port: u16) -> Option<()> { let addr = format!("{ip}:{port}"); // TCP connect probe (blocking, run in thread pool) @@ -116,9 +126,13 @@ async fn probe_and_store( let addr_clone = addr.clone(); let open = task::spawn_blocking(move || { TcpStream::connect_timeout( - &match addr_clone.parse() { Ok(a) => a, Err(_) => return false }, + &match addr_clone.parse() { + Ok(a) => a, + Err(_) => return false, + }, Duration::from_secs(PROBE_TIMEOUT_SECS), - ).is_ok() + ) + .is_ok() }) .await .unwrap_or(false); @@ -132,7 +146,8 @@ async fn probe_and_store( let fqdn = task::spawn_blocking(move || { use std::net::ToSocketAddrs; let addr = format!("{ip_clone}:{port}"); - addr.to_socket_addrs().ok() + addr.to_socket_addrs() + .ok() .and_then(|mut a| a.next()) .and_then(|_| dns_lookup_for_ip(ip_clone)) }) @@ -163,7 +178,10 @@ fn dns_lookup_for_ip(ip: IpAddr) -> Option { // Standard library doesn't have reverse lookup; use getaddrinfo via format let host = format!("{ip}"); // Best-effort: try to resolve numeric address to hostname - (host + ":0").to_socket_addrs().ok()?.next() + (host + ":0") + .to_socket_addrs() + .ok()? + .next() .map(|a| a.ip().to_string()) .filter(|s| s != &ip.to_string()) } @@ -188,7 +206,10 @@ async fn get_scan_results( .map(Json) .map_err(|e| { tracing::error!(error = %e); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) }) } @@ -201,7 +222,10 @@ async fn register_discovered_host( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } // Fetch discovery result @@ -213,7 +237,12 @@ async fn register_discovered_host( .bind(id) .fetch_optional(&state.db) .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?; + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })?; let result = result.ok_or_else(|| ( StatusCode::NOT_FOUND, @@ -235,7 +264,12 @@ async fn register_discovered_host( .bind(result.agent_port) .fetch_one(&state.db) .await - .map_err(|e| (StatusCode::CONFLICT, Json(json!({ "error": { "code": "conflict", "message": e.to_string() } }))))?; + .map_err(|e| { + ( + StatusCode::CONFLICT, + Json(json!({ "error": { "code": "conflict", "message": e.to_string() } })), + ) + })?; // Assign to groups if let Some(group_ids) = &req.group_ids { @@ -247,10 +281,24 @@ async fn register_discovered_host( // Mark as registered let _ = sqlx::query("UPDATE discovery_results SET registered = TRUE WHERE id = $1") - .bind(id).execute(&state.db).await; + .bind(id) + .execute(&state.db) + .await; - log_event(&state.db, AuditAction::HostRegistered, Some(auth.user_id), Some(&auth.username), - Some("host"), Some(&host_id.to_string()), json!({ "from_discovery": true, "ip": result.ip_address }), None, None).await; + log_event( + &state.db, + AuditAction::HostRegistered, + Some(auth.user_id), + Some(&auth.username), + Some("host"), + Some(&host_id.to_string()), + json!({ "from_discovery": true, "ip": result.ip_address }), + None, + None, + ) + .await; - Ok(Json(json!({ "host_id": host_id, "message": "Host registered from discovery" }))) + Ok(Json( + json!({ "host_id": host_id, "message": "Host registered from discovery" }), + )) } diff --git a/crates/pm-web/src/routes/groups.rs b/crates/pm-web/src/routes/groups.rs index 630a7ff..b23a98f 100644 --- a/crates/pm-web/src/routes/groups.rs +++ b/crates/pm-web/src/routes/groups.rs @@ -15,11 +15,11 @@ use axum::{ routing::{delete, get, post, put}, Router, }; +use pm_auth::rbac::AuthUser; use pm_core::{ audit::{log_event, AuditAction}, - models::{Group, CreateGroupRequest, UpdateGroupRequest}, + models::{CreateGroupRequest, Group, UpdateGroupRequest}, }; -use pm_auth::rbac::AuthUser; use serde_json::{json, Value}; use uuid::Uuid; @@ -28,8 +28,14 @@ use crate::AppState; pub fn router() -> Router { Router::new() .route("/", get(list_groups).post(create_group)) - .route("/:id", get(get_group).put(update_group).delete(delete_group)) - .route("/:id/users/:user_id", post(add_user_to_group).delete(remove_user_from_group)) + .route( + "/:id", + get(get_group).put(update_group).delete(delete_group), + ) + .route( + "/:id/users/:user_id", + post(add_user_to_group).delete(remove_user_from_group), + ) } async fn list_groups( @@ -37,14 +43,17 @@ async fn list_groups( _auth: AuthUser, ) -> Result>, (StatusCode, Json)> { sqlx::query_as::<_, Group>( - "SELECT id, name, description, created_at, updated_at FROM groups ORDER BY name" + "SELECT id, name, description, created_at, updated_at FROM groups ORDER BY name", ) .fetch_all(&state.db) .await .map(Json) .map_err(|e| { tracing::error!(error = %e, "Failed to list groups"); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) }) } @@ -54,23 +63,42 @@ async fn create_group( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } - let id: Uuid = sqlx::query_scalar( - "INSERT INTO groups (name, description) VALUES ($1, $2) RETURNING id" - ) - .bind(&req.name) - .bind(req.description.as_deref().unwrap_or("")) - .fetch_one(&state.db) - .await - .map_err(|e| { - let msg = if e.to_string().contains("unique") { "Group name already exists".to_string() } else { "Database error".to_string() }; - (StatusCode::CONFLICT, Json(json!({ "error": { "code": "conflict", "message": msg } }))) - })?; + let id: Uuid = + sqlx::query_scalar("INSERT INTO groups (name, description) VALUES ($1, $2) RETURNING id") + .bind(&req.name) + .bind(req.description.as_deref().unwrap_or("")) + .fetch_one(&state.db) + .await + .map_err(|e| { + let msg = if e.to_string().contains("unique") { + "Group name already exists".to_string() + } else { + "Database error".to_string() + }; + ( + StatusCode::CONFLICT, + Json(json!({ "error": { "code": "conflict", "message": msg } })), + ) + })?; - log_event(&state.db, AuditAction::GroupCreated, Some(auth.user_id), Some(&auth.username), - Some("group"), Some(&id.to_string()), json!({ "name": req.name }), None, None).await; + log_event( + &state.db, + AuditAction::GroupCreated, + Some(auth.user_id), + Some(&auth.username), + Some("group"), + Some(&id.to_string()), + json!({ "name": req.name }), + None, + None, + ) + .await; Ok(Json(json!({ "id": id, "message": "Group created" }))) } @@ -81,24 +109,43 @@ async fn get_group( Path(id): Path, ) -> Result, (StatusCode, Json)> { let group: Option = sqlx::query_as( - "SELECT id, name, description, created_at, updated_at FROM groups WHERE id = $1" + "SELECT id, name, description, created_at, updated_at FROM groups WHERE id = $1", ) .bind(id) .fetch_optional(&state.db) .await .map_err(|e| { - tracing::error!(error = %e); (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + tracing::error!(error = %e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) })?; - let group = group.ok_or_else(|| (StatusCode::NOT_FOUND, Json(json!({ "error": { "code": "not_found", "message": "Group not found" } }))))?; + let group = group.ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })), + ) + })?; // Fetch member counts - let host_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM host_groups WHERE group_id = $1") - .bind(id).fetch_one(&state.db).await.unwrap_or(0); - let user_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM user_groups WHERE group_id = $1") - .bind(id).fetch_one(&state.db).await.unwrap_or(0); + let host_count: i64 = + sqlx::query_scalar("SELECT COUNT(*) FROM host_groups WHERE group_id = $1") + .bind(id) + .fetch_one(&state.db) + .await + .unwrap_or(0); + let user_count: i64 = + sqlx::query_scalar("SELECT COUNT(*) FROM user_groups WHERE group_id = $1") + .bind(id) + .fetch_one(&state.db) + .await + .unwrap_or(0); - Ok(Json(json!({ "group": group, "host_count": host_count, "user_count": user_count }))) + Ok(Json( + json!({ "group": group, "host_count": host_count, "user_count": user_count }), + )) } async fn update_group( @@ -108,7 +155,10 @@ async fn update_group( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } let rows = sqlx::query( @@ -123,7 +173,10 @@ async fn update_group( .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })))); + return Err(( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })), + )); } Ok(Json(json!({ "message": "Group updated" }))) @@ -135,20 +188,43 @@ async fn delete_group( Path(id): Path, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } let rows = sqlx::query("DELETE FROM groups WHERE id = $1") - .bind(id).execute(&state.db).await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))? + .bind(id) + .execute(&state.db) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })))); + return Err(( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })), + )); } - log_event(&state.db, AuditAction::GroupDeleted, Some(auth.user_id), Some(&auth.username), - Some("group"), Some(&id.to_string()), json!({}), None, None).await; + log_event( + &state.db, + AuditAction::GroupDeleted, + Some(auth.user_id), + Some(&auth.username), + Some("group"), + Some(&id.to_string()), + json!({}), + None, + None, + ) + .await; Ok(Json(json!({ "message": "Group deleted" }))) } @@ -159,16 +235,38 @@ async fn add_user_to_group( Path((id, user_id)): Path<(Uuid, Uuid)>, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } - sqlx::query("INSERT INTO user_groups (user_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING") - .bind(user_id).bind(id) - .execute(&state.db).await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?; + sqlx::query( + "INSERT INTO user_groups (user_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING", + ) + .bind(user_id) + .bind(id) + .execute(&state.db) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })?; - log_event(&state.db, AuditAction::GroupMembershipChanged, Some(auth.user_id), Some(&auth.username), - Some("user_group"), Some(&id.to_string()), json!({ "user_id": user_id, "action": "added" }), None, None).await; + log_event( + &state.db, + AuditAction::GroupMembershipChanged, + Some(auth.user_id), + Some(&auth.username), + Some("user_group"), + Some(&id.to_string()), + json!({ "user_id": user_id, "action": "added" }), + None, + None, + ) + .await; Ok(Json(json!({ "message": "User added to group" }))) } @@ -179,16 +277,36 @@ async fn remove_user_from_group( Path((id, user_id)): Path<(Uuid, Uuid)>, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } sqlx::query("DELETE FROM user_groups WHERE user_id = $1 AND group_id = $2") - .bind(user_id).bind(id) - .execute(&state.db).await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?; + .bind(user_id) + .bind(id) + .execute(&state.db) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })?; - log_event(&state.db, AuditAction::GroupMembershipChanged, Some(auth.user_id), Some(&auth.username), - Some("user_group"), Some(&id.to_string()), json!({ "user_id": user_id, "action": "removed" }), None, None).await; + log_event( + &state.db, + AuditAction::GroupMembershipChanged, + Some(auth.user_id), + Some(&auth.username), + Some("user_group"), + Some(&id.to_string()), + json!({ "user_id": user_id, "action": "removed" }), + None, + None, + ) + .await; Ok(Json(json!({ "message": "User removed from group" }))) } diff --git a/crates/pm-web/src/routes/hosts.rs b/crates/pm-web/src/routes/hosts.rs index c1b881a..2f77a70 100644 --- a/crates/pm-web/src/routes/hosts.rs +++ b/crates/pm-web/src/routes/hosts.rs @@ -16,13 +16,11 @@ use axum::{ routing::{delete, get, post}, Router, }; +use pm_auth::rbac::AuthUser; use pm_core::{ audit::{log_event, AuditAction}, - models::{ - CreateHostRequest, HostSummary, Group, - }, + models::{CreateHostRequest, Group, HostSummary}, }; -use pm_auth::rbac::AuthUser; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use uuid::Uuid; @@ -88,12 +86,11 @@ async fn operator_can_access_host( } // Ungrouped hosts are accessible to any operator - let ungrouped: bool = sqlx::query_scalar( - "SELECT NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = $1)", - ) - .bind(host_id) - .fetch_one(pool) - .await?; + let ungrouped: bool = + sqlx::query_scalar("SELECT NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = $1)") + .bind(host_id) + .fetch_one(pool) + .await?; Ok(ungrouped) } @@ -162,7 +159,12 @@ async fn list_hosts( .await .unwrap_or(0); - Ok(Json(HostListResponse { hosts, total, limit, offset })) + Ok(Json(HostListResponse { + hosts, + total, + limit, + offset, + })) } // ── POST /api/v1/hosts ──────────────────────────────────────────────────────── @@ -244,7 +246,8 @@ async fn register_host( json!({ "fqdn": req.fqdn, "ip": ip_address }), None, None, - ).await; + ) + .await; tracing::info!(host_id = %host_id, fqdn = %req.fqdn, "Host registered"); Ok(Json(json!({ "id": host_id, "message": "Host registered" }))) @@ -291,10 +294,12 @@ async fn get_host( ) })?; - host.map(Json).ok_or_else(|| ( - StatusCode::NOT_FOUND, - Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })), - )) + host.map(Json).ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })), + ) + }) } // ── DELETE /api/v1/hosts/:id ────────────────────────────────────────────────── @@ -347,7 +352,8 @@ async fn remove_host( json!({ "fqdn": fqdn }), None, None, - ).await; + ) + .await; tracing::info!(host_id = %id, "Host removed"); Ok(Json(json!({ "message": "Host removed" }))) @@ -362,10 +368,13 @@ async fn list_host_groups( ) -> Result>, (StatusCode, Json)> { if !auth.role.is_admin() { let can_access = operator_can_access_host(&state.db, auth.user_id, id) - .await.unwrap_or(false); + .await + .unwrap_or(false); if !can_access { - return Err((StatusCode::FORBIDDEN, - Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })), + )); } } @@ -381,8 +390,10 @@ async fn list_host_groups( .await .map_err(|e| { tracing::error!(error = %e, "Failed to list host groups"); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) })?; Ok(Json(groups)) @@ -391,7 +402,9 @@ async fn list_host_groups( // ── POST /api/v1/hosts/:id/groups ───────────────────────────────────────────── #[derive(Debug, Deserialize)] -struct AddToGroupRequest { group_id: Uuid } +struct AddToGroupRequest { + group_id: Uuid, +} async fn add_host_to_group( State(state): State, @@ -400,8 +413,10 @@ async fn add_host_to_group( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, - Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } sqlx::query( @@ -413,13 +428,24 @@ async fn add_host_to_group( .await .map_err(|e| { tracing::error!(error = %e, "Failed to add host to group"); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) })?; - log_event(&state.db, AuditAction::GroupMembershipChanged, - Some(auth.user_id), Some(&auth.username), Some("host"), Some(&id.to_string()), - json!({ "group_id": req.group_id, "action": "added" }), None, None).await; + log_event( + &state.db, + AuditAction::GroupMembershipChanged, + Some(auth.user_id), + Some(&auth.username), + Some("host"), + Some(&id.to_string()), + json!({ "group_id": req.group_id, "action": "added" }), + None, + None, + ) + .await; Ok(Json(json!({ "message": "Host added to group" }))) } @@ -432,22 +458,37 @@ async fn remove_host_from_group( Path((id, group_id)): Path<(Uuid, Uuid)>, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, - Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } sqlx::query("DELETE FROM host_groups WHERE host_id = $1 AND group_id = $2") - .bind(id).bind(group_id) - .execute(&state.db).await + .bind(id) + .bind(group_id) + .execute(&state.db) + .await .map_err(|e| { tracing::error!(error = %e, "Failed to remove host from group"); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) })?; - log_event(&state.db, AuditAction::GroupMembershipChanged, - Some(auth.user_id), Some(&auth.username), Some("host"), Some(&id.to_string()), - json!({ "group_id": group_id, "action": "removed" }), None, None).await; + log_event( + &state.db, + AuditAction::GroupMembershipChanged, + Some(auth.user_id), + Some(&auth.username), + Some("host"), + Some(&id.to_string()), + json!({ "group_id": group_id, "action": "removed" }), + None, + None, + ) + .await; Ok(Json(json!({ "message": "Host removed from group" }))) } diff --git a/crates/pm-web/src/routes/jobs.rs b/crates/pm-web/src/routes/jobs.rs index 42adfc3..5a5a875 100644 --- a/crates/pm-web/src/routes/jobs.rs +++ b/crates/pm-web/src/routes/jobs.rs @@ -149,7 +149,11 @@ async fn create_job( .await .map_err(|e| { tracing::error!(error = %e, "create_job: insert patch_jobs failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; // Insert one patch_job_hosts row per requested host. @@ -170,7 +174,11 @@ async fn create_job( error = %e, %job_id, %host_id, "create_job: insert patch_job_hosts failed" ); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; } @@ -310,7 +318,12 @@ async fn list_jobs( .unwrap_or(0) }; - Ok(Json(JobListResponse { jobs, total, limit, offset })) + Ok(Json(JobListResponse { + jobs, + total, + limit, + offset, + })) } // ── GET /api/v1/jobs/:id ───────────────────────────────────────────────────── @@ -325,11 +338,7 @@ async fn get_job( .await .unwrap_or(false); if !allowed { - return Err(err( - StatusCode::FORBIDDEN, - "forbidden", - "Access denied", - )); + return Err(err(StatusCode::FORBIDDEN, "forbidden", "Access denied")); } } @@ -350,12 +359,14 @@ async fn get_job( .await .map_err(|e| { tracing::error!(error = %e, %id, "get_job: failed to fetch job"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; - let job = job.ok_or_else(|| { - err(StatusCode::NOT_FOUND, "not_found", "Job not found") - })?; + let job = job.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?; // Fetch per-host status rows joined to the host display name. let hosts: Vec = sqlx::query_as( @@ -383,7 +394,11 @@ async fn get_job( .await .map_err(|e| { tracing::error!(error = %e, %id, "get_job: failed to fetch host rows"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; Ok(Json(json!({ "job": job, "hosts": hosts }))) @@ -397,20 +412,22 @@ async fn cancel_job( Path(id): Path, ) -> Result, (StatusCode, Json)> { // Fetch the job to verify it exists and check ownership. - let row: Option<(String, Option)> = sqlx::query_as( - "SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1", - ) - .bind(id) - .fetch_optional(&state.db) - .await - .map_err(|e| { - tracing::error!(error = %e, %id, "cancel_job: db fetch failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") - })?; + let row: Option<(String, Option)> = + sqlx::query_as("SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1") + .bind(id) + .fetch_optional(&state.db) + .await + .map_err(|e| { + tracing::error!(error = %e, %id, "cancel_job: db fetch failed"); + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) + })?; - let (status_str, creator_id) = row.ok_or_else(|| { - err(StatusCode::NOT_FOUND, "not_found", "Job not found") - })?; + let (status_str, creator_id) = + row.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?; // Only admin or the job creator may cancel. if !auth.role.is_admin() { @@ -437,16 +454,18 @@ async fn cancel_job( } // Cancel the parent job. - sqlx::query( - "UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1", - ) - .bind(id) - .execute(&state.db) - .await - .map_err(|e| { - tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") - })?; + sqlx::query("UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1") + .bind(id) + .execute(&state.db) + .await + .map_err(|e| { + tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed"); + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) + })?; // Cancel all queued/pending host rows for this job. sqlx::query( @@ -462,7 +481,11 @@ async fn cancel_job( .await .map_err(|e| { tracing::error!(error = %e, %id, "cancel_job: update patch_job_hosts failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; log_event( @@ -506,7 +529,11 @@ async fn rollback_job( .await .map_err(|e| { tracing::error!(error = %e, %id, "rollback_job: existence check failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; if !original_exists { @@ -521,7 +548,11 @@ async fn rollback_job( .await .map_err(|e| { tracing::error!(error = %e, %id, "rollback_job: host fetch failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; if host_ids.is_empty() { @@ -552,7 +583,11 @@ async fn rollback_job( .await .map_err(|e| { tracing::error!(error = %e, parent_job_id = %id, "rollback_job: insert failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; // Replicate host list into the rollback job. @@ -573,7 +608,11 @@ async fn rollback_job( error = %e, %rollback_job_id, %host_id, "rollback_job: insert patch_job_hosts failed" ); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; } diff --git a/crates/pm-web/src/routes/maintenance_windows.rs b/crates/pm-web/src/routes/maintenance_windows.rs index 90aa2fa..ce32b56 100644 --- a/crates/pm-web/src/routes/maintenance_windows.rs +++ b/crates/pm-web/src/routes/maintenance_windows.rs @@ -15,9 +15,7 @@ use axum::{ use pm_auth::rbac::AuthUser; use pm_core::{ audit::{log_event, AuditAction}, - models::{ - CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest, - }, + models::{CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest}, }; use serde_json::{json, Value}; use uuid::Uuid; @@ -56,15 +54,18 @@ async fn list_windows( Path(host_id): Path, ) -> Result, (StatusCode, Json)> { // Verify host exists. - let host_exists: bool = - sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)") - .bind(host_id) - .fetch_one(&state.db) - .await - .map_err(|e| { - tracing::error!(error = %e, %host_id, "list_windows: host existence check failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") - })?; + let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)") + .bind(host_id) + .fetch_one(&state.db) + .await + .map_err(|e| { + tracing::error!(error = %e, %host_id, "list_windows: host existence check failed"); + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) + })?; if !host_exists { return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found")); @@ -84,7 +85,11 @@ async fn list_windows( .await .map_err(|e| { tracing::error!(error = %e, %host_id, "list_windows: query failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; Ok(Json(json!({ "windows": windows }))) @@ -101,41 +106,44 @@ async fn create_window( // Validate: weekly requires recurrence_day 0-6 if req.recurrence == pm_core::models::WindowRecurrence::Weekly { match req.recurrence_day { - Some(d) if (0..=6).contains(&d) => {} + Some(d) if (0..=6).contains(&d) => {}, _ => { return Err(err( StatusCode::BAD_REQUEST, "bad_request", "Weekly recurrence requires recurrence_day 0-6 (0=Sunday)", )); - } + }, } } // Validate: monthly requires recurrence_day 1-31 if req.recurrence == pm_core::models::WindowRecurrence::Monthly { match req.recurrence_day { - Some(d) if (1..=31).contains(&d) => {} + Some(d) if (1..=31).contains(&d) => {}, _ => { return Err(err( StatusCode::BAD_REQUEST, "bad_request", "Monthly recurrence requires recurrence_day 1-31", )); - } + }, } } // Verify host exists. - let host_exists: bool = - sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)") - .bind(host_id) - .fetch_one(&state.db) - .await - .map_err(|e| { - tracing::error!(error = %e, %host_id, "create_window: host existence check failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") - })?; + let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)") + .bind(host_id) + .fetch_one(&state.db) + .await + .map_err(|e| { + tracing::error!(error = %e, %host_id, "create_window: host existence check failed"); + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) + })?; if !host_exists { return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found")); @@ -165,7 +173,11 @@ async fn create_window( .await .map_err(|e| { tracing::error!(error = %e, %host_id, "create_window: insert failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; log_event( @@ -219,44 +231,52 @@ async fn update_window( .await .map_err(|e| { tracing::error!(error = %e, %win_id, "update_window: fetch failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; let existing = existing.ok_or_else(|| { - err(StatusCode::NOT_FOUND, "not_found", "Maintenance window not found") + err( + StatusCode::NOT_FOUND, + "not_found", + "Maintenance window not found", + ) })?; // Apply partial updates using existing values as defaults. - let new_label = req.label.unwrap_or(existing.label); + let new_label = req.label.unwrap_or(existing.label); let new_recurrence = req.recurrence.unwrap_or(existing.recurrence); - let new_start_at = req.start_at.unwrap_or(existing.start_at); - let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes); - let new_rec_day = req.recurrence_day.or(existing.recurrence_day); - let new_enabled = req.enabled.unwrap_or(existing.enabled); + let new_start_at = req.start_at.unwrap_or(existing.start_at); + let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes); + let new_rec_day = req.recurrence_day.or(existing.recurrence_day); + let new_enabled = req.enabled.unwrap_or(existing.enabled); // Validate recurrence_day for the final recurrence type. if new_recurrence == pm_core::models::WindowRecurrence::Weekly { match new_rec_day { - Some(d) if (0..=6).contains(&d) => {} + Some(d) if (0..=6).contains(&d) => {}, _ => { return Err(err( StatusCode::BAD_REQUEST, "bad_request", "Weekly recurrence requires recurrence_day 0-6", )); - } + }, } } if new_recurrence == pm_core::models::WindowRecurrence::Monthly { match new_rec_day { - Some(d) if (1..=31).contains(&d) => {} + Some(d) if (1..=31).contains(&d) => {}, _ => { return Err(err( StatusCode::BAD_REQUEST, "bad_request", "Monthly recurrence requires recurrence_day 1-31", )); - } + }, } } @@ -287,7 +307,11 @@ async fn update_window( .await .map_err(|e| { tracing::error!(error = %e, %win_id, "update_window: update failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) })?; log_event( @@ -320,17 +344,19 @@ async fn delete_window( auth: AuthUser, Path((host_id, win_id)): Path<(Uuid, Uuid)>, ) -> Result, (StatusCode, Json)> { - let result = sqlx::query( - "DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2", - ) - .bind(win_id) - .bind(host_id) - .execute(&state.db) - .await - .map_err(|e| { - tracing::error!(error = %e, %win_id, "delete_window: delete failed"); - err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error") - })?; + let result = sqlx::query("DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2") + .bind(win_id) + .bind(host_id) + .execute(&state.db) + .await + .map_err(|e| { + tracing::error!(error = %e, %win_id, "delete_window: delete failed"); + err( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_error", + "Database error", + ) + })?; if result.rows_affected() == 0 { return Err(err( diff --git a/crates/pm-web/src/routes/mod.rs b/crates/pm-web/src/routes/mod.rs index 923e9de..93a091c 100644 --- a/crates/pm-web/src/routes/mod.rs +++ b/crates/pm-web/src/routes/mod.rs @@ -1,15 +1,15 @@ //! Route modules for the pm-web API. pub mod auth; +pub mod azure_sso; pub mod ca; pub mod discovery; pub mod groups; pub mod hosts; -pub mod maintenance_windows; pub mod jobs; +pub mod maintenance_windows; +pub mod settings; pub mod status; pub mod users; -pub mod settings; -pub mod azure_sso; pub mod ws; pub mod reports; diff --git a/crates/pm-web/src/routes/reports.rs b/crates/pm-web/src/routes/reports.rs index 0ab5ea3..5efee2f 100644 --- a/crates/pm-web/src/routes/reports.rs +++ b/crates/pm-web/src/routes/reports.rs @@ -28,10 +28,10 @@ struct ReportQuery { pub fn router() -> Router { Router::new() - .route("/compliance", get(compliance_report)) + .route("/compliance", get(compliance_report)) .route("/patch-history", get(patch_history_report)) .route("/vulnerability", get(vulnerability_report)) - .route("/audit", get(audit_report)) + .route("/audit", get(audit_report)) } // --------------------------------------------------------------------------- @@ -58,21 +58,22 @@ async fn run_report( match result { Ok(bytes) => { let mut headers = HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - HeaderValue::from_static(ct), - ); + headers.insert(header::CONTENT_TYPE, HeaderValue::from_static(ct)); headers.insert( header::CONTENT_DISPOSITION, HeaderValue::from_str(&disposition) .unwrap_or_else(|_| HeaderValue::from_static("attachment")), ); (headers, Bytes::from(bytes)).into_response() - } + }, Err(e) => { tracing::error!(error = %e, "report generation failed"); - (StatusCode::INTERNAL_SERVER_ERROR, format!("Report error: {}", e)).into_response() - } + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Report error: {}", e), + ) + .into_response() + }, } } @@ -92,10 +93,13 @@ async fn compliance_report( }; let use_pdf = matches!(q.format.as_deref(), Some("pdf")); run_report( - state.db, params, use_pdf, + state.db, + params, + use_pdf, "compliance-report.csv", "compliance-report.pdf", - ).await + ) + .await } async fn patch_history_report( @@ -110,10 +114,13 @@ async fn patch_history_report( }; let use_pdf = matches!(q.format.as_deref(), Some("pdf")); run_report( - state.db, params, use_pdf, + state.db, + params, + use_pdf, "patch-history-report.csv", "patch-history-report.pdf", - ).await + ) + .await } async fn vulnerability_report( @@ -128,16 +135,16 @@ async fn vulnerability_report( }; let use_pdf = matches!(q.format.as_deref(), Some("pdf")); run_report( - state.db, params, use_pdf, + state.db, + params, + use_pdf, "vulnerability-report.csv", "vulnerability-report.pdf", - ).await + ) + .await } -async fn audit_report( - State(state): State, - Query(q): Query, -) -> Response { +async fn audit_report(State(state): State, Query(q): Query) -> Response { let params = ReportParams { report_type: ReportType::Audit, from: q.from, @@ -146,8 +153,11 @@ async fn audit_report( }; let use_pdf = matches!(q.format.as_deref(), Some("pdf")); run_report( - state.db, params, use_pdf, + state.db, + params, + use_pdf, "audit-report.csv", "audit-report.pdf", - ).await + ) + .await } diff --git a/crates/pm-web/src/routes/settings.rs b/crates/pm-web/src/routes/settings.rs index f669d7c..054d718 100644 --- a/crates/pm-web/src/routes/settings.rs +++ b/crates/pm-web/src/routes/settings.rs @@ -20,8 +20,8 @@ use lettre::{ transport::smtp::authentication::Credentials, AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor, }; -use pm_core::audit::{log_event, verify_integrity, AuditAction}; use pm_auth::rbac::AuthUser; +use pm_core::audit::{log_event, verify_integrity, AuditAction}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; @@ -132,7 +132,10 @@ pub fn router() -> Router { .route("/", get(get_settings).put(update_settings)) .route("/azure-sso/test", post(test_azure_sso)) .route("/smtp/test", post(test_smtp)) - .route("/ip-whitelist", get(get_ip_whitelist).put(update_ip_whitelist)) + .route( + "/ip-whitelist", + get(get_ip_whitelist).put(update_ip_whitelist), + ) .route("/audit-integrity", post(audit_integrity)) } @@ -155,26 +158,28 @@ fn admin_only(auth: &AuthUser) -> Result<(), (StatusCode, Json)> { async fn load_system_config( pool: &sqlx::PgPool, ) -> Result, (StatusCode, Json)> { - let rows: Vec<(String, String)> = sqlx::query_as( - "SELECT key, value FROM system_config", - ) - .fetch_all(pool) - .await - .map_err(|e| { - tracing::error!(error = %e, "Failed to load system_config"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), - ) - })?; + let rows: Vec<(String, String)> = sqlx::query_as("SELECT key, value FROM system_config") + .fetch_all(pool) + .await + .map_err(|e| { + tracing::error!(error = %e, "Failed to load system_config"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) + })?; Ok(rows.into_iter().collect()) } -fn build_settings_response(cfg: &HashMap, azure: AzureSsoConfig) -> SettingsResponse { +fn build_settings_response( + cfg: &HashMap, + azure: AzureSsoConfig, +) -> SettingsResponse { let get = |key: &str| -> String { cfg.get(key).cloned().unwrap_or_default() }; - let recipients: Vec = serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default(); + let recipients: Vec = + serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default(); SettingsResponse { azure_sso: azure, @@ -517,7 +522,7 @@ async fn test_azure_sso( "success": false, "message": "Azure SSO is not configured" }))); - } + }, }; if tenant_id.is_empty() { @@ -560,7 +565,7 @@ async fn test_azure_sso( "issuer": issuer }))) } - } + }, Ok(resp) => Ok(Json(json!({ "success": false, "message": format!("Failed to reach Azure AD: HTTP {}", resp.status()) @@ -593,11 +598,17 @@ async fn test_smtp( } let host = cfg.get("smtp_host").cloned().unwrap_or_default(); - let port: u16 = cfg.get("smtp_port").and_then(|v| v.parse().ok()).unwrap_or(587); + let port: u16 = cfg + .get("smtp_port") + .and_then(|v| v.parse().ok()) + .unwrap_or(587); let username = cfg.get("smtp_username").cloned().unwrap_or_default(); let password = cfg.get("smtp_password").cloned().unwrap_or_default(); let from_addr = cfg.get("smtp_from").cloned().unwrap_or_default(); - let tls_mode = cfg.get("smtp_tls_mode").cloned().unwrap_or_else(|| "starttls".to_string()); + let tls_mode = cfg + .get("smtp_tls_mode") + .cloned() + .unwrap_or_else(|| "starttls".to_string()); if host.is_empty() || from_addr.is_empty() { return Ok(Json(json!({ @@ -628,7 +639,9 @@ async fn send_smtp_test( from_addr: &str, tls_mode: &str, ) -> Result<(), String> { - let from_mailbox: Mailbox = from_addr.parse().map_err(|e| format!("Invalid from address: {}", e))?; + let from_mailbox: Mailbox = from_addr + .parse() + .map_err(|e| format!("Invalid from address: {}", e))?; let email = Message::builder() .from(from_mailbox.clone()) @@ -644,33 +657,39 @@ async fn send_smtp_test( .map_err(|e| format!("TLS relay error: {}", e))?; builder = builder.port(port); if !username.is_empty() { - builder = builder.credentials(Credentials::new(username.to_string(), password.to_string())); + builder = builder + .credentials(Credentials::new(username.to_string(), password.to_string())); } let transport = builder.build(); transport.send(email).await - } + }, "starttls" => { let mut builder = AsyncSmtpTransport::::starttls_relay(host) .map_err(|e| format!("STARTTLS relay error: {}", e))?; builder = builder.port(port); if !username.is_empty() { - builder = builder.credentials(Credentials::new(username.to_string(), password.to_string())); + builder = builder + .credentials(Credentials::new(username.to_string(), password.to_string())); } let transport = builder.build(); transport.send(email).await - } + }, _ => { // "none" — plaintext / no TLS - let mut builder = AsyncSmtpTransport::::builder_dangerous(host).port(port); + let mut builder = + AsyncSmtpTransport::::builder_dangerous(host).port(port); if !username.is_empty() { - builder = builder.credentials(Credentials::new(username.to_string(), password.to_string())); + builder = builder + .credentials(Credentials::new(username.to_string(), password.to_string())); } let transport = builder.build(); transport.send(email).await - } + }, }; - result.map(|_| ()).map_err(|e| format!("SMTP send error: {}", e)) + result + .map(|_| ()) + .map_err(|e| format!("SMTP send error: {}", e)) } // ============================================================ @@ -713,12 +732,12 @@ async fn update_ip_whitelist( // Validate each entry for entry in &req.entries { - if entry.parse::().is_err() - && entry.parse::().is_err() - { + if entry.parse::().is_err() && entry.parse::().is_err() { return Err(( StatusCode::BAD_REQUEST, - Json(json!({ "error": { "code": "bad_request", "message": format!("Invalid CIDR or IP: {}", entry) } })), + Json( + json!({ "error": { "code": "bad_request", "message": format!("Invalid CIDR or IP: {}", entry) } }), + ), )); } } diff --git a/crates/pm-web/src/routes/status.rs b/crates/pm-web/src/routes/status.rs index a9eac68..bd5db21 100644 --- a/crates/pm-web/src/routes/status.rs +++ b/crates/pm-web/src/routes/status.rs @@ -2,13 +2,7 @@ //! //! GET /api/v1/status/fleet — aggregate health and patch summary across all hosts. -use axum::{ - extract::State, - http::StatusCode, - response::Json, - routing::get, - Router, -}; +use axum::{extract::State, http::StatusCode, response::Json, routing::get, Router}; use serde::Serialize; use serde_json::{json, Value}; diff --git a/crates/pm-web/src/routes/users.rs b/crates/pm-web/src/routes/users.rs index be1e47c..cca15f2 100644 --- a/crates/pm-web/src/routes/users.rs +++ b/crates/pm-web/src/routes/users.rs @@ -15,11 +15,11 @@ use axum::{ routing::{delete, get, post, put}, Router, }; +use pm_auth::{hash_password, rbac::AuthUser, session::force_logout}; use pm_core::{ audit::{log_event, AuditAction}, - models::{User, CreateUserRequest, UpdateUserRequest}, + models::{CreateUserRequest, UpdateUserRequest, User}, }; -use pm_auth::{hash_password, rbac::AuthUser, session::force_logout}; use serde_json::{json, Value}; use uuid::Uuid; @@ -38,7 +38,10 @@ async fn list_users( auth: AuthUser, ) -> Result>, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } sqlx::query_as::<_, User>( @@ -52,7 +55,10 @@ async fn list_users( .map(Json) .map_err(|e| { tracing::error!(error = %e); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) }) } @@ -62,14 +68,24 @@ async fn create_user( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } let hash = hash_password(&req.password).map_err(|e| { - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) })?; - let role = if req.role == "admin" { "admin" } else { "operator" }; + let role = if req.role == "admin" { + "admin" + } else { + "operator" + }; let id: Uuid = sqlx::query_scalar( r#"INSERT INTO users (username, display_name, email, role, auth_provider, password_hash) @@ -84,12 +100,29 @@ async fn create_user( .fetch_one(&state.db) .await .map_err(|e| { - let msg = if e.to_string().contains("unique") { "Username or email already exists".to_string() } else { "Database error".to_string() }; - (StatusCode::CONFLICT, Json(json!({ "error": { "code": "conflict", "message": msg } }))) + let msg = if e.to_string().contains("unique") { + "Username or email already exists".to_string() + } else { + "Database error".to_string() + }; + ( + StatusCode::CONFLICT, + Json(json!({ "error": { "code": "conflict", "message": msg } })), + ) })?; - log_event(&state.db, AuditAction::UserCreated, Some(auth.user_id), Some(&auth.username), - Some("user"), Some(&id.to_string()), json!({ "username": req.username }), None, None).await; + log_event( + &state.db, + AuditAction::UserCreated, + Some(auth.user_id), + Some(&auth.username), + Some("user"), + Some(&id.to_string()), + json!({ "username": req.username }), + None, + None, + ) + .await; Ok(Json(json!({ "id": id, "message": "User created" }))) } @@ -108,12 +141,18 @@ async fn get_user( ) -> Result, (StatusCode, Json)> { // Users can see themselves; admin can see anyone if !auth.role.is_admin() && auth.user_id != id { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })), + )); } fetch_user(&state.db, id).await } -async fn fetch_user(pool: &sqlx::PgPool, id: Uuid) -> Result, (StatusCode, Json)> { +async fn fetch_user( + pool: &sqlx::PgPool, + id: Uuid, +) -> Result, (StatusCode, Json)> { let user: Option = sqlx::query_as( r#"SELECT id, username, display_name, email, role, auth_provider, mfa_enabled, is_active, force_password_reset, last_login_at, @@ -125,10 +164,18 @@ async fn fetch_user(pool: &sqlx::PgPool, id: Uuid) -> Result, (Status .await .map_err(|e| { tracing::error!(error = %e); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": "Database error" } }))) + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })), + ) })?; - user.map(Json).ok_or_else(|| (StatusCode::NOT_FOUND, Json(json!({ "error": { "code": "not_found", "message": "User not found" } })))) + user.map(Json).ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "User not found" } })), + ) + }) } async fn update_user( @@ -138,14 +185,25 @@ async fn update_user( Json(req): Json, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() && auth.user_id != id { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })), + )); } // Only admins can change role or active status if (req.role.is_some() || req.is_active.is_some()) && !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required to change role or status" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json( + json!({ "error": { "code": "forbidden", "message": "Admin role required to change role or status" } }), + ), + )); } - let role_str = req.role.as_deref().map(|r| if r == "admin" { "admin" } else { "operator" }); + let role_str = req + .role + .as_deref() + .map(|r| if r == "admin" { "admin" } else { "operator" }); let rows = sqlx::query( r#"UPDATE users SET @@ -163,15 +221,33 @@ async fn update_user( .bind(id) .execute(&state.db) .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))? + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, Json(json!({ "error": { "code": "not_found", "message": "User not found" } })))); + return Err(( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "User not found" } })), + )); } - log_event(&state.db, AuditAction::UserUpdated, Some(auth.user_id), Some(&auth.username), - Some("user"), Some(&id.to_string()), json!({}), None, None).await; + log_event( + &state.db, + AuditAction::UserUpdated, + Some(auth.user_id), + Some(&auth.username), + Some("user"), + Some(&id.to_string()), + json!({}), + None, + None, + ) + .await; Ok(Json(json!({ "message": "User updated" }))) } @@ -182,23 +258,51 @@ async fn delete_user( Path(id): Path, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } if auth.user_id == id { - return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": { "code": "bad_request", "message": "Cannot delete your own account" } })))); + return Err(( + StatusCode::BAD_REQUEST, + Json( + json!({ "error": { "code": "bad_request", "message": "Cannot delete your own account" } }), + ), + )); } let rows = sqlx::query("DELETE FROM users WHERE id = $1") - .bind(id).execute(&state.db).await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))? + .bind(id) + .execute(&state.db) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })? .rows_affected(); if rows == 0 { - return Err((StatusCode::NOT_FOUND, Json(json!({ "error": { "code": "not_found", "message": "User not found" } })))); + return Err(( + StatusCode::NOT_FOUND, + Json(json!({ "error": { "code": "not_found", "message": "User not found" } })), + )); } - log_event(&state.db, AuditAction::UserDeleted, Some(auth.user_id), Some(&auth.username), - Some("user"), Some(&id.to_string()), json!({}), None, None).await; + log_event( + &state.db, + AuditAction::UserDeleted, + Some(auth.user_id), + Some(&auth.username), + Some("user"), + Some(&id.to_string()), + json!({}), + None, + None, + ) + .await; Ok(Json(json!({ "message": "User deleted" }))) } @@ -209,11 +313,20 @@ async fn revoke_user_sessions( Path(id): Path, ) -> Result, (StatusCode, Json)> { if !auth.role.is_admin() { - return Err((StatusCode::FORBIDDEN, Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })))); + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })), + )); } - let count = force_logout(&state.db, id).await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?; + let count = force_logout(&state.db, id).await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })), + ) + })?; - Ok(Json(json!({ "message": "Sessions revoked", "count": count }))) + Ok(Json( + json!({ "message": "Sessions revoked", "count": count }), + )) } diff --git a/crates/pm-web/src/routes/ws.rs b/crates/pm-web/src/routes/ws.rs index 9f3128b..633f763 100644 --- a/crates/pm-web/src/routes/ws.rs +++ b/crates/pm-web/src/routes/ws.rs @@ -4,8 +4,8 @@ //! GET /api/v1/ws/jobs — browser WebSocket endpoint (ticket-authenticated) use axum::{ - extract::{Query, State, WebSocketUpgrade}, extract::ws::{Message, WebSocket}, + extract::{Query, State, WebSocketUpgrade}, http::StatusCode, response::{Json, Response}, routing::{get, post}, @@ -59,7 +59,6 @@ fn err( // ── POST /api/v1/ws/ticket ──────────────────────────────────────────────────── - /// Issue a single-use WebSocket authentication ticket (60 s expiry). pub async fn create_ticket_handler( State(state): State, @@ -109,7 +108,7 @@ pub async fn ws_handler( "invalid_ticket", "WebSocket ticket not found or already used", )); - } + }, Some(t) => { if t.expires_at < Utc::now() { drop(t); @@ -121,7 +120,7 @@ pub async fn ws_handler( )); } t.clone() - } + }, } }; // Single-use: remove immediately after validation. @@ -140,11 +139,7 @@ pub async fn ws_handler( // ── WebSocket handler ───────────────────────────────────────────────────────── /// Drive the browser WebSocket: LISTEN on `job_update` and forward payloads. -async fn handle_browser_ws( - mut socket: WebSocket, - db: sqlx::PgPool, - ticket: WsTicket, -) { +async fn handle_browser_ws(mut socket: WebSocket, db: sqlx::PgPool, ticket: WsTicket) { // Acquire a dedicated PG listener connection. let mut listener = match PgListener::connect_with(&db).await { Ok(l) => l, @@ -156,7 +151,7 @@ async fn handle_browser_ws( )) .await; return; - } + }, }; if let Err(e) = listener.listen("job_update").await { diff --git a/crates/pm-worker/src/agent_loader.rs b/crates/pm-worker/src/agent_loader.rs index 812a723..d6b6cbf 100644 --- a/crates/pm-worker/src/agent_loader.rs +++ b/crates/pm-worker/src/agent_loader.rs @@ -34,12 +34,12 @@ pub fn load_agent_certs(security: &SecurityConfig) -> anyhow::Result })?; let ca_cert = std::fs::read(&security.ca_cert_path).map_err(|e| { - anyhow::anyhow!( - "Failed to read CA cert '{}': {}", - security.ca_cert_path, - e - ) + anyhow::anyhow!("Failed to read CA cert '{}': {}", security.ca_cert_path, e) })?; - Ok(AgentCerts { client_cert, client_key, ca_cert }) + Ok(AgentCerts { + client_cert, + client_key, + ca_cert, + }) } diff --git a/crates/pm-worker/src/email.rs b/crates/pm-worker/src/email.rs index c53ec03..115991c 100644 --- a/crates/pm-worker/src/email.rs +++ b/crates/pm-worker/src/email.rs @@ -80,7 +80,8 @@ async fn load_notification_settings(pool: &PgPool) -> NotificationSettings { .unwrap_or_default() }; - let recipients: Vec = serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default(); + let recipients: Vec = + serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default(); NotificationSettings { email_enabled: get("notification_email_enabled") == "true", @@ -90,9 +91,7 @@ async fn load_notification_settings(pool: &PgPool) -> NotificationSettings { } /// Build an async SMTP transport from settings. -fn build_transport( - settings: &SmtpSettings, -) -> Result, String> { +fn build_transport(settings: &SmtpSettings) -> Result, String> { match settings.tls_mode.as_str() { "tls" => { let mut builder = AsyncSmtpTransport::::relay(&settings.host) @@ -105,7 +104,7 @@ fn build_transport( )); } Ok(builder.build()) - } + }, "starttls" => { let mut builder = AsyncSmtpTransport::::starttls_relay(&settings.host) .map_err(|e| format!("STARTTLS relay error: {}", e))?; @@ -117,11 +116,12 @@ fn build_transport( )); } Ok(builder.build()) - } + }, _ => { // "none" — plaintext / no TLS - let mut builder = AsyncSmtpTransport::::builder_dangerous(&settings.host) - .port(settings.port); + let mut builder = + AsyncSmtpTransport::::builder_dangerous(&settings.host) + .port(settings.port); if !settings.username.is_empty() { builder = builder.credentials(Credentials::new( settings.username.clone(), @@ -129,21 +129,17 @@ fn build_transport( )); } Ok(builder.build()) - } + }, } } /// Send an email notification. Returns true if the email was sent successfully. -async fn send_email( - pool: &PgPool, - subject: &str, - body: &str, -) -> bool { +async fn send_email(pool: &PgPool, subject: &str, body: &str) -> bool { let smtp = match load_smtp_settings(pool).await { s if !s.enabled => { tracing::debug!("SMTP not enabled, skipping email notification"); return false; - } + }, s => s, }; @@ -169,7 +165,7 @@ async fn send_email( Err(e) => { tracing::error!(error = %e, "Invalid from address for email notification"); return false; - } + }, }; let mut builder = Message::builder() @@ -184,7 +180,7 @@ async fn send_email( Err(e) => { tracing::error!(error = %e, recipient = %recipient, "Invalid recipient address"); continue; - } + }, }; builder = builder.to(mailbox); } @@ -194,7 +190,7 @@ async fn send_email( Err(e) => { tracing::error!(error = %e, "Failed to build email message"); return false; - } + }, }; let transport = match build_transport(&smtp) { @@ -202,18 +198,18 @@ async fn send_email( Err(e) => { tracing::error!(error = %e, "Failed to build SMTP transport"); return false; - } + }, }; match transport.send(email).await { Ok(_) => { tracing::info!(subject, "Email notification sent successfully"); true - } + }, Err(e) => { tracing::error!(error = %e, subject, "Failed to send email notification"); false - } + }, } } @@ -300,7 +296,10 @@ pub async fn send_maintenance_window_reminder_email( window_label: &str, start_at: &str, ) { - let subject = format!("[Patch Manager] Upcoming Maintenance Window: {}", window_label); + let subject = format!( + "[Patch Manager] Upcoming Maintenance Window: {}", + window_label + ); let body = format!( "Maintenance window reminder:\n\ Host: {host_fqdn}\n\ diff --git a/crates/pm-worker/src/health_poller.rs b/crates/pm-worker/src/health_poller.rs index 46cd76c..18d12cf 100644 --- a/crates/pm-worker/src/health_poller.rs +++ b/crates/pm-worker/src/health_poller.rs @@ -7,15 +7,9 @@ use std::sync::Arc; use pm_agent_client::{AgentClient, AgentClientError}; -use pm_core::{ - config::AppConfig, - models::HostHealthStatus, -}; +use pm_core::{config::AppConfig, models::HostHealthStatus}; use sqlx::{FromRow, PgPool}; -use tokio::{ - sync::Semaphore, - time, -}; +use tokio::{sync::Semaphore, time}; use uuid::Uuid; use crate::agent_loader::load_agent_certs; @@ -37,10 +31,7 @@ pub async fn run_health_poller(pool: PgPool, config: Arc) { let interval_secs = config.worker.health_poll_interval_secs; let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs)); - tracing::info!( - interval_secs, - "Health poller started" - ); + tracing::info!(interval_secs, "Health poller started"); loop { ticker.tick().await; @@ -51,7 +42,7 @@ pub async fn run_health_poller(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "Health poller: failed to load agent certs — skipping cycle"); continue; - } + }, }; let client_cert = Arc::new(certs.client_cert); @@ -69,7 +60,7 @@ pub async fn run_health_poller(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "Health poller: failed to fetch hosts"); continue; - } + }, }; if hosts.is_empty() { @@ -107,7 +98,7 @@ pub async fn run_health_poller(pool: PgPool, config: Arc) { Ok(HostHealthStatus::Healthy) => healthy += 1, Ok(HostHealthStatus::Degraded) => degraded += 1, Ok(HostHealthStatus::Unreachable) => unreachable += 1, - Ok(_) => {} + Ok(_) => {}, Err(e) => tracing::error!(error = %e, "Health poller task panicked"), } } @@ -144,25 +135,37 @@ async fn poll_host_health( error = %e, "Health poller: failed to build AgentClient" ); - (HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default())) - } + ( + HostHealthStatus::Unreachable, + serde_json::Value::Object(Default::default()), + ) + }, Ok(client) => match client.health().await { Ok(data) => { let payload = serde_json::to_value(&data).unwrap_or_default(); (HostHealthStatus::Healthy, payload) - } + }, Err(AgentClientError::Timeout) => { tracing::warn!(host_id = %host.id, "Health poller: agent timed out"); - (HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default())) - } + ( + HostHealthStatus::Unreachable, + serde_json::Value::Object(Default::default()), + ) + }, Err(AgentClientError::Connect(_)) => { tracing::warn!(host_id = %host.id, "Health poller: agent connection refused"); - (HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default())) - } + ( + HostHealthStatus::Unreachable, + serde_json::Value::Object(Default::default()), + ) + }, Err(e) => { tracing::warn!(host_id = %host.id, error = %e, "Health poller: agent error"); - (HostHealthStatus::Degraded, serde_json::Value::Object(Default::default())) - } + ( + HostHealthStatus::Degraded, + serde_json::Value::Object(Default::default()), + ) + }, }, }; diff --git a/crates/pm-worker/src/job_executor.rs b/crates/pm-worker/src/job_executor.rs index 1e7912f..e25b58f 100644 --- a/crates/pm-worker/src/job_executor.rs +++ b/crates/pm-worker/src/job_executor.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use chrono::{Duration as ChronoDuration, Utc}; -use pm_agent_client::{AgentClient, types::ApplyPatchesRequest}; +use pm_agent_client::{types::ApplyPatchesRequest, AgentClient}; use pm_core::config::AppConfig; use sqlx::{FromRow, PgPool}; use tokio::{sync::Semaphore, time}; @@ -71,13 +71,13 @@ struct RetryRow { #[derive(Debug, FromRow)] struct StatusCounts { - running_count: i64, - pending_count: i64, - queued_count: i64, + running_count: i64, + pending_count: i64, + queued_count: i64, succeeded_count: i64, - failed_count: i64, + failed_count: i64, cancelled_count: i64, - total_count: i64, + total_count: i64, } // ───────────────────────────────────────────────────────────────────────────── @@ -125,12 +125,8 @@ async fn run_notify_listener(pool: PgPool, config: Arc) { /// Inner NOTIFY loop — returns `Err` only on a fatal connection error so the /// outer loop can reconnect. -async fn notify_listen_loop( - pool: &PgPool, - config: &Arc, -) -> anyhow::Result<()> { - let mut listener = - sqlx::postgres::PgListener::connect(&config.database.url).await?; +async fn notify_listen_loop(pool: &PgPool, config: &Arc) -> anyhow::Result<()> { + let mut listener = sqlx::postgres::PgListener::connect(&config.database.url).await?; listener.listen("job_enqueued").await?; tracing::debug!("Job executor NOTIFY listener connected"); @@ -148,7 +144,7 @@ async fn notify_listen_loop( "Job executor: invalid UUID in job_enqueued payload" ); continue; - } + }, }; let (p, c) = (pool.clone(), config.clone()); @@ -301,7 +297,7 @@ pub async fn process_job(pool: PgPool, config: Arc, job_id: Uuid) { Err(e) => { tracing::error!(%job_id, error = %e, "process_job: failed to fetch queued hosts"); return; - } + }, }; if hosts.is_empty() { @@ -317,11 +313,11 @@ pub async fn process_job(pool: PgPool, config: Arc, job_id: Uuid) { Err(e) => { tracing::error!(%job_id, error = %e, "process_job: semaphore closed"); break; - } + }, }; let (p, c) = (pool.clone(), config.clone()); - let pjh_id = host.id; + let pjh_id = host.id; let host_id = host.host_id; tokio::spawn(async move { @@ -338,11 +334,11 @@ pub async fn process_job(pool: PgPool, config: Arc, job_id: Uuid) { /// Connect to a single host agent, submit the patch job, and record the /// agent-assigned async job ID for later polling. async fn execute_host_job( - pool: PgPool, - config: Arc, - job_id: Uuid, + pool: PgPool, + config: Arc, + job_id: Uuid, host_id: Uuid, - pjh_id: Uuid, + pjh_id: Uuid, ) { tracing::info!(%job_id, %host_id, %pjh_id, "execute_host_job: starting"); @@ -364,34 +360,33 @@ async fn execute_host_job( ) .await; return; - } + }, Err(e) => { tracing::error!(%host_id, error = %e, "execute_host_job: DB error fetching host"); handle_host_failure(pool, pjh_id, format!("DB error fetching host: {e}")).await; return; - } + }, }; // ── 2. Fetch the job's patch_selection ────────────────────────────────── - let patch_sel: JobPatchSelection = match sqlx::query_as( - "SELECT patch_selection FROM patch_jobs WHERE id = $1", - ) - .bind(job_id) - .fetch_optional(&pool) - .await - { - Ok(Some(row)) => row, - Ok(None) => { - tracing::error!(%job_id, "execute_host_job: parent job not found"); - handle_host_failure(pool, pjh_id, format!("Parent job {job_id} not found")).await; - return; - } - Err(e) => { - tracing::error!(%job_id, error = %e, "execute_host_job: DB error fetching job"); - handle_host_failure(pool, pjh_id, format!("DB error fetching job: {e}")).await; - return; - } - }; + let patch_sel: JobPatchSelection = + match sqlx::query_as("SELECT patch_selection FROM patch_jobs WHERE id = $1") + .bind(job_id) + .fetch_optional(&pool) + .await + { + Ok(Some(row)) => row, + Ok(None) => { + tracing::error!(%job_id, "execute_host_job: parent job not found"); + handle_host_failure(pool, pjh_id, format!("Parent job {job_id} not found")).await; + return; + }, + Err(e) => { + tracing::error!(%job_id, error = %e, "execute_host_job: DB error fetching job"); + handle_host_failure(pool, pjh_id, format!("DB error fetching job: {e}")).await; + return; + }, + }; let packages: Vec = serde_json::from_value(patch_sel.patch_selection).unwrap_or_default(); @@ -403,7 +398,7 @@ async fn execute_host_job( tracing::error!(%host_id, error = %e, "execute_host_job: failed to load agent certs"); handle_host_failure(pool, pjh_id, format!("Failed to load agent certs: {e}")).await; return; - } + }, }; // ── 4. Build AgentClient ───────────────────────────────────────────────── @@ -419,7 +414,7 @@ async fn execute_host_job( tracing::error!(%host_id, error = %e, "execute_host_job: failed to build AgentClient"); handle_host_failure(pool, pjh_id, format!("Failed to build agent client: {e}")).await; return; - } + }, }; // ── 5. Mark pjh as running ─────────────────────────────────────────────── @@ -439,7 +434,10 @@ async fn execute_host_job( } // ── 6. Submit the patch job to the agent ───────────────────────────────── - let req = ApplyPatchesRequest { packages, allow_reboot: true }; + let req = ApplyPatchesRequest { + packages, + allow_reboot: true, + }; match client.apply_patches(&req).await { Ok(resp) => { @@ -450,13 +448,12 @@ async fn execute_host_job( ); // ── 7. Store agent_job_id; status stays 'running' (agent is async) ── - if let Err(e) = sqlx::query( - "UPDATE patch_job_hosts SET agent_job_id = $1 WHERE id = $2", - ) - .bind(&resp.job_id) - .bind(pjh_id) - .execute(&pool) - .await + if let Err(e) = + sqlx::query("UPDATE patch_job_hosts SET agent_job_id = $1 WHERE id = $2") + .bind(&resp.job_id) + .bind(pjh_id) + .execute(&pool) + .await { tracing::error!( %pjh_id, @@ -464,11 +461,11 @@ async fn execute_host_job( "execute_host_job: failed to store agent_job_id" ); } - } + }, Err(e) => { tracing::warn!(%pjh_id, error = %e, "execute_host_job: agent rejected job"); handle_host_failure(pool, pjh_id, format!("Agent error: {e}")).await; - } + }, } } @@ -498,7 +495,7 @@ pub async fn poll_running_jobs(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "poll_running_jobs: DB query failed"); return; - } + }, }; for row in rows { @@ -510,11 +507,7 @@ pub async fn poll_running_jobs(pool: PgPool, config: Arc) { } /// Poll one running host entry and update its status from the agent response. -async fn poll_single_host( - pool: PgPool, - config: Arc, - row: PatchJobHostRunning, -) { +async fn poll_single_host(pool: PgPool, config: Arc, row: PatchJobHostRunning) { let certs = match load_agent_certs(&config.security) { Ok(c) => c, Err(e) => { @@ -524,7 +517,7 @@ async fn poll_single_host( "poll_single_host: failed to load agent certs" ); return; - } + }, }; let client = match AgentClient::new( @@ -542,7 +535,7 @@ async fn poll_single_host( "poll_single_host: failed to build AgentClient" ); return; - } + }, }; let status = match client.job_status(&row.agent_job_id).await { @@ -555,7 +548,7 @@ async fn poll_single_host( "poll_single_host: agent status call failed" ); return; - } + }, }; match status.status.as_str() { @@ -578,14 +571,14 @@ async fn poll_single_host( tracing::error!(pjh_id = %row.id, error = %e, "poll_single_host: update failed"); } sync_job_status(&pool, row.job_id).await; - } + }, "failed" => { tracing::warn!(pjh_id = %row.id, "poll_single_host: agent job failed"); let err_msg = status .error .unwrap_or_else(|| "Agent reported failure (no detail)".to_string()); handle_host_failure(pool, row.id, err_msg).await; - } + }, "running" | "queued" => { // Still in progress — nothing to update; will poll again next cycle. tracing::debug!( @@ -593,14 +586,14 @@ async fn poll_single_host( agent_status = %status.status, "poll_single_host: job still in progress" ); - } + }, other => { tracing::warn!( pjh_id = %row.id, agent_status = %other, "poll_single_host: unexpected agent status — ignoring" ); - } + }, } } @@ -624,7 +617,7 @@ async fn handle_host_failure(pool: PgPool, pjh_id: Uuid, error_msg: String) { Err(e) => { tracing::error!(%pjh_id, error = %e, "handle_host_failure: DB error fetching retry row"); return; - } + }, }; let row = match row { @@ -632,7 +625,7 @@ async fn handle_host_failure(pool: PgPool, pjh_id: Uuid, error_msg: String) { None => { tracing::error!(%pjh_id, "handle_host_failure: pjh row not found"); return; - } + }, }; if row.retry_count < 3 { @@ -736,7 +729,7 @@ async fn sync_job_status(pool: &PgPool, job_id: Uuid) { Err(e) => { tracing::error!(%job_id, error = %e, "sync_job_status: DB query failed"); return; - } + }, }; // Determine the aggregate status. @@ -745,19 +738,19 @@ async fn sync_job_status(pool: &PgPool, job_id: Uuid) { if counts.running_count > 0 || counts.pending_count > 0 || counts.queued_count > 0 { // Still work in flight — keep parent running. - new_status = "running"; + new_status = "running"; set_completed = false; } else if counts.total_count > 0 && counts.succeeded_count == counts.total_count { // Every host succeeded. - new_status = "succeeded"; + new_status = "succeeded"; set_completed = true; } else if counts.total_count > 0 && counts.cancelled_count == counts.total_count { // Every host cancelled. - new_status = "cancelled"; + new_status = "cancelled"; set_completed = true; } else if counts.failed_count > 0 { // At least one failure and nothing still active → failed (partial counts too). - new_status = "failed"; + new_status = "failed"; set_completed = true; } else { // Fallback: nothing actionable yet. @@ -789,13 +782,11 @@ async fn sync_job_status(pool: &PgPool, job_id: Uuid) { .execute(pool) .await } else { - sqlx::query( - "UPDATE patch_jobs SET status = $2 WHERE id = $1", - ) - .bind(job_id) - .bind(new_status) - .execute(pool) - .await + sqlx::query("UPDATE patch_jobs SET status = $2 WHERE id = $1") + .bind(job_id) + .bind(new_status) + .execute(pool) + .await }; if let Err(e) = result { @@ -812,13 +803,8 @@ async fn sync_job_status(pool: &PgPool, job_id: Uuid) { let failed = counts.failed_count; tokio::spawn(async move { - email::send_job_completion_email( - &pool_clone, - &job_id_str, - total, - succeeded, - failed, - ).await; + email::send_job_completion_email(&pool_clone, &job_id_str, total, succeeded, failed) + .await; // If there are failures, also send failure emails per host if failed > 0 { @@ -838,16 +824,12 @@ async fn sync_job_status(pool: &PgPool, job_id: Uuid) { Err(e) => { tracing::error!(%job_id, error = %e, "sync_job_status: failed to fetch failed hosts for email"); Vec::new() - } + }, }; for (fqdn, error_msg) in failed_hosts { - email::send_patch_failure_email( - &pool_clone, - &fqdn, - &job_id_str, - &error_msg, - ).await; + email::send_patch_failure_email(&pool_clone, &fqdn, &job_id_str, &error_msg) + .await; } } }); @@ -878,7 +860,7 @@ pub async fn retry_pending_jobs(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "retry_pending_jobs: DB query failed"); return; - } + }, }; for row in rows { diff --git a/crates/pm-worker/src/main.rs b/crates/pm-worker/src/main.rs index 333dae8..ba08f44 100644 --- a/crates/pm-worker/src/main.rs +++ b/crates/pm-worker/src/main.rs @@ -7,27 +7,23 @@ mod agent_loader; mod audit_verifier; mod email; mod health_poller; +mod job_executor; mod maintenance_scheduler; mod patch_poller; mod refresh_listener; -mod job_executor; mod ws_relay; -use pm_core::{ - config::AppConfig, - db, - logging, -}; +use pm_core::{config::AppConfig, db, logging}; use sqlx::PgPool; use std::{sync::Arc, time::Duration}; use tokio::time; use audit_verifier::run_audit_verifier; use health_poller::run_health_poller; +use job_executor::run_job_executor; use maintenance_scheduler::run_maintenance_scheduler; use patch_poller::run_patch_poller; use refresh_listener::run_refresh_listener; -use job_executor::run_job_executor; use ws_relay::run_ws_relay; /// Minimum number of applied migrations the worker requires before @@ -44,16 +40,18 @@ async fn main() -> anyhow::Result<()> { let config_path = std::env::var("PATCH_MANAGER_CONFIG") .unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string()); - let config = AppConfig::load(&config_path) - .unwrap_or_else(|_| { - eprintln!("Config file not found or invalid, using defaults"); - AppConfig::default() - }); + let config = AppConfig::load(&config_path).unwrap_or_else(|_| { + eprintln!("Config file not found or invalid, using defaults"); + AppConfig::default() + }); // Initialize logging logging::init(&config.logging); - tracing::info!(version = env!("CARGO_PKG_VERSION"), "patch-manager-worker starting"); + tracing::info!( + version = env!("CARGO_PKG_VERSION"), + "patch-manager-worker starting" + ); // Initialize database pool let pool = db::init_pool(&config.database).await?; @@ -114,17 +112,17 @@ async fn wait_for_schema(pool: &PgPool) -> anyhow::Result<()> { Ok(count) if count >= REQUIRED_MIGRATION_COUNT => { tracing::info!(migration_count = count, "Schema version check passed"); return Ok(()); - } + }, Ok(count) => { tracing::warn!( migration_count = count, required = REQUIRED_MIGRATION_COUNT, "Schema not ready, waiting..." ); - } + }, Err(e) => { tracing::warn!(error = %e, "Schema version check failed, retrying..."); - } + }, } if tokio::time::Instant::now() >= deadline { diff --git a/crates/pm-worker/src/maintenance_scheduler.rs b/crates/pm-worker/src/maintenance_scheduler.rs index 3857f9f..880026a 100644 --- a/crates/pm-worker/src/maintenance_scheduler.rs +++ b/crates/pm-worker/src/maintenance_scheduler.rs @@ -144,7 +144,7 @@ async fn dispatch_open_window_jobs(pool: PgPool, config: Arc) { "dispatch_open_window_jobs: queued jobs query failed" ); continue; - } + }, }; for job in job_ids { diff --git a/crates/pm-worker/src/patch_poller.rs b/crates/pm-worker/src/patch_poller.rs index cf40648..82a96da 100644 --- a/crates/pm-worker/src/patch_poller.rs +++ b/crates/pm-worker/src/patch_poller.rs @@ -9,10 +9,7 @@ use std::sync::Arc; use pm_agent_client::AgentClient; use pm_core::config::AppConfig; use sqlx::{FromRow, PgPool}; -use tokio::{ - sync::Semaphore, - time, -}; +use tokio::{sync::Semaphore, time}; use uuid::Uuid; use crate::agent_loader::load_agent_certs; @@ -34,10 +31,7 @@ pub async fn run_patch_poller(pool: PgPool, config: Arc) { let interval_secs = config.worker.patch_poll_interval_secs; let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs)); - tracing::info!( - interval_secs, - "Patch poller started" - ); + tracing::info!(interval_secs, "Patch poller started"); loop { ticker.tick().await; @@ -47,7 +41,7 @@ pub async fn run_patch_poller(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "Patch poller: failed to load agent certs — skipping cycle"); continue; - } + }, }; let client_cert = Arc::new(certs.client_cert); @@ -64,7 +58,7 @@ pub async fn run_patch_poller(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "Patch poller: failed to fetch hosts"); continue; - } + }, }; if hosts.is_empty() { @@ -102,16 +96,11 @@ pub async fn run_patch_poller(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "Patch poller task panicked"); failed += 1; - } + }, } } - tracing::info!( - total, - succeeded, - failed, - "Patch poll cycle complete" - ); + tracing::info!(total, succeeded, failed, "Patch poll cycle complete"); } } @@ -135,7 +124,7 @@ async fn poll_host_patches( Err(e) => { tracing::warn!(host_id = %host.id, error = %e, "Patch poller: failed to build AgentClient"); return false; - } + }, }; // Fetch patches and packages concurrently. @@ -147,7 +136,7 @@ async fn poll_host_patches( Err(e) => { tracing::warn!(host_id = %host.id, error = %e, "Patch poller: patches() failed"); return false; - } + }, }; let packages_data = match packages_result { @@ -155,7 +144,7 @@ async fn poll_host_patches( Err(e) => { tracing::warn!(host_id = %host.id, error = %e, "Patch poller: packages_upgradable() failed"); return false; - } + }, }; let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default(); @@ -188,12 +177,10 @@ async fn poll_host_patches( } // Update hosts.last_patch_at. - if let Err(e) = sqlx::query( - "UPDATE hosts SET last_patch_at = NOW() WHERE id = $1", - ) - .bind(host.id) - .execute(&pool) - .await + if let Err(e) = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1") + .bind(host.id) + .execute(&pool) + .await { tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to update last_patch_at"); } diff --git a/crates/pm-worker/src/refresh_listener.rs b/crates/pm-worker/src/refresh_listener.rs index e51aaeb..d4b2915 100644 --- a/crates/pm-worker/src/refresh_listener.rs +++ b/crates/pm-worker/src/refresh_listener.rs @@ -8,10 +8,7 @@ use std::sync::Arc; use pm_agent_client::{AgentClient, AgentClientError}; -use pm_core::{ - config::AppConfig, - models::HostHealthStatus, -}; +use pm_core::{config::AppConfig, models::HostHealthStatus}; use sqlx::{FromRow, PgPool}; use tokio::time; use uuid::Uuid; @@ -46,8 +43,7 @@ pub async fn run_refresh_listener(pool: PgPool, config: Arc) { /// Inner loop — returns `Err` only on a fatal listener error so the outer /// loop can reconnect. async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> { - let mut listener = - sqlx::postgres::PgListener::connect(&config.database.url).await?; + let mut listener = sqlx::postgres::PgListener::connect(&config.database.url).await?; listener.listen("refresh_requested").await?; @@ -68,7 +64,7 @@ async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> { "Refresh listener: invalid UUID in notification payload" ); continue; - } + }, }; // Fetch the host from the database. @@ -85,7 +81,7 @@ async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> { None => { tracing::warn!(%host_id, "Refresh listener: host not found"); continue; - } + }, }; // Load certs for this refresh. @@ -98,7 +94,7 @@ async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> { "Refresh listener: failed to load agent certs" ); continue; - } + }, }; // Spawn the actual work so the listener loop is not blocked. @@ -137,7 +133,7 @@ async fn refresh_host( ); persist_health_unreachable(&pool, host.id).await; return; - } + }, }; // ── Health ──────────────────────────────────────────────────────────── @@ -145,15 +141,21 @@ async fn refresh_host( Ok(data) => { let payload = serde_json::to_value(&data).unwrap_or_default(); (HostHealthStatus::Healthy, payload) - } + }, Err(AgentClientError::Timeout) | Err(AgentClientError::Connect(_)) => { tracing::warn!(host_id = %host.id, "Refresh: agent unreachable"); - (HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default())) - } + ( + HostHealthStatus::Unreachable, + serde_json::Value::Object(Default::default()), + ) + }, Err(e) => { tracing::warn!(host_id = %host.id, error = %e, "Refresh: health error"); - (HostHealthStatus::Degraded, serde_json::Value::Object(Default::default())) - } + ( + HostHealthStatus::Degraded, + serde_json::Value::Object(Default::default()), + ) + }, }; persist_health(&pool, host.id, &health_status, &health_payload).await; @@ -164,8 +166,7 @@ async fn refresh_host( match (patches_result, packages_result) { (Ok(patches_data), Ok(packages_data)) => { - let available_patches = - serde_json::to_value(&patches_data.patches).unwrap_or_default(); + let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default(); let installed_packages = serde_json::to_value(&packages_data.packages).unwrap_or_default(); let patch_count = patches_data.total as i32; @@ -196,12 +197,10 @@ async fn refresh_host( "Refresh: failed to insert patch data" ); } else { - let _ = sqlx::query( - "UPDATE hosts SET last_patch_at = NOW() WHERE id = $1", - ) - .bind(host.id) - .execute(&pool) - .await; + let _ = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1") + .bind(host.id) + .execute(&pool) + .await; tracing::info!( host_id = %host.id, @@ -210,14 +209,14 @@ async fn refresh_host( "On-demand refresh complete" ); } - } + }, (Err(e), _) | (_, Err(e)) => { tracing::warn!( host_id = %host.id, error = %e, "Refresh: failed to collect patch data" ); - } + }, } } @@ -252,13 +251,12 @@ async fn persist_health( ); } - if let Err(e) = sqlx::query( - "UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1", - ) - .bind(host_id) - .bind(status) - .execute(pool) - .await + if let Err(e) = + sqlx::query("UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1") + .bind(host_id) + .bind(status) + .execute(pool) + .await { tracing::error!(%host_id, error = %e, "Refresh: failed to update host health_status"); } diff --git a/crates/pm-worker/src/ws_relay.rs b/crates/pm-worker/src/ws_relay.rs index 35aa604..ca55108 100644 --- a/crates/pm-worker/src/ws_relay.rs +++ b/crates/pm-worker/src/ws_relay.rs @@ -5,27 +5,18 @@ //! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS //! handler can forward the event to connected clients. -use std::{ - collections::HashSet, - sync::Arc, - time::Duration, -}; +use std::{collections::HashSet, sync::Arc, time::Duration}; use anyhow::Context; use futures::StreamExt; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer}, - ClientConfig as TlsClientConfig, - RootCertStore, + ClientConfig as TlsClientConfig, RootCertStore, }; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use tokio::sync::Mutex; -use tokio_tungstenite::{ - connect_async_tls_with_config, - tungstenite::protocol::Message, - Connector, -}; +use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::protocol::Message, Connector}; use uuid::Uuid; use pm_agent_client::client::DEFAULT_AGENT_PORT; @@ -84,7 +75,7 @@ pub async fn run_ws_relay(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "ws_relay: DB poll failed"); continue; - } + }, }; for row in rows { @@ -101,12 +92,12 @@ pub async fn run_ws_relay(pool: PgPool, config: Arc) { Err(e) => { tracing::error!(error = %e, "ws_relay: TLS config error"); continue; - } + }, }; active.lock().await.insert(key); - let pool_c = pool.clone(); + let pool_c = pool.clone(); let active_c = active.clone(); tokio::spawn(async move { @@ -164,12 +155,15 @@ async fn query_running_jobs(pool: &PgPool) -> anyhow::Result async fn build_tls_config(config: &AppConfig) -> anyhow::Result { let sec = &config.security; - let cert_pem = tokio::fs::read(&sec.agent_client_cert_path).await + let cert_pem = tokio::fs::read(&sec.agent_client_cert_path) + .await .with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?; - let key_pem = tokio::fs::read(&sec.agent_client_key_path).await - .with_context(|| format!("read agent client key '{}'" , sec.agent_client_key_path))?; - let ca_pem = tokio::fs::read(&sec.ca_cert_path).await - .with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?; + let key_pem = tokio::fs::read(&sec.agent_client_key_path) + .await + .with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?; + let ca_pem = tokio::fs::read(&sec.ca_cert_path) + .await + .with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?; // Parse client certificate chain. let client_certs: Vec> = { @@ -207,8 +201,8 @@ async fn build_tls_config(config: &AppConfig) -> anyhow::Result // ── Per-job relay ───────────────────────────────────────────────────────────── async fn relay_one_job( - pool: &PgPool, - row: &RunningHostJob, + pool: &PgPool, + row: &RunningHostJob, tls_config: Arc, ) -> anyhow::Result<()> { let url = format!( @@ -229,7 +223,7 @@ async fn relay_one_job( while let Some(frame) = stream.next().await { let frame = match frame { - Ok(f) => f, + Ok(f) => f, Err(e) => { tracing::warn!( error = %e, @@ -238,16 +232,16 @@ async fn relay_one_job( "WS relay: stream error" ); break; - } + }, }; let text = match frame { - Message::Text(t) => t.to_string(), + Message::Text(t) => t.to_string(), Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(), - Message::Close(_) => { + Message::Close(_) => { tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly"); break; - } + }, _ => continue, }; @@ -256,14 +250,14 @@ async fn relay_one_job( } let event: AgentWsEvent = match serde_json::from_str(&text) { - Ok(e) => e, + Ok(e) => e, Err(e) => { tracing::warn!( error = %e, raw = %text, "WS relay: unparseable agent frame" ); continue; - } + }, }; process_event(pool, row, &event).await; @@ -287,17 +281,17 @@ async fn relay_one_job( async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) { // Map agent status string to DB job_status enum value. let db_status = match event.status.as_str() { - "running" => "running", + "running" => "running", "succeeded" => "succeeded", - "failed" => "failed", + "failed" => "failed", "cancelled" => "cancelled", other => { tracing::warn!(status = %other, "WS relay: unknown agent status"); return; - } + }, }; - let output = event.output.as_deref().unwrap_or(""); + let output = event.output.as_deref().unwrap_or(""); let error_msg = event.error.as_deref(); // Determine timestamps based on terminal state. @@ -359,20 +353,20 @@ async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent // Fire pg_notify so browser WS handlers forward the event. let payload = NotifyPayload { - job_id: row.job_id.to_string(), - host_id: row.host_id.to_string(), - status: db_status.to_string(), - output: event.output.clone(), + job_id: row.job_id.to_string(), + host_id: row.host_id.to_string(), + status: db_status.to_string(), + output: event.output.clone(), error_message: event.error.clone(), agent_job_id: row.agent_job_id.clone(), }; let payload_json = match serde_json::to_string(&payload) { - Ok(s) => s, + Ok(s) => s, Err(e) => { tracing::error!(error = %e, "WS relay: failed to serialize notify payload"); return; - } + }, }; if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)") @@ -418,11 +412,11 @@ async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) { .fetch_one(pool) .await { - Ok(n) => n, + Ok(n) => n, Err(e) => { tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed"); return; - } + }, }; if pending > 0 { @@ -437,14 +431,18 @@ async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) { .fetch_one(pool) .await { - Ok(n) => n, + Ok(n) => n, Err(e) => { tracing::error!(error = %e, %job_id, "update_parent_job_status: failed-count query failed"); return; - } + }, }; - let final_status = if failed_count > 0 { "failed" } else { "succeeded" }; + let final_status = if failed_count > 0 { + "failed" + } else { + "succeeded" + }; if let Err(e) = sqlx::query( "UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2", diff --git a/rustfmt.toml b/rustfmt.toml index 977d0bc..5644e86 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,5 +1,6 @@ # Linux Patch Manager - Rust Formatting Configuration # Run: cargo fmt --check (CI) or cargo fmt (fix) +# Only stable options - nightly-only options removed edition = "2021" max_width = 100 @@ -10,15 +11,4 @@ use_small_heuristics = "Default" reorder_imports = true reorder_modules = true remove_nested_parens = true -fn_single_line = false -where_single_line = false -imports_granularity = "Crate" -group_imports = "StdExternalCrate" -normalize_doc_attributes = true -wrap_comments = true -comment_width = 80 -indent_style = "Block" -trailing_comma = "Vertical" match_block_trailing_comma = true -blank_lines_lower_bound = 0 -blank_lines_upper_bound = 1