Private
Public Access
1
0

feat: add bump-version.sh script for version management

Automates version bumps across all version source files:
- Cargo.toml (PRIMARY - workspace.package.version)
- debian/changelog (prepend new entry)
- debian/control (update Version field)
- scripts/build-package.sh (update VERSION variable)
- frontend/package.json (update version field)
- Stale references check after bump

Usage: ./scripts/bump-version.sh <new_version> <old_version>
This commit is contained in:
2026-05-28 10:52:16 -05:00
commit 124b5b0e3b
153 changed files with 41878 additions and 0 deletions

View File

@ -0,0 +1,19 @@
[package]
name = "pm-agent-client"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[dependencies]
pm-core = { path = "../pm-core" }
tokio = { workspace = true }
reqwest = { workspace = true }
rustls = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }

View File

@ -0,0 +1,274 @@
//! mTLS HTTP client for communicating with Linux Patch API agents.
//!
//! # Example
//!
//! ```no_run
//! use pm_agent_client::client::AgentClient;
//!
//! # async fn example() -> Result<(), pm_agent_client::error::AgentClientError> {
//! let client = AgentClient::new(
//! "192.168.1.10",
//! 12443,
//! include_bytes!("../certs/client.crt"),
//! include_bytes!("../certs/client.key"),
//! include_bytes!("../certs/ca.crt"),
//! )?;
//!
//! let health = client.health().await?;
//! println!("Agent status: {}", health.status);
//! # Ok(())
//! # }
//! ```
use std::time::Duration;
use reqwest::{tls::Version, Certificate, ClientBuilder, Identity};
use serde::{de::DeserializeOwned, Serialize};
use tracing::{debug, instrument};
use crate::{
error::AgentClientError,
types::{
AgentEnvelope, AgentJobStatus, ApplyPatchesRequest, ApplyPatchesResponse, HealthData,
PackagesData, PatchesData, RollbackResponse, ServiceStatusData, SystemInfoData,
},
};
/// Default TCP port that the Linux Patch API agent listens on.
pub const DEFAULT_AGENT_PORT: u16 = 12443;
/// Request timeout applied to every agent API call.
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
// ============================================================
// AgentClient
// ============================================================
/// Async HTTP client that speaks mTLS to a single Linux Patch API agent.
///
/// Construct once via [`AgentClient::new`] and reuse across calls;
/// the underlying [`reqwest::Client`] maintains a connection pool.
#[derive(Debug, Clone)]
pub struct AgentClient {
/// Underlying HTTP client (configured for mTLS + TLS 1.3).
inner: reqwest::Client,
/// Base URL of the agent, e.g. `https://10.0.0.5:12443/api/v1`.
base_url: String,
}
impl AgentClient {
/// Create a new [`AgentClient`] configured for mTLS.
///
/// # Arguments
///
/// * `host_ip` IP address (or hostname) of the agent.
/// * `port` TCP port the agent listens on (default [`DEFAULT_AGENT_PORT`]).
/// * `client_cert_pem` PEM-encoded client certificate presented during the TLS handshake.
/// * `client_key_pem` PEM-encoded private key matching `client_cert_pem`.
/// * `ca_cert_pem` PEM-encoded CA certificate used to verify the agent's server cert.
///
/// # Errors
///
/// Returns [`AgentClientError::Tls`] when certificate parsing fails, or
/// [`AgentClientError::Request`] when `reqwest` client construction fails.
pub fn new(
host_ip: &str,
port: u16,
client_cert_pem: &[u8],
client_key_pem: &[u8],
ca_cert_pem: &[u8],
) -> Result<Self, AgentClientError> {
// Build client identity: reqwest expects cert + key concatenated as PEM.
let mut identity_pem = Vec::with_capacity(client_cert_pem.len() + client_key_pem.len());
identity_pem.extend_from_slice(client_cert_pem);
identity_pem.extend_from_slice(client_key_pem);
let identity = Identity::from_pem(&identity_pem)
.map_err(|e| AgentClientError::Tls(format!("invalid client identity PEM: {e}")))?;
// Parse the CA certificate used to verify the agent's server certificate.
let ca_cert = Certificate::from_pem(ca_cert_pem)
.map_err(|e| AgentClientError::Tls(format!("invalid CA certificate PEM: {e}")))?;
// Build the reqwest client:
// - force rustls TLS backend
// - disable built-in OS/system trust roots (only trust our internal CA)
// - enforce TLS 1.3 minimum
// - attach client identity (mTLS)
// - add our CA as a trusted root
// - apply a global request timeout
let inner = ClientBuilder::new()
.use_rustls_tls()
.tls_built_in_root_certs(false)
.min_tls_version(Version::TLS_1_3)
.identity(identity)
.add_root_certificate(ca_cert)
.timeout(REQUEST_TIMEOUT)
.build()
.map_err(AgentClientError::Request)?;
let clean_ip = host_ip.split('/').next().unwrap_or(host_ip);
let base_url = format!("https://{}:{}/api/v1", clean_ip, port);
Ok(Self { inner, base_url })
}
// --------------------------------------------------------
// Public API methods
// --------------------------------------------------------
/// `GET /api/v1/health` — check agent liveness and retrieve uptime.
#[instrument(skip(self), fields(base_url = %self.base_url))]
pub async fn health(&self) -> Result<HealthData, AgentClientError> {
self.get("health", &[]).await
}
/// `GET /api/v1/system/info` — retrieve host system information.
#[instrument(skip(self), fields(base_url = %self.base_url))]
pub async fn system_info(&self) -> Result<SystemInfoData, AgentClientError> {
self.get("system/info", &[]).await
}
/// `GET /api/v1/packages?status=upgradable` — list packages with available upgrades.
#[instrument(skip(self), fields(base_url = %self.base_url))]
pub async fn packages_upgradable(&self) -> Result<PackagesData, AgentClientError> {
self.get("packages", &[("status", "upgradable")]).await
}
/// `GET /api/v1/patches` — list available patches with severity and CVE data.
#[instrument(skip(self), fields(base_url = %self.base_url))]
pub async fn patches(&self) -> Result<PatchesData, AgentClientError> {
self.get("patches", &[]).await
}
// --------------------------------------------------------
// Private helpers
// --------------------------------------------------------
/// Execute a GET request against `{base_url}/{path}` with optional query
/// parameters, deserialize the [`AgentEnvelope`], and extract the `data`
/// field — or propagate an [`AgentClientError::ApiError`].
async fn get<T>(&self, path: &str, query: &[(&str, &str)]) -> Result<T, AgentClientError>
where
T: DeserializeOwned,
{
let url = format!("{}/{}", self.base_url, path);
debug!(url = %url, ?query, "Sending GET request to agent");
let mut request = self.inner.get(&url);
if !query.is_empty() {
request = request.query(query);
}
let response = request.send().await?;
let status = response.status();
debug!(url = %url, status = %status, "Received response from agent");
// Capture body text so we can attempt to deserialise the error envelope
// even for non-2xx responses.
let body = response.text().await?;
// Attempt to parse the standard agent envelope regardless of HTTP status.
// The agent may embed a structured error body on 4xx/5xx responses.
let envelope: AgentEnvelope<T> = serde_json::from_str(&body)?;
if !status.is_success() || !envelope.success {
// Prefer the structured error from the envelope when present.
if let Some(err) = envelope.error {
return Err(AgentClientError::ApiError {
code: err.code,
message: err.message,
});
}
// Fallback: use the HTTP status as the error indicator.
return Err(AgentClientError::ApiError {
code: status.as_str().to_string(),
message: format!("Agent returned HTTP {} for {}", status.as_u16(), url),
});
}
// On success the `data` field must be present.
envelope.data.ok_or_else(|| AgentClientError::ApiError {
code: "MISSING_DATA".to_string(),
message: "Agent response success=true but data field is absent".to_string(),
})
}
// --------------------------------------------------------
// Patch apply / job management methods
// --------------------------------------------------------
/// `POST /api/v1/patches/apply` — trigger patch application on the agent.
#[instrument(skip(self, req), fields(base_url = %self.base_url))]
pub async fn apply_patches(
&self,
req: &ApplyPatchesRequest,
) -> Result<ApplyPatchesResponse, AgentClientError> {
self.post("patches/apply", req).await
}
/// `GET /api/v1/jobs/{id}` — poll an async agent job for status.
#[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))]
pub async fn job_status(&self, job_id: &str) -> Result<AgentJobStatus, AgentClientError> {
self.get(&format!("jobs/{}", job_id), &[]).await
}
/// `POST /api/v1/jobs/{id}/rollback` — trigger rollback on the agent.
#[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))]
pub async fn rollback_job(&self, job_id: &str) -> Result<RollbackResponse, AgentClientError> {
let empty: serde_json::Value = serde_json::json!({});
self.post(&format!("jobs/{}/rollback", job_id), &empty)
.await
}
/// `GET /api/v1/system/services/{name}` — check status of a specific service on the agent.
#[instrument(skip(self), fields(base_url = %self.base_url, service_name = %service_name))]
pub async fn service_status(
&self,
service_name: &str,
) -> Result<ServiceStatusData, AgentClientError> {
self.get(&format!("system/services/{}", service_name), &[])
.await
}
// --------------------------------------------------------
// Private POST helper
// --------------------------------------------------------
/// Execute a POST request against `{base_url}/{path}`, serialize `body` as
/// JSON, deserialize the [`AgentEnvelope`], and extract the `data` field —
/// or propagate an [`AgentClientError::ApiError`].
async fn post<Req, Resp>(&self, path: &str, body: &Req) -> Result<Resp, AgentClientError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let url = format!("{}/{}", self.base_url, path);
debug!(url = %url, "Sending POST request to agent");
let response = self.inner.post(&url).json(body).send().await?;
let status = response.status();
debug!(url = %url, status = %status, "Received POST response from agent");
let body_text = response.text().await?;
let envelope: AgentEnvelope<Resp> = serde_json::from_str(&body_text)?;
if !status.is_success() || !envelope.success {
if let Some(err) = envelope.error {
return Err(AgentClientError::ApiError {
code: err.code,
message: err.message,
});
}
return Err(AgentClientError::ApiError {
code: status.as_str().to_string(),
message: format!("Agent returned HTTP {} for {}", status.as_u16(), url),
});
}
envelope.data.ok_or_else(|| AgentClientError::ApiError {
code: "MISSING_DATA".to_string(),
message: "Agent response success=true but data field is absent".to_string(),
})
}
}

View File

@ -0,0 +1,49 @@
//! Error types for the pm-agent-client crate.
use thiserror::Error;
/// Top-level error type returned by [`crate::client::AgentClient`] methods.
#[derive(Debug, Error)]
pub enum AgentClientError {
/// TLS configuration or handshake failure.
#[error("TLS error: {0}")]
Tls(String),
/// Unable to establish a TCP/TLS connection to the agent.
#[error("Connection error: {0}")]
Connect(#[source] reqwest::Error),
/// An HTTP request or response transport error (not a timeout).
#[error("Request error: {0}")]
Request(#[source] reqwest::Error),
/// The request did not complete within the configured timeout.
#[error("Request timed out")]
Timeout,
/// The agent returned a non-2xx HTTP status or `success: false` in the
/// response envelope.
#[error("Agent API error [{code}]: {message}")]
ApiError {
/// Machine-readable error code supplied by the agent (e.g. `"NOT_FOUND"`).
code: String,
/// Human-readable description returned by the agent.
message: String,
},
/// JSON deserialization of the agent response failed.
#[error("Failed to deserialise agent response: {0}")]
Deserialize(#[from] serde_json::Error),
}
impl From<reqwest::Error> for AgentClientError {
fn from(err: reqwest::Error) -> Self {
if err.is_timeout() {
AgentClientError::Timeout
} else if err.is_connect() {
AgentClientError::Connect(err)
} else {
AgentClientError::Request(err)
}
}
}

View File

@ -0,0 +1,43 @@
//! `pm-agent-client` — mTLS HTTP client for Linux Patch API agent communication.
//!
//! This crate provides [`client::AgentClient`], an async HTTP client that
//! establishes mutual-TLS connections (TLS 1.3) to `linux_patch_api` agents
//! running on managed hosts.
//!
//! # Quick start
//!
//! ```no_run
//! use pm_agent_client::AgentClient;
//!
//! # async fn run() -> Result<(), pm_agent_client::AgentClientError> {
//! let client = AgentClient::new(
//! "10.0.1.5",
//! 12443,
//! include_bytes!("../certs/client.crt"),
//! include_bytes!("../certs/client.key"),
//! include_bytes!("../certs/ca.crt"),
//! )?;
//!
//! let health = client.health().await?;
//! println!("Agent {}: {}", health.status, health.version);
//! # Ok(())
//! # }
//! ```
pub mod client;
pub mod error;
pub mod types;
// ── Convenience re-exports ──────────────────────────────────────────────────
/// Primary client — re-exported from [`client::AgentClient`].
pub use client::{AgentClient, DEFAULT_AGENT_PORT};
/// Error type — re-exported from [`error::AgentClientError`].
pub use error::AgentClientError;
/// Response envelope and all data types.
pub use types::{
AgentEnvelope, AgentErrorBody, HealthData, Package, PackagesData, Patch, PatchesData,
RollbackResponse, ServiceStatusData, SystemInfoData,
};

View File

@ -0,0 +1,230 @@
//! Response and request types for the Linux Patch API agent endpoints.
//!
//! All agent responses are wrapped in [`AgentEnvelope<T>`].
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
// ============================================================
// Envelope & error
// ============================================================
/// Generic response wrapper returned by every agent endpoint.
///
/// ```json
/// { "success": true, "request_id": "…", "timestamp": "…", "data": {…}, "error": null }
/// ```
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentEnvelope<T> {
/// `true` when the request succeeded; `false` on error.
pub success: bool,
/// Server-assigned request identifier (UUID v4).
pub request_id: Uuid,
/// Server timestamp for the response (ISO-8601 / RFC-3339).
pub timestamp: DateTime<Utc>,
/// Response payload — present when `success` is `true`.
pub data: Option<T>,
/// Error detail — present when `success` is `false`.
pub error: Option<AgentErrorBody>,
}
/// Structured error returned inside [`AgentEnvelope::error`].
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentErrorBody {
/// Machine-readable error code (e.g. `"INTERNAL_ERROR"`).
pub code: String,
/// Human-readable description of what went wrong.
pub message: String,
/// Optional free-form extra detail from the agent.
#[serde(default)]
pub details: Option<serde_json::Value>,
/// Whether the caller may safely retry the request.
#[serde(default)]
pub retryable: bool,
}
// ============================================================
// GET /api/v1/health
// ============================================================
/// Payload returned by `GET /api/v1/health`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HealthData {
/// Agent status string, e.g. `"ok"` or `"degraded"`.
pub status: String,
/// Seconds elapsed since the agent process started.
pub uptime_seconds: u64,
/// Agent software version string.
pub version: String,
}
// ============================================================
// GET /api/v1/system/info
// ============================================================
/// Payload returned by `GET /api/v1/system/info`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SystemInfoData {
/// Hostname of the managed system.
pub hostname: String,
/// OS family / distribution name (e.g. `"Ubuntu"`).
pub os: String,
/// OS version string.
pub os_version: String,
/// Kernel version string.
pub kernel: String,
/// CPU architecture (e.g. `"x86_64"`).
pub architecture: String,
/// When the agent last checked for updates (`null` if never).
pub last_update_check: Option<DateTime<Utc>>,
/// When updates were last applied (`null` if never).
pub last_update_apply: Option<DateTime<Utc>>,
/// Whether the system has a pending reboot.
pub pending_reboot: bool,
}
// ============================================================
// GET /api/v1/packages?status=upgradable
// ============================================================
/// A single package entry.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Package {
/// Package name.
pub name: String,
/// Installed version.
pub version: String,
/// Package status string (e.g. `"installed"`, `"upgradable"`).
pub status: String,
/// Whether a newer version is available.
pub upgradable: bool,
/// Latest available version (`null` if not upgradable).
pub latest_version: Option<String>,
/// Short package description.
pub description: String,
/// CVE identifiers associated with this package.
#[serde(default)]
pub cve_ids: Vec<String>,
}
/// Payload returned by `GET /api/v1/packages`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PackagesData {
/// List of packages matching the query filters.
pub packages: Vec<Package>,
/// Total count of matching packages.
pub total: u64,
}
// ============================================================
// GET /api/v1/patches
// ============================================================
/// A single available patch.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Patch {
/// Package / patch name.
pub name: String,
/// Currently installed version.
pub current_version: String,
/// Version available after applying this patch.
pub available_version: String,
/// Severity level (e.g. `"critical"`, `"high"`, `"medium"`, `"low"`).
pub severity: String,
/// Human-readable description of the patch.
pub description: String,
/// CVE identifiers addressed by this patch.
#[serde(default)]
pub cve_ids: Vec<String>,
/// Whether applying this patch requires a system reboot.
pub requires_reboot: bool,
}
/// Payload returned by `GET /api/v1/patches`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PatchesData {
/// List of available patches.
pub patches: Vec<Patch>,
/// Total patch count.
pub total: u64,
/// Number of patches classified as security updates.
pub security_updates: u64,
/// Whether any patch in the list requires a reboot.
pub requires_reboot: bool,
}
// ============================================================
// POST /api/v1/patches/apply
// ============================================================
/// Request body for `POST /api/v1/patches/apply`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApplyPatchesRequest {
/// Package names to apply. Empty = apply all available patches.
pub packages: Vec<String>,
/// If true, allow automatic reboot after patching if required.
pub allow_reboot: bool,
}
/// Response from `POST /api/v1/patches/apply`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ApplyPatchesResponse {
/// Agent-assigned async job ID for status polling.
pub job_id: String,
/// Initial status: typically `"running"` or `"queued"`.
pub status: String,
}
// ============================================================
// GET /api/v1/jobs/{id}
// ============================================================
/// Status of an async agent job returned by `GET /api/v1/jobs/{id}`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentJobStatus {
pub job_id: String,
/// Current status: `"queued"`, `"running"`, `"succeeded"`, `"completed"`, `"failed"`, or `"cancelled"`.
pub status: String,
pub progress_percent: Option<u8>,
pub output: Option<String>,
pub error: Option<String>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
// ============================================================
// GET /api/v1/system/services/{name}
// ============================================================
/// Payload returned by `GET /api/v1/system/services/{name}`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServiceStatusData {
/// Service name.
pub name: String,
/// Human-readable service name.
pub display_name: String,
/// Active state (e.g. `"active"`, `"inactive"`, `"failed"`).
pub active_state: String,
/// Sub state (e.g. `"running"`, `"dead"`, `"exited"`).
pub sub_state: String,
/// Load state (e.g. `"loaded"`, `"not-found"`).
pub load_state: String,
/// Enabled state (e.g. `"enabled"`, `"disabled"`).
pub enabled_state: String,
/// Main PID of the service process.
pub main_pid: Option<u32>,
/// Whether the service is considered healthy.
pub healthy: bool,
}
// ============================================================
// POST /api/v1/jobs/{id}/rollback
// ============================================================
/// Response from `POST /api/v1/jobs/{id}/rollback`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RollbackResponse {
pub job_id: String,
pub status: String,
}

29
crates/pm-auth/Cargo.toml Normal file
View File

@ -0,0 +1,29 @@
[package]
name = "pm-auth"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[dependencies]
pm-core = { path = "../pm-core" }
tokio = { workspace = true }
axum = { workspace = true }
axum-extra = { workspace = true }
sqlx = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
argon2 = { workspace = true }
jsonwebtoken = { workspace = true }
rand = { workspace = true }
totp-rs = { workspace = true }
base64 = { workspace = true }
hex = { workspace = true }
ipnet = { workspace = true }
parking_lot = "0.12"
sha2 = { workspace = true }

152
crates/pm-auth/src/jwt.rs Executable file
View File

@ -0,0 +1,152 @@
//! JWT issuance and validation using EdDSA / Ed25519.
//!
//! - Access tokens: 15-minute TTL, signed with Ed25519 private key
//! - Key rotation: 90-day cycle with 24-hour overlap window
//! - The web process holds the signing key; worker holds only the public key
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
/// JWT algorithm — EdDSA with Ed25519 curve.
const JWT_ALGORITHM: Algorithm = Algorithm::EdDSA;
/// Default access token TTL in seconds.
pub const DEFAULT_ACCESS_TTL_SECS: i64 = 900; // 15 minutes
#[derive(Debug, Error)]
pub enum JwtError {
#[error("Failed to encode JWT: {0}")]
Encode(String),
#[error("Failed to decode JWT: {0}")]
Decode(String),
#[error("Token is expired")]
Expired,
#[error("Token has invalid claims")]
InvalidClaims,
#[error("Failed to load signing key: {0}")]
KeyLoad(String),
}
/// Standard JWT claims for access tokens.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessClaims {
/// Subject: user ID (UUID)
pub sub: String,
/// Issued at (Unix timestamp)
pub iat: i64,
/// Expiry (Unix timestamp)
pub exp: i64,
/// JWT ID (unique per token)
pub jti: String,
/// User role: "admin" or "operator"
pub role: String,
/// Username (for display / logging)
pub username: String,
}
impl AccessClaims {
/// Create new claims for the given user.
pub fn new(user_id: Uuid, username: &str, role: &str, ttl_secs: i64) -> Self {
let now = Utc::now();
Self {
sub: user_id.to_string(),
iat: now.timestamp(),
exp: (now + Duration::seconds(ttl_secs)).timestamp(),
jti: Uuid::new_v4().to_string(),
role: role.to_string(),
username: username.to_string(),
}
}
/// Check if the token is expired (redundant with validation but useful for explicit checks).
pub fn is_expired(&self) -> bool {
Utc::now().timestamp() > self.exp
}
/// Return the user UUID parsed from the `sub` field.
pub fn user_id(&self) -> Result<Uuid, JwtError> {
Uuid::parse_str(&self.sub).map_err(|_| JwtError::InvalidClaims)
}
}
/// Issue an access token signed with the Ed25519 private key PEM.
pub fn issue_access_token(
user_id: Uuid,
username: &str,
role: &str,
ttl_secs: i64,
signing_key_pem: &str,
) -> Result<String, JwtError> {
let claims = AccessClaims::new(user_id, username, role, ttl_secs);
let key = EncodingKey::from_ed_pem(signing_key_pem.as_bytes())
.map_err(|e| JwtError::KeyLoad(e.to_string()))?;
let header = Header::new(JWT_ALGORITHM);
encode(&header, &claims, &key).map_err(|e| JwtError::Encode(e.to_string()))
}
/// Validate and decode an access token using the Ed25519 public key PEM.
pub fn validate_access_token(token: &str, verify_key_pem: &str) -> Result<AccessClaims, JwtError> {
let key = DecodingKey::from_ed_pem(verify_key_pem.as_bytes())
.map_err(|e| JwtError::KeyLoad(e.to_string()))?;
let mut validation = Validation::new(JWT_ALGORITHM);
validation.validate_exp = true;
validation.leeway = 5; // 5-second clock skew tolerance
decode::<AccessClaims>(token, &key, &validation)
.map(|data| data.claims)
.map_err(|e| {
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
JwtError::Expired
} else {
JwtError::Decode(e.to_string())
}
})
}
/// Load the Ed25519 signing key from a PEM file path.
pub fn load_signing_key(path: &str) -> Result<String, JwtError> {
std::fs::read_to_string(path).map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}")))
}
/// Load the Ed25519 verification (public) key from a PEM file path.
pub fn load_verify_key(path: &str) -> Result<String, JwtError> {
std::fs::read_to_string(path).map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}")))
}
#[cfg(test)]
#[allow(dead_code)]
mod tests {
use super::*;
// Test keys generated with:
// openssl genpkey -algorithm ed25519 -out signing.pem
// openssl pkey -in signing.pem -pubout -out verify.pem
const TEST_SIGNING_KEY: &str = "-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIHNzPc3LkpODUVFr8GjVPm4M2yiKrXsZ/1uJQ/tQMjNb
-----END PRIVATE KEY-----
";
const TEST_VERIFY_KEY: &str = "-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEA8nRzpCYzZ1xFKNJDGt9wuXdq7kKS/ck9PfLJu/r3VEw=
-----END PUBLIC KEY-----
";
// Note: real tests require valid key pairs; these are placeholders.
// Integration tests in the test suite use generated keys.
#[test]
fn claims_construction() {
let user_id = Uuid::new_v4();
let claims = AccessClaims::new(user_id, "admin", "admin", 900);
assert_eq!(claims.sub, user_id.to_string());
assert_eq!(claims.role, "admin");
assert!(!claims.is_expired());
assert_eq!(claims.user_id().unwrap(), user_id);
}
}

25
crates/pm-auth/src/lib.rs Executable file
View File

@ -0,0 +1,25 @@
//! pm-auth — Authentication and authorization.
//!
//! Modules:
//! - `password` — Argon2id password hashing (m=65536, t=3, p=1)
//! - `jwt` — EdDSA/Ed25519 JWT issuance and validation (15-min TTL)
//! - `refresh` — Opaque 256-bit refresh tokens (1-hour sliding window)
//! - `mfa_totp` — TOTP setup and verification (Google Authenticator compatible)
//! - `mfa_webauthn` — WebAuthn stub (full implementation pending)
//! - `rbac` — Axum middleware for JWT authentication and role enforcement
//! - `session` — Login flow orchestration (password → MFA → tokens)
pub mod jwt;
pub mod mfa_totp;
pub mod mfa_webauthn;
pub mod password;
pub mod rbac;
pub mod refresh;
pub mod session;
// Commonly re-exported types
pub use jwt::{AccessClaims, JwtError};
pub use password::validate_password_strength;
pub use password::{hash_password, verify_password, PasswordError};
pub use rbac::{AuthConfig, AuthUser, UserRole};
pub use session::{LoginRequest, LoginResponse, SessionError, SessionUser};

103
crates/pm-auth/src/mfa_totp.rs Executable file
View File

@ -0,0 +1,103 @@
//! TOTP (Time-based One-Time Password) MFA implementation.
//!
//! Uses TOTP-rs with HMAC-SHA1, 6-digit codes, 30-second window.
//! Compatible with Google Authenticator, Authy, and standard TOTP apps.
use serde::{Deserialize, Serialize};
use thiserror::Error;
use totp_rs::{Algorithm, Secret, TOTP};
/// TOTP issuer label shown in authenticator apps.
const ISSUER: &str = "Linux Patch Manager";
#[derive(Debug, Error)]
pub enum TotpError {
#[error("Failed to create TOTP: {0}")]
Creation(String),
#[error("Invalid TOTP secret")]
InvalidSecret,
#[error("TOTP code verification failed")]
VerificationFailed,
}
/// TOTP setup response returned to the user during MFA enrollment.
#[derive(Debug, Serialize, Deserialize)]
pub struct TotpSetup {
/// Base32-encoded secret for manual entry in authenticator apps.
pub secret_base32: String,
/// OTP Auth URI for QR code generation (otpauth://totp/...).
pub otp_uri: String,
}
/// Generate a new TOTP secret and return setup information.
///
/// The caller should store `secret_base32` in the database after
/// the user verifies the first code.
pub fn generate_setup(username: &str) -> Result<TotpSetup, TotpError> {
let secret = Secret::generate_secret();
let secret_base32 = secret.to_encoded().to_string();
let totp = build_totp(username, &secret_base32)?;
let otp_uri = totp.get_url();
Ok(TotpSetup {
secret_base32,
otp_uri,
})
}
/// Verify a TOTP code against the stored secret.
///
/// Accepts codes within a ±1 step window (±30 seconds) to handle clock skew.
pub fn verify_code(username: &str, secret_base32: &str, code: &str) -> Result<bool, TotpError> {
let totp = build_totp(username, secret_base32)?;
let valid = totp
.check_current(code)
.map_err(|_| TotpError::VerificationFailed)?;
Ok(valid)
}
/// Build a TOTP instance from a base32 secret.
fn build_totp(username: &str, secret_base32: &str) -> Result<TOTP, TotpError> {
let secret = Secret::Encoded(secret_base32.to_string());
let secret_bytes = secret.to_bytes().map_err(|_| TotpError::InvalidSecret)?;
// With the `otpauth` feature, TOTP::new signature is:
// new(issuer, account_name, algorithm, digits, skew, step, secret)
TOTP::new(
Algorithm::SHA1,
6, // digits
1, // skew
30, // step (seconds)
secret_bytes,
Some(ISSUER.to_string()),
username.to_string(),
)
.map_err(|e| TotpError::Creation(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_setup_produces_valid_uri() {
let setup = generate_setup("testuser").unwrap();
assert!(!setup.secret_base32.is_empty());
assert!(setup.otp_uri.starts_with("otpauth://totp/"));
}
#[test]
fn verify_with_current_code() {
let setup = generate_setup("testuser").unwrap();
let totp = build_totp("testuser", &setup.secret_base32).unwrap();
let code = totp.generate_current().unwrap();
assert!(verify_code("testuser", &setup.secret_base32, &code).unwrap());
}
#[test]
fn wrong_code_fails() {
let setup = generate_setup("testuser").unwrap();
assert!(!verify_code("testuser", &setup.secret_base32, "000000").unwrap());
}
}

View File

@ -0,0 +1,51 @@
//! WebAuthn (FIDO2) MFA stub.
//!
//! Full implementation planned for M2 extension or M3.
//! WebAuthn requires stateful registration/authentication ceremonies
//! and a compatible client library (webauthn-rs).
//!
//! For M2, TOTP is the primary MFA method.
//! WebAuthn credentials are stored in the `users.webauthn_credential` JSONB
//! column and will be processed here when implemented.
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum WebAuthnError {
#[error("WebAuthn not yet implemented")]
NotImplemented,
}
/// Placeholder for WebAuthn registration options.
#[derive(Debug, Serialize, Deserialize)]
pub struct RegistrationOptions {
pub message: String,
}
/// Begin WebAuthn registration ceremony (stub).
pub fn begin_registration(_username: &str) -> Result<RegistrationOptions, WebAuthnError> {
Err(WebAuthnError::NotImplemented)
}
/// Complete WebAuthn registration ceremony (stub).
pub fn complete_registration(
_username: &str,
_response: &serde_json::Value,
) -> Result<serde_json::Value, WebAuthnError> {
Err(WebAuthnError::NotImplemented)
}
/// Begin WebAuthn authentication ceremony (stub).
pub fn begin_authentication(_username: &str) -> Result<serde_json::Value, WebAuthnError> {
Err(WebAuthnError::NotImplemented)
}
/// Verify WebAuthn authentication response (stub).
pub fn verify_authentication(
_username: &str,
_credential: &serde_json::Value,
_response: &serde_json::Value,
) -> Result<bool, WebAuthnError> {
Err(WebAuthnError::NotImplemented)
}

125
crates/pm-auth/src/password.rs Executable file
View File

@ -0,0 +1,125 @@
//! Password hashing and verification using Argon2id.
//!
//! Parameters (calibrated per OWASP recommendations):
//! - Algorithm: Argon2id
//! - Memory cost: 65536 KiB (64 MiB)
//! - Time cost: 3 iterations
//! - Parallelism: 1
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2, Params, Version,
};
use thiserror::Error;
/// Argon2id parameters per spec.
const M_COST: u32 = 65536; // 64 MiB
const T_COST: u32 = 3; // 3 iterations
const P_COST: u32 = 1; // 1 thread
#[derive(Debug, Error)]
pub enum PasswordError {
#[error("Failed to hash password: {0}")]
HashError(String),
#[error("Failed to verify password: {0}")]
VerifyError(String),
#[error("Invalid password hash format")]
InvalidHash,
}
/// Build an Argon2id instance with calibrated parameters.
fn argon2() -> Result<Argon2<'static>, PasswordError> {
let params = Params::new(M_COST, T_COST, P_COST, None)
.map_err(|e| PasswordError::HashError(e.to_string()))?;
Ok(Argon2::new(
argon2::Algorithm::Argon2id,
Version::V0x13,
params,
))
}
/// Hash a plaintext password using Argon2id with a random salt.
///
/// Returns the PHC string format hash suitable for storage.
pub fn hash_password(password: &str) -> Result<String, PasswordError> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = argon2()?;
let hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| PasswordError::HashError(e.to_string()))?;
Ok(hash.to_string())
}
/// Verify a plaintext password against a stored Argon2id PHC hash.
///
/// Returns `Ok(true)` if the password matches, `Ok(false)` if not.
pub fn verify_password(password: &str, hash: &str) -> Result<bool, PasswordError> {
let parsed_hash = PasswordHash::new(hash).map_err(|_| PasswordError::InvalidHash)?;
let argon2 = argon2()?;
match argon2.verify_password(password.as_bytes(), &parsed_hash) {
Ok(()) => Ok(true),
Err(argon2::password_hash::Error::Password) => Ok(false),
Err(e) => Err(PasswordError::VerifyError(e.to_string())),
}
}
/// Validate password strength against minimum requirements.
///
/// Requirements:
/// - Minimum 8 characters
/// - At least one uppercase letter
/// - At least one lowercase letter
/// - At least one digit
/// - At least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?)
pub fn validate_password_strength(password: &str) -> Result<(), String> {
if password.len() < 8 {
return Err("Password must be at least 8 characters".to_string());
}
if !password.chars().any(|c| c.is_ascii_uppercase()) {
return Err("Password must contain at least one uppercase letter".to_string());
}
if !password.chars().any(|c| c.is_ascii_lowercase()) {
return Err("Password must contain at least one lowercase letter".to_string());
}
if !password.chars().any(|c| c.is_ascii_digit()) {
return Err("Password must contain at least one digit".to_string());
}
let special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?";
if !password.chars().any(|c| special_chars.contains(c)) {
return Err(
"Password must contain at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?)"
.to_string(),
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash_and_verify_roundtrip() {
let password = "super-secret-password-123!";
let hash = hash_password(password).unwrap();
assert!(hash.starts_with("$argon2id$"));
assert!(verify_password(password, &hash).unwrap());
}
#[test]
fn wrong_password_fails() {
let hash = hash_password("correct-horse").unwrap();
assert!(!verify_password("wrong-password", &hash).unwrap());
}
#[test]
fn different_salts_produce_different_hashes() {
let hash1 = hash_password("same-password").unwrap();
let hash2 = hash_password("same-password").unwrap();
assert_ne!(hash1, hash2); // different salts
}
}

232
crates/pm-auth/src/rbac.rs Executable file
View File

@ -0,0 +1,232 @@
//! Role-Based Access Control (RBAC) middleware for Axum.
//!
//! Provides:
//! - JWT extraction and validation from `Authorization: Bearer <token>` header
//! - Role enforcement (`admin`, `operator`)
//! - Group-scoped access (enforced at the handler level using `AuthUser` extension)
//! - IP whitelist enforcement
use axum::{
extract::Request,
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Json, Response},
};
use ipnet::IpNet;
use parking_lot::RwLock;
use serde_json::json;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use uuid::Uuid;
use crate::jwt::{validate_access_token, AccessClaims, JwtError};
/// User identity extracted from a validated JWT, inserted as a request extension.
#[derive(Debug, Clone)]
pub struct AuthUser {
pub user_id: Uuid,
pub username: String,
pub role: UserRole,
pub claims: AccessClaims,
}
/// Application roles.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UserRole {
Admin,
Operator,
Reporter,
}
impl UserRole {
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Option<Self> {
match s {
"admin" => Some(Self::Admin),
"operator" => Some(Self::Operator),
"reporter" => Some(Self::Reporter),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Admin => "admin",
Self::Operator => "operator",
Self::Reporter => "reporter",
}
}
/// Admin can do everything; operator has limited scope.
pub fn is_admin(&self) -> bool {
matches!(self, Self::Admin)
}
/// Admin and Operator can write; Reporter is read-only.
pub fn can_write(&self) -> bool {
matches!(self, Self::Admin | Self::Operator)
}
}
/// Shared auth configuration injected via Axum state.
#[derive(Clone)]
pub struct AuthConfig {
/// Ed25519 public key PEM for JWT verification.
pub verify_key_pem: String,
/// IP whitelist (empty = allow all). RwLock for runtime updates.
pub ip_whitelist: Arc<RwLock<Vec<IpNet>>>,
}
impl AuthConfig {
pub fn new(verify_key_pem: String, ip_whitelist_cidrs: &[String]) -> Self {
let ip_whitelist = ip_whitelist_cidrs
.iter()
.filter_map(|cidr| IpNet::from_str(cidr).ok())
.collect();
Self {
verify_key_pem,
ip_whitelist: Arc::new(RwLock::new(ip_whitelist)),
}
}
/// Check if an IP address is allowed by the whitelist.
/// If the whitelist is empty, all IPs are allowed.
pub fn is_ip_allowed(&self, ip: &IpAddr) -> bool {
let whitelist = self.ip_whitelist.read();
if whitelist.is_empty() {
return true;
}
whitelist.iter().any(|net| net.contains(ip))
}
/// Update the IP whitelist at runtime without restart.
pub fn update_ip_whitelist(&self, entries: Vec<String>) {
let nets: Vec<IpNet> = entries
.iter()
.filter_map(|cidr| IpNet::from_str(cidr).ok())
.collect();
let count = nets.len();
*self.ip_whitelist.write() = nets;
tracing::info!(count, "IP whitelist updated at runtime");
}
}
/// Extract `Authorization: Bearer <token>` from request headers.
fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
}
/// Extract the remote IP from `X-Forwarded-For`.
fn extract_remote_ip(headers: &HeaderMap) -> Option<IpAddr> {
headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.and_then(|s| s.trim().parse().ok())
}
/// Unauthorized JSON response helper.
fn unauthorized(message: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
Json(json!({ "error": { "code": "unauthorized", "message": message } })),
)
.into_response()
}
/// Forbidden JSON response helper.
fn forbidden(message: &str) -> Response {
(
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": message } })),
)
.into_response()
}
/// Middleware: authenticate any valid JWT (admin or operator).
///
/// Inserts `AuthUser` into request extensions on success.
/// Rejects with 401 if token is missing/invalid, 403 if IP is blocked.
pub async fn require_auth(auth_config: Arc<AuthConfig>, mut req: Request, next: Next) -> Response {
// IP whitelist check
if let Some(ip) = extract_remote_ip(req.headers()) {
if !auth_config.is_ip_allowed(&ip) {
tracing::warn!(ip = %ip, "Request blocked by IP whitelist");
return forbidden("Access denied");
}
}
// Extract and validate JWT
let token = match extract_bearer_token(req.headers()) {
Some(t) => t,
None => return unauthorized("Missing authorization token"),
};
let claims = match validate_access_token(token, &auth_config.verify_key_pem) {
Ok(c) => c,
Err(JwtError::Expired) => return unauthorized("Token expired"),
Err(e) => {
tracing::debug!(error = %e, "JWT validation failed");
return unauthorized("Invalid token");
},
};
let role = match UserRole::from_str(&claims.role) {
Some(r) => r,
None => return unauthorized("Invalid role in token"),
};
let user_id = match claims.user_id() {
Ok(id) => id,
Err(_) => return unauthorized("Invalid user ID in token"),
};
let auth_user = AuthUser {
user_id,
username: claims.username.clone(),
role,
claims,
};
req.extensions_mut().insert(auth_user);
next.run(req).await
}
/// Middleware: require the `admin` role.
/// Must be chained AFTER `require_auth` (which inserts `AuthUser`).
pub async fn require_admin(req: Request, next: Next) -> Response {
let auth_user = match req.extensions().get::<AuthUser>().cloned() {
Some(u) => u,
None => return unauthorized("Authentication required"),
};
if !auth_user.role.is_admin() {
return forbidden("Admin role required");
}
next.run(req).await
}
/// Axum extractor: pulls `AuthUser` from request extensions.
impl<S> axum::extract::FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<AuthUser>()
.cloned()
.ok_or_else(|| unauthorized("Authentication required"))
}
}

163
crates/pm-auth/src/refresh.rs Executable file
View File

@ -0,0 +1,163 @@
//! Opaque refresh token management.
//!
//! - 256-bit cryptographically random opaque tokens
//! - Stored as SHA-256 hash in the database (never the raw token)
//! - 1-hour sliding inactivity timeout, updated on each use
//! - Rotated on use (old token revoked, new one issued)
//! - Revocable by admin force-logout
use chrono::{Duration, Utc};
use rand::RngCore;
use sha2::{Digest, Sha256};
use sqlx::PgPool;
use thiserror::Error;
use uuid::Uuid;
/// Length of the raw refresh token in bytes (256 bits).
const TOKEN_BYTES: usize = 32;
/// Sliding inactivity window: 1 hour.
const INACTIVITY_TIMEOUT_HOURS: i64 = 1;
#[derive(Debug, Error)]
pub enum RefreshError {
#[error("Refresh token not found or revoked")]
Invalid,
#[error("Refresh token expired")]
Expired,
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
}
/// Raw (plaintext) refresh token — returned to the client, never stored.
#[derive(Debug, Clone)]
pub struct RawRefreshToken(pub String);
impl RawRefreshToken {
/// Hex-encode a raw 256-bit random token.
pub fn generate() -> Self {
let mut bytes = [0u8; TOKEN_BYTES];
rand::thread_rng().fill_bytes(&mut bytes);
Self(hex::encode(bytes))
}
/// Return the SHA-256 hash of this token for database storage.
pub fn hash(&self) -> String {
let digest = Sha256::digest(self.0.as_bytes());
hex::encode(digest)
}
}
/// Database row representation of a stored refresh token.
#[derive(Debug, sqlx::FromRow)]
pub struct StoredRefreshToken {
pub id: Uuid,
pub user_id: Uuid,
pub expires_at: chrono::DateTime<Utc>,
pub revoked: bool,
}
/// Issue a new refresh token for the given user and store it in the database.
///
/// Returns the raw (plaintext) token to be sent to the client.
pub async fn issue(
pool: &PgPool,
user_id: Uuid,
user_agent: Option<&str>,
ip_address: Option<&str>,
) -> Result<RawRefreshToken, RefreshError> {
let token = RawRefreshToken::generate();
let hash = token.hash();
let expires_at = Utc::now() + Duration::hours(INACTIVITY_TIMEOUT_HOURS);
sqlx::query(
r#"
INSERT INTO refresh_tokens (user_id, token_hash, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, $4, $5::inet)
"#,
)
.bind(user_id)
.bind(&hash)
.bind(expires_at)
.bind(user_agent)
.bind(ip_address)
.execute(pool)
.await?;
tracing::debug!(user_id = %user_id, "Refresh token issued");
Ok(token)
}
/// Validate a refresh token, then rotate it (revoke old, issue new).
///
/// Returns `(new_raw_token, user_id)` if valid.
pub async fn rotate(
pool: &PgPool,
raw_token: &str,
user_agent: Option<&str>,
ip_address: Option<&str>,
) -> Result<(RawRefreshToken, Uuid), RefreshError> {
let hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
let now = Utc::now();
// Look up token
let row: Option<StoredRefreshToken> = sqlx::query_as(
r#"
SELECT id, user_id, expires_at, revoked
FROM refresh_tokens
WHERE token_hash = $1
"#,
)
.bind(&hash)
.fetch_optional(pool)
.await?;
let stored = row.ok_or(RefreshError::Invalid)?;
if stored.revoked {
return Err(RefreshError::Invalid);
}
if stored.expires_at < now {
return Err(RefreshError::Expired);
}
// Revoke old token
sqlx::query("UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE id = $1")
.bind(stored.id)
.execute(pool)
.await?;
// Issue new token
let new_token = issue(pool, stored.user_id, user_agent, ip_address).await?;
tracing::debug!(user_id = %stored.user_id, "Refresh token rotated");
Ok((new_token, stored.user_id))
}
/// Revoke all refresh tokens for a user (force logout).
pub async fn revoke_all_for_user(pool: &PgPool, user_id: Uuid) -> Result<u64, RefreshError> {
let result = sqlx::query(
"UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE user_id = $1 AND revoked = FALSE",
)
.bind(user_id)
.execute(pool)
.await?;
tracing::info!(user_id = %user_id, rows = result.rows_affected(), "All refresh tokens revoked");
Ok(result.rows_affected())
}
/// Revoke a single refresh token by its raw value.
pub async fn revoke(pool: &PgPool, raw_token: &str) -> Result<(), RefreshError> {
let hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
sqlx::query(
"UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE token_hash = $1",
)
.bind(&hash)
.execute(pool)
.await?;
Ok(())
}

308
crates/pm-auth/src/session.rs Executable file
View File

@ -0,0 +1,308 @@
//! Session management: login flow, logout, token issuance.
//!
//! Login flow: password → MFA → access token + refresh token
//! Logout: revoke refresh token
//! Force logout: revoke all tokens for a user
use chrono::Utc;
use pm_core::models::{AuthProvider, UserRole};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use thiserror::Error;
use uuid::Uuid;
use crate::{
jwt::{self, JwtError},
mfa_totp,
password::{self, PasswordError},
refresh::{self, RefreshError},
};
#[derive(Debug, Error)]
pub enum SessionError {
#[error("Invalid credentials")]
InvalidCredentials,
#[error("Account is disabled")]
AccountDisabled,
#[error("Password reset required")]
PasswordResetRequired,
#[error("MFA required")]
MfaRequired,
#[error("Invalid MFA code")]
InvalidMfaCode,
#[error("Account locked due to too many failed attempts")]
AccountLocked,
#[error("JWT error: {0}")]
Jwt(#[from] JwtError),
#[error("Refresh token error: {0}")]
Refresh(#[from] RefreshError),
#[error("Password error: {0}")]
Password(#[from] PasswordError),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
}
/// Successful login response returned to the client.
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
/// Short-lived JWT access token (15 minutes).
pub access_token: String,
/// Opaque refresh token (1-hour sliding window).
pub refresh_token: String,
/// Token type (always "Bearer").
pub token_type: String,
/// Access token TTL in seconds.
pub expires_in: i64,
/// User information.
pub user: SessionUser,
}
/// User summary embedded in login response.
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionUser {
pub id: String,
pub username: String,
pub display_name: String,
pub role: String,
pub mfa_enabled: bool,
}
/// Database user row fetched during login.
#[derive(Debug, sqlx::FromRow)]
#[allow(dead_code)]
struct DbUser {
id: Uuid,
username: String,
display_name: String,
role: UserRole,
auth_provider: AuthProvider,
password_hash: Option<String>,
totp_secret: Option<String>,
mfa_enabled: bool,
is_active: bool,
force_password_reset: bool,
failed_login_attempts: i32,
locked_until: Option<chrono::DateTime<Utc>>,
}
/// Login request payload.
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
/// TOTP code (required if MFA is enabled).
pub totp_code: Option<String>,
}
/// Perform the full login flow for local accounts.
///
/// Steps:
/// 1. Look up user by username
/// 2. Verify password (Argon2id)
/// 3. Check account active state
/// 4. Verify MFA if enabled
/// 5. Issue access token + refresh token
/// 6. Update last_login_at
pub async fn login(
pool: &PgPool,
req: &LoginRequest,
signing_key_pem: &str,
access_ttl_secs: i64,
user_agent: Option<&str>,
ip_address: Option<&str>,
) -> Result<LoginResponse, SessionError> {
// 1. Fetch user by username
let user: Option<DbUser> = sqlx::query_as(
r#"
SELECT id, username, display_name, role, auth_provider,
password_hash, totp_secret, mfa_enabled, is_active, force_password_reset,
failed_login_attempts, locked_until
FROM users
WHERE username = $1 AND auth_provider = 'local'
"#,
)
.bind(&req.username)
.fetch_optional(pool)
.await?;
// Use constant-time comparison approach: always run Argon2 even on miss
let user = match user {
Some(u) => u,
None => {
// Prevent timing-based username enumeration
let _ = password::hash_password("dummy-timing-fill");
return Err(SessionError::InvalidCredentials);
},
};
// 2a. Check if account is locked due to too many failed attempts
if let Some(locked_until) = user.locked_until {
if locked_until > Utc::now() {
tracing::warn!(username = %req.username, "Login blocked: account locked until {}", locked_until);
return Err(SessionError::AccountLocked);
}
// Lockout period has expired — reset counters
sqlx::query(
"UPDATE users SET failed_login_attempts = 0, locked_until = NULL WHERE id = $1",
)
.bind(user.id)
.execute(pool)
.await?;
}
// 2. Verify password
let hash = user.password_hash.as_deref().unwrap_or("");
let valid = password::verify_password(&req.password, hash).unwrap_or(false);
if !valid {
// Increment failed login attempts
let new_attempts = user.failed_login_attempts + 1;
if new_attempts >= 5 {
let lock_until = Utc::now() + chrono::Duration::minutes(30);
sqlx::query(
"UPDATE users SET failed_login_attempts = $1, locked_until = $2 WHERE id = $3",
)
.bind(new_attempts)
.bind(lock_until)
.bind(user.id)
.execute(pool)
.await?;
tracing::warn!(username = %req.username, "Account locked after {} failed attempts", new_attempts);
} else {
sqlx::query("UPDATE users SET failed_login_attempts = $1 WHERE id = $2")
.bind(new_attempts)
.bind(user.id)
.execute(pool)
.await?;
}
tracing::warn!(username = %req.username, "Login failed: invalid password (attempt {})", new_attempts);
return Err(SessionError::InvalidCredentials);
}
// 3. Check account state
if !user.is_active {
tracing::warn!(username = %req.username, "Login failed: account disabled");
return Err(SessionError::AccountDisabled);
}
// 3b. Check if password reset is required
if user.force_password_reset {
tracing::warn!(username = %req.username, "Login blocked: password reset required");
return Err(SessionError::PasswordResetRequired);
}
// 4. MFA check
if user.mfa_enabled {
let code = req.totp_code.as_deref().ok_or(SessionError::MfaRequired)?;
let secret = user.totp_secret.as_deref().unwrap_or("");
let mfa_ok = mfa_totp::verify_code(&user.username, secret, code).unwrap_or(false);
if !mfa_ok {
tracing::warn!(username = %req.username, "Login failed: invalid MFA code");
return Err(SessionError::InvalidMfaCode);
}
}
// 5. Issue tokens
let access_token = jwt::issue_access_token(
user.id,
&user.username,
&user.role.to_string(),
access_ttl_secs,
signing_key_pem,
)?;
let raw_refresh = refresh::issue(pool, user.id, user_agent, ip_address).await?;
// 6. Update last_login_at
sqlx::query("UPDATE users SET last_login_at = $1, failed_login_attempts = 0, locked_until = NULL WHERE id = $2")
.bind(Utc::now())
.bind(user.id)
.execute(pool)
.await?;
tracing::info!(user_id = %user.id, username = %user.username, "Login successful");
Ok(LoginResponse {
access_token,
refresh_token: raw_refresh.0,
token_type: "Bearer".to_string(),
expires_in: access_ttl_secs,
user: SessionUser {
id: user.id.to_string(),
username: user.username,
display_name: user.display_name,
role: user.role.to_string(),
mfa_enabled: user.mfa_enabled,
},
})
}
/// Refresh an access token using a valid refresh token.
///
/// The old refresh token is revoked and a new one issued (rotation).
pub async fn refresh_session(
pool: &PgPool,
raw_refresh_token: &str,
signing_key_pem: &str,
access_ttl_secs: i64,
user_agent: Option<&str>,
ip_address: Option<&str>,
) -> Result<LoginResponse, SessionError> {
let (new_refresh, user_id) =
refresh::rotate(pool, raw_refresh_token, user_agent, ip_address).await?;
// Fetch user for token claims
let user: DbUser = sqlx::query_as(
r#"
SELECT id, username, display_name, role, auth_provider,
password_hash, totp_secret, mfa_enabled, is_active, force_password_reset,
failed_login_attempts, locked_until
FROM users WHERE id = $1
"#,
)
.bind(user_id)
.fetch_one(pool)
.await?;
if !user.is_active {
// Revoke all tokens and deny
let _ = refresh::revoke_all_for_user(pool, user_id).await;
return Err(SessionError::AccountDisabled);
}
let access_token = jwt::issue_access_token(
user.id,
&user.username,
&user.role.to_string(),
access_ttl_secs,
signing_key_pem,
)?;
Ok(LoginResponse {
access_token,
refresh_token: new_refresh.0,
token_type: "Bearer".to_string(),
expires_in: access_ttl_secs,
user: SessionUser {
id: user.id.to_string(),
username: user.username,
display_name: user.display_name,
role: user.role.to_string(),
mfa_enabled: user.mfa_enabled,
},
})
}
/// Logout: revoke the current refresh token.
pub async fn logout(pool: &PgPool, raw_refresh_token: &str) -> Result<(), SessionError> {
refresh::revoke(pool, raw_refresh_token).await?;
Ok(())
}
/// Force-logout: revoke all refresh tokens for a user.
pub async fn force_logout(pool: &PgPool, user_id: Uuid) -> Result<u64, SessionError> {
let count = refresh::revoke_all_for_user(pool, user_id).await?;
Ok(count)
}

25
crates/pm-ca/Cargo.toml Normal file
View File

@ -0,0 +1,25 @@
[package]
name = "pm-ca"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[dependencies]
pm-core = { path = "../pm-core" }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
chrono = { workspace = true }
sqlx = { workspace = true }
uuid = { workspace = true }
rand = { workspace = true }
hex = { workspace = true }
sha2 = { workspace = true }
rustls = { workspace = true }
rcgen = { workspace = true }
pem = { workspace = true }
time = { workspace = true }

527
crates/pm-ca/src/ca.rs Executable file
View File

@ -0,0 +1,527 @@
//! Internal Certificate Authority for Linux Patch Manager.
//!
//! Issues and renews mTLS client certificates and agent server certificates
//! for agent communication. Uses rcgen (ECDSA P-256) for all certificate
//! generation. CA key and certificate are stored on disk under `base_dir`
//! (default: /etc/patch-manager/ca/). Certificate metadata is persisted in
//! the `certificates` PostgreSQL table.
use std::net::IpAddr;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use rand::RngCore;
use rcgen::{
BasicConstraints, Certificate, CertificateParams, DistinguishedName, DnType,
ExtendedKeyUsagePurpose, Ia5String, IsCa, KeyPair, KeyUsagePurpose, SanType, SerialNumber,
PKCS_ECDSA_P256_SHA256,
};
use sqlx::{PgPool, Row};
use time::{Duration as TimeDuration, OffsetDateTime};
use uuid::Uuid;
// ---------------------------------------------------------------------------
// Public types
// ---------------------------------------------------------------------------
/// Returned by [`CertAuthority::issue_client_cert`] and [`CertAuthority::renew_cert`].
///
/// The private keys are intentionally **not** stored in the database.
#[derive(Debug, Clone)]
pub struct IssuedCert {
/// PEM-encoded client certificate (mTLS).
pub cert_pem: String,
/// PEM-encoded client private key (PKCS#8).
pub key_pem: String,
/// Hex-encoded 16-byte random serial number (client cert).
pub serial_number: String,
/// Certificate expiry timestamp (UTC).
pub expires_at: DateTime<Utc>,
/// PEM-encoded agent server certificate (for TLS listener).
pub server_cert_pem: String,
/// PEM-encoded agent server private key (PKCS#8).
pub server_key_pem: String,
/// Hex-encoded serial number of the server certificate.
pub server_serial_number: String,
/// PEM-encoded CA root certificate.
pub ca_root_pem: String,
}
// ---------------------------------------------------------------------------
// CertAuthority
// ---------------------------------------------------------------------------
/// Thread-safe, cloneable handle to the internal certificate authority.
///
/// CA certificate and key are held in memory as PEM strings; rcgen objects
/// are reconstructed on demand so this struct is unconditionally `Send + Sync`.
#[derive(Debug, Clone)]
pub struct CertAuthority {
#[allow(dead_code)]
base_dir: PathBuf,
/// PEM-encoded CA certificate (public cert only).
ca_cert_pem: String,
/// PEM-encoded CA private key (PKCS#8).
ca_key_pem: String,
}
// ---------------------------------------------------------------------------
// Private helpers
// ---------------------------------------------------------------------------
/// Generate a 16-byte cryptographically-random serial number.
/// Returns `(rcgen::SerialNumber, hex_encoded_string)`.
fn make_serial() -> (SerialNumber, String) {
let mut bytes = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut bytes);
let hex_serial = hex::encode(bytes);
let serial = SerialNumber::from_slice(&bytes);
(serial, hex_serial)
}
/// `OffsetDateTime::now_utc()` offset forward by `days` (for rcgen params).
fn odt_offset_days(days: i64) -> OffsetDateTime {
OffsetDateTime::now_utc() + TimeDuration::days(days)
}
/// `chrono::Utc::now()` offset forward by `days` (for DB / return values).
fn chrono_offset_days(days: i64) -> DateTime<Utc> {
Utc::now() + ChronoDuration::days(days)
}
/// Build a `CertificateParams` with common fields pre-filled.
/// Caller still needs to set `is_ca`, `key_usages`, `extended_key_usages`, and `subject_alt_names`.
fn base_params(cn: &str, validity_days: i64) -> Result<(CertificateParams, String, DateTime<Utc>)> {
let (serial, serial_hex) = make_serial();
let expires_at = chrono_offset_days(validity_days);
let mut params = CertificateParams::default();
params.not_before = OffsetDateTime::now_utc();
params.not_after = odt_offset_days(validity_days);
params.serial_number = Some(serial);
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, cn);
params.distinguished_name = dn;
Ok((params, serial_hex, expires_at))
}
/// Write `contents` to `path` and set Unix permissions to `0600`.
async fn write_protected(path: &Path, contents: &str) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
tokio::fs::write(path, contents).await?;
let perms = std::fs::Permissions::from_mode(0o600);
tokio::fs::set_permissions(path, perms).await?;
Ok(())
}
// ---------------------------------------------------------------------------
// impl CertAuthority
// ---------------------------------------------------------------------------
impl CertAuthority {
// -----------------------------------------------------------------------
// Construction
// -----------------------------------------------------------------------
/// Load an existing CA from disk, or generate a brand-new one if absent.
///
/// Files managed:
/// * `{base_dir}/ca.key` — PKCS#8 PEM private key (mode `0600`)
/// * `{base_dir}/ca.crt` — PEM certificate (mode `0600`)
///
/// On first generation the CA row is inserted into `certificates`
/// with `host_id = NULL` (marks it as the root CA record).
pub async fn init(base_dir: &Path, db: &PgPool) -> Result<Self> {
let key_path = base_dir.join("ca.key");
let crt_path = base_dir.join("ca.crt");
// ── Load existing CA ──────────────────────────────────────────────
if key_path.exists() && crt_path.exists() {
tracing::info!(path = %base_dir.display(), "Loading existing root CA from disk");
let ca_key_pem = tokio::fs::read_to_string(&key_path)
.await
.context("read ca.key")?;
let ca_cert_pem = tokio::fs::read_to_string(&crt_path)
.await
.context("read ca.crt")?;
// Validate that both PEMs parse without error.
KeyPair::from_pem(&ca_key_pem).context("parse CA private-key PEM")?;
CertificateParams::from_ca_cert_pem(&ca_cert_pem)
.context("parse CA certificate PEM")?;
tracing::info!("Root CA loaded successfully");
return Ok(Self {
base_dir: base_dir.to_owned(),
ca_cert_pem,
ca_key_pem,
});
}
// ── Generate new CA ───────────────────────────────────────────────
tracing::info!(
path = %base_dir.display(),
"Generating new root CA (ECDSA P-256, 10-year validity)"
);
tokio::fs::create_dir_all(base_dir)
.await
.context("create CA directory")?;
let ca_key =
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate CA key pair")?;
let (serial, serial_hex) = make_serial();
let expires_at = chrono_offset_days(365 * 10);
let mut params = CertificateParams::default();
params.not_before = OffsetDateTime::now_utc();
params.not_after = odt_offset_days(365 * 10);
params.serial_number = Some(serial);
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "Patch Manager Root CA");
dn.push(DnType::OrganizationName, "Patch Manager");
params.distinguished_name = dn;
let ca_cert_obj = params
.self_signed(&ca_key)
.context("self-sign CA certificate")?;
let ca_cert_pem = ca_cert_obj.pem();
let ca_key_pem = ca_key.serialize_pem();
write_protected(&key_path, &ca_key_pem)
.await
.context("write ca.key")?;
write_protected(&crt_path, &ca_cert_pem)
.await
.context("write ca.crt")?;
tracing::info!(
serial = %serial_hex,
expires_at = %expires_at,
"Root CA generated and written to disk"
);
// Persist CA cert metadata (host_id = NULL marks the root CA row).
sqlx::query(
"INSERT INTO certificates \
(host_id, serial_number, common_name, status, expires_at, cert_pem) \
VALUES (NULL, $1, 'Patch Manager Root CA', 'active'::cert_status, $2, $3)",
)
.bind(&serial_hex)
.bind(expires_at)
.bind(&ca_cert_pem)
.execute(db)
.await
.context("insert root CA cert into database")?;
tracing::info!("Root CA certificate recorded in database");
Ok(Self {
base_dir: base_dir.to_owned(),
ca_cert_pem,
ca_key_pem,
})
}
// -----------------------------------------------------------------------
// Public accessors
// -----------------------------------------------------------------------
/// Return the PEM-encoded root CA certificate (public cert only).
pub fn root_cert_pem(&self) -> &str {
&self.ca_cert_pem
}
// -----------------------------------------------------------------------
// Certificate issuance
// -----------------------------------------------------------------------
/// Issue a one-year mTLS client certificate for a managed host.
///
/// * Subject: `CN=<hostname>`
/// * Key usage: Digital Signature
/// * Extended key usage: Client Authentication
///
/// Also issues a server certificate for the agent's TLS listener
/// (see [`issue_server_cert`]).
///
/// The certificate PEMs are stored in `certificates`.
/// The private keys are returned to the caller **only** and never persisted.
pub async fn issue_client_cert(
&self,
host_id: Uuid,
hostname: &str,
ip_address: &str,
db: &PgPool,
) -> Result<IssuedCert> {
tracing::info!(host_id = %host_id, hostname, ip_address, "Issuing mTLS client certificate");
let key =
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate client key pair")?;
let (mut params, serial_hex, expires_at) = base_params(hostname, 365)?;
params.is_ca = IsCa::ExplicitNoCa;
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth];
let (ca_key, ca_cert) = self.ca_objects()?;
let cert = params
.signed_by(&key, &ca_cert, &ca_key)
.context("sign client cert with CA")?;
let cert_pem = cert.pem();
let key_pem = key.serialize_pem();
sqlx::query(
"INSERT INTO certificates \
(host_id, serial_number, common_name, status, expires_at, cert_pem) \
VALUES ($1, $2, $3, 'active'::cert_status, $4, $5)",
)
.bind(host_id)
.bind(&serial_hex)
.bind(hostname)
.bind(expires_at)
.bind(&cert_pem)
.execute(db)
.await
.context("insert client cert into database")?;
tracing::info!(
host_id = %host_id,
hostname,
serial = %serial_hex,
expires_at = %expires_at,
"Client certificate issued successfully"
);
// Also issue a server certificate for the agent's TLS listener.
let (server_cert_pem, server_key_pem, server_serial_number, _server_expires_at) = self
.issue_server_cert(host_id, hostname, ip_address, db)
.await?;
Ok(IssuedCert {
cert_pem,
key_pem,
serial_number: serial_hex,
expires_at,
server_cert_pem,
server_key_pem,
server_serial_number,
ca_root_pem: self.root_cert_pem().to_owned(),
})
}
/// Issue a one-year server certificate for a managed host's agent TLS listener.
///
/// * Subject: `CN=<hostname>-server`
/// * Key usage: Digital Signature
/// * Extended key usage: Server Authentication
/// * SANs: DNS `<hostname>` + IP `<ip_address>` (if valid)
///
/// The certificate PEM is stored in `certificates` with common_name
/// `{hostname}-server` to distinguish from client certs.
/// The private key is returned to the caller **only** and never persisted.
///
/// Returns `(cert_pem, key_pem, serial_number, expires_at)`.
pub async fn issue_server_cert(
&self,
host_id: Uuid,
hostname: &str,
ip_address: &str,
db: &PgPool,
) -> Result<(String, String, String, DateTime<Utc>)> {
tracing::info!(host_id = %host_id, hostname, ip_address, "Issuing agent server certificate");
let key =
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate server key pair")?;
let server_cn = format!("{hostname}-server");
let (mut params, serial_hex, expires_at) = base_params(&server_cn, 365)?;
params.is_ca = IsCa::ExplicitNoCa;
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
// Build SANs: DNS hostname + optional IP address
let mut sans = vec![SanType::DnsName(
Ia5String::try_from(hostname.to_owned()).context("hostname is not valid IA5")?,
)];
// Strip CIDR netmask (e.g. "192.168.3.36/32") before parsing
let ip_str = ip_address.split('/').next().unwrap_or(ip_address);
if let Ok(ip) = ip_str.parse::<IpAddr>() {
sans.push(SanType::IpAddress(ip));
} else {
tracing::warn!(
ip_address,
"Could not parse IP address for server cert SANs, skipping IP SAN"
);
}
params.subject_alt_names = sans;
let (ca_key, ca_cert) = self.ca_objects()?;
let cert = params
.signed_by(&key, &ca_cert, &ca_key)
.context("sign server cert with CA")?;
let cert_pem = cert.pem();
let key_pem = key.serialize_pem();
sqlx::query(
"INSERT INTO certificates \
(host_id, serial_number, common_name, status, expires_at, cert_pem) \
VALUES ($1, $2, $3, 'active'::cert_status, $4, $5)",
)
.bind(host_id)
.bind(&serial_hex)
.bind(&server_cn)
.bind(expires_at)
.bind(&cert_pem)
.execute(db)
.await
.context("insert server cert into database")?;
tracing::info!(
host_id = %host_id,
hostname,
serial = %serial_hex,
expires_at = %expires_at,
"Server certificate issued successfully"
);
Ok((cert_pem, key_pem, serial_hex, expires_at))
}
/// Revoke a certificate by database ID.
///
/// Sets `status = 'revoked'` and `revoked_at = NOW()` in the `certificates` table.
/// Does **not** reissue a replacement; use [`renew_cert`] for that.
pub async fn revoke_cert(&self, cert_id: Uuid, db: &PgPool) -> Result<()> {
tracing::info!(cert_id = %cert_id, "Revoking certificate");
let rows = sqlx::query(
"UPDATE certificates \
SET status = 'revoked'::cert_status, revoked_at = NOW() \
WHERE id = $1",
)
.bind(cert_id)
.execute(db)
.await
.context("revoke certificate in database")?;
if rows.rows_affected() == 0 {
anyhow::bail!("certificate not found: {}", cert_id);
}
tracing::info!(cert_id = %cert_id, "Certificate revoked");
Ok(())
}
/// Renew a certificate: revoke the existing cert and issue a new one with
/// the same `host_id` and `common_name`.
///
/// Also issues a new server certificate and populates all `IssuedCert` fields.
/// The host's IP address is looked up from the database for server cert SANs.
pub async fn renew_cert(&self, cert_id: Uuid, db: &PgPool) -> Result<IssuedCert> {
tracing::info!(cert_id = %cert_id, "Renewing certificate");
// Fetch the existing cert's host_id and common_name.
let row = sqlx::query("SELECT host_id, common_name FROM certificates WHERE id = $1")
.bind(cert_id)
.fetch_one(db)
.await
.context("fetch certificate for renewal")?;
let host_id: Uuid = row
.try_get("host_id")
.context("certificate has no host_id (cannot renew root CA)")?;
let common_name: String = row.try_get("common_name").context("fetch common_name")?;
// Look up the host's IP address for the server cert SANs.
let ip_address: String =
sqlx::query_scalar("SELECT ip_address::text FROM hosts WHERE id = $1")
.bind(host_id)
.fetch_one(db)
.await
.context("fetch host IP address for renewal")?;
// Revoke the old cert first.
self.revoke_cert(cert_id, db).await?;
// Issue a fresh cert with the same CN.
let issued = self
.issue_client_cert(host_id, &common_name, &ip_address, db)
.await?;
tracing::info!(
old_cert_id = %cert_id,
new_serial = %issued.serial_number,
"Certificate renewed"
);
Ok(issued)
}
/// Generate a self-signed TLS certificate for the web UI using the CA.
///
/// * Subject: `CN=<hostname>`
/// * Key usage: Digital Signature
/// * Extended key usage: Server Authentication
/// * SAN: DNS `<hostname>`
///
/// Returns `(cert_pem, key_pem)`. This certificate is **not** stored in the
/// database; it is intended for runtime use only.
pub async fn issue_web_tls_cert(&self, hostname: &str) -> Result<(String, String)> {
tracing::info!(hostname, "Issuing web TLS certificate");
let key =
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate web TLS key pair")?;
let (mut params, serial_hex, expires_at) = base_params(hostname, 365)?;
params.is_ca = IsCa::ExplicitNoCa;
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.subject_alt_names = vec![SanType::DnsName(
Ia5String::try_from(hostname.to_owned()).context("hostname is not valid IA5")?,
)];
let (ca_key, ca_cert) = self.ca_objects()?;
let cert = params
.signed_by(&key, &ca_cert, &ca_key)
.context("sign web TLS cert with CA")?;
let cert_pem = cert.pem();
let key_pem = key.serialize_pem();
tracing::info!(
hostname,
serial = %serial_hex,
expires_at = %expires_at,
"Web TLS certificate issued"
);
Ok((cert_pem, key_pem))
}
// -----------------------------------------------------------------------
// Private helpers
// -----------------------------------------------------------------------
/// Reconstruct rcgen `(KeyPair, Certificate)` from the in-memory PEM strings.
///
/// The returned `Certificate` is used solely as an issuer reference when
/// signing leaf certificates; it is never distributed directly.
fn ca_objects(&self) -> Result<(KeyPair, Certificate)> {
let key =
KeyPair::from_pem(&self.ca_key_pem).context("reconstruct CA key pair from PEM")?;
let params = CertificateParams::from_ca_cert_pem(&self.ca_cert_pem)
.context("reconstruct CA params from PEM")?;
let cert = params
.self_signed(&key)
.context("reconstruct CA certificate for signing")?;
Ok((key, cert))
}
}

7
crates/pm-ca/src/lib.rs Executable file
View File

@ -0,0 +1,7 @@
//! pm-ca — Internal Certificate Authority.
//!
//! Issues and renews mTLS client certificates for agent communication.
//! Uses rcgen + rustls. CA key stored at /etc/patch-manager/ca/.
pub mod ca;
pub use ca::{CertAuthority, IssuedCert};

26
crates/pm-core/Cargo.toml Normal file
View File

@ -0,0 +1,26 @@
[package]
name = "pm-core"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[dependencies]
tokio = { workspace = true }
sqlx = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
toml = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
ulid = { workspace = true }
chrono = { workspace = true }
config = { workspace = true }
axum = { workspace = true }
sha2 = { workspace = true }
hex = { workspace = true }
aes-gcm = { workspace = true }
rand = { workspace = true }

330
crates/pm-core/src/audit.rs Executable file
View File

@ -0,0 +1,330 @@
//! Audit log helper functions.
//!
//! Writes tamper-evident, hash-chained audit events to the `audit_log` table.
//! The hash chain: each row's `row_hash` = SHA-256(
//! prev_hash || action || actor_user_id || actor_username ||
//! target_type || target_id || details_json || ip_address ||
//! request_id || created_at
//! ).
//!
//! The `prev_hash` column stores the previous row's `row_hash` for chain
//! verification. The first row has `prev_hash = ''`.
use sha2::{Digest, Sha256};
use sqlx::PgPool;
use std::net::IpAddr;
use uuid::Uuid;
/// Audit event categories (must match the `audit_action` PostgreSQL ENUM).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuditAction {
UserLogin,
UserLogout,
UserLoginFailed,
UserCreated,
UserDeleted,
UserUpdated,
HostRegistered,
HostRemoved,
GroupCreated,
GroupDeleted,
GroupMembershipChanged,
PatchJobCreated,
PatchJobCancelled,
PatchJobRollback,
MaintenanceWindowCreated,
MaintenanceWindowUpdated,
MaintenanceWindowDeleted,
CertificateIssued,
CertificateRenewed,
CertificateRevoked,
CertificateDownloaded,
ConfigChanged,
DiscoveryScanStarted,
// M11 additions
AuditIntegrityVerified,
EmailNotificationSent,
PatchJobCompleted,
PatchJobFailed,
MaintenanceWindowReminder,
HealthCheckCreated,
HealthCheckUpdated,
HealthCheckDeleted,
CertificateReissued,
}
impl AuditAction {
pub fn as_str(&self) -> &'static str {
match self {
Self::UserLogin => "user_login",
Self::UserLogout => "user_logout",
Self::UserLoginFailed => "user_login_failed",
Self::UserCreated => "user_created",
Self::UserDeleted => "user_deleted",
Self::UserUpdated => "user_updated",
Self::HostRegistered => "host_registered",
Self::HostRemoved => "host_removed",
Self::GroupCreated => "group_created",
Self::GroupDeleted => "group_deleted",
Self::GroupMembershipChanged => "group_membership_changed",
Self::PatchJobCreated => "patch_job_created",
Self::PatchJobCancelled => "patch_job_cancelled",
Self::PatchJobRollback => "patch_job_rollback",
Self::MaintenanceWindowCreated => "maintenance_window_created",
Self::MaintenanceWindowUpdated => "maintenance_window_updated",
Self::MaintenanceWindowDeleted => "maintenance_window_deleted",
Self::CertificateIssued => "certificate_issued",
Self::CertificateRenewed => "certificate_renewed",
Self::CertificateRevoked => "certificate_revoked",
Self::CertificateDownloaded => "certificate_downloaded",
Self::ConfigChanged => "config_changed",
Self::DiscoveryScanStarted => "discovery_scan_started",
Self::AuditIntegrityVerified => "audit_integrity_verified",
Self::EmailNotificationSent => "email_notification_sent",
Self::PatchJobCompleted => "patch_job_completed",
Self::PatchJobFailed => "patch_job_failed",
Self::MaintenanceWindowReminder => "maintenance_window_reminder",
Self::HealthCheckCreated => "health_check_created",
Self::HealthCheckUpdated => "health_check_updated",
Self::HealthCheckDeleted => "health_check_deleted",
Self::CertificateReissued => "certificate_reissued",
}
}
}
/// Write an audit event to the database.
///
/// Computes a hash chain entry using the previous row's hash.
/// Non-fatal: logs errors but does not propagate them to avoid
/// disrupting the primary operation.
#[allow(clippy::too_many_arguments)]
pub async fn log_event(
pool: &PgPool,
action: AuditAction,
actor_user_id: Option<Uuid>,
actor_username: Option<&str>,
target_type: Option<&str>,
target_id: Option<&str>,
details: serde_json::Value,
ip_address: Option<IpAddr>,
request_id: Option<&str>,
) {
let result = write_audit_row(
pool,
action,
actor_user_id,
actor_username,
target_type,
target_id,
details,
ip_address,
request_id,
)
.await;
if let Err(e) = result {
tracing::error!(error = %e, action = action.as_str(), "Failed to write audit log");
}
}
#[allow(clippy::too_many_arguments)]
async fn write_audit_row(
pool: &PgPool,
action: AuditAction,
actor_user_id: Option<Uuid>,
actor_username: Option<&str>,
target_type: Option<&str>,
target_id: Option<&str>,
details: serde_json::Value,
ip_address: Option<IpAddr>,
request_id: Option<&str>,
) -> Result<(), sqlx::Error> {
// Fetch previous hash for chain
let prev_hash: Option<String> =
sqlx::query_scalar("SELECT row_hash FROM audit_log ORDER BY id DESC LIMIT 1")
.fetch_optional(pool)
.await?;
let prev = prev_hash.unwrap_or_default();
let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true);
let action_str = action.as_str();
let uid_str = actor_user_id.map(|u| u.to_string()).unwrap_or_default();
let uname = actor_username.unwrap_or("");
let ttype = target_type.unwrap_or("");
let tid = target_id.unwrap_or("");
let details_str = serde_json::to_string(&details).unwrap_or_default();
let ip_str = ip_address.map(|ip| ip.to_string()).unwrap_or_default();
let rid = request_id.unwrap_or("");
// Hash: SHA-256(prev_hash + action + actor_user_id + actor_username +
// target_type + target_id + details_json + ip_address +
// request_id + created_at)
let mut hasher = Sha256::new();
hasher.update(prev.as_bytes());
hasher.update(action_str.as_bytes());
hasher.update(uid_str.as_bytes());
hasher.update(uname.as_bytes());
hasher.update(ttype.as_bytes());
hasher.update(tid.as_bytes());
hasher.update(details_str.as_bytes());
hasher.update(ip_str.as_bytes());
hasher.update(rid.as_bytes());
hasher.update(now.as_bytes());
let row_hash = hex::encode(hasher.finalize());
let ip_for_db = ip_address.map(|ip| ip.to_string());
sqlx::query(
r#"
INSERT INTO audit_log
(action, actor_user_id, actor_username, target_type, target_id,
details, ip_address, request_id, created_at, row_hash, prev_hash)
VALUES
($1::audit_action, $2, $3, $4, $5, $6, $7::inet, $8, $9::timestamptz, $10, $11)
"#,
)
.bind(action_str)
.bind(actor_user_id)
.bind(actor_username)
.bind(target_type)
.bind(target_id)
.bind(details)
.bind(ip_for_db)
.bind(request_id)
.bind(&now)
.bind(&row_hash)
.bind(&prev)
.execute(pool)
.await?;
Ok(())
}
/// Result of an audit integrity verification pass.
#[derive(Debug, serde::Serialize)]
pub struct IntegrityResult {
/// Whether the chain is intact (no tampering detected).
pub intact: bool,
/// Total number of rows checked.
pub rows_checked: i64,
/// List of errors found (row id, expected hash, actual hash).
pub errors: Vec<IntegrityError>,
}
/// A single integrity error detected in the audit chain.
#[derive(Debug, serde::Serialize)]
pub struct IntegrityError {
pub row_id: i64,
pub expected_hash: String,
pub actual_hash: String,
}
/// Row read from audit_log for integrity verification.
#[derive(Debug, sqlx::FromRow)]
struct AuditRow {
id: i64,
action: String,
actor_user_id: Option<uuid::Uuid>,
actor_username: Option<String>,
target_type: Option<String>,
target_id: Option<String>,
details: Option<serde_json::Value>,
ip_address: Option<String>,
request_id: Option<String>,
created_at: Option<chrono::DateTime<chrono::Utc>>,
row_hash: String,
prev_hash: String,
}
/// Walk the audit_log rows ordered by id and verify each row_hash matches
/// the recomputed hash. Returns an [`IntegrityResult`] describing any
/// tampering detected.
pub async fn verify_integrity(pool: &PgPool) -> IntegrityResult {
let rows: Vec<AuditRow> = match sqlx::query_as(
r#"
SELECT id, action::text AS action, actor_user_id, actor_username,
target_type, target_id, details,
host(ip_address) AS ip_address,
request_id, created_at, row_hash, prev_hash
FROM audit_log
ORDER BY id ASC
"#,
)
.fetch_all(pool)
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "verify_integrity: failed to fetch audit rows");
return IntegrityResult {
intact: false,
rows_checked: 0,
errors: vec![],
};
},
};
let mut errors = Vec::new();
let mut expected_prev_hash = String::new();
for row in &rows {
// Verify prev_hash linkage
if row.prev_hash != expected_prev_hash {
errors.push(IntegrityError {
row_id: row.id,
expected_hash: expected_prev_hash.clone(),
actual_hash: row.prev_hash.clone(),
});
}
// Recompute the row hash from all fields
let uid_str = row.actor_user_id.map(|u| u.to_string()).unwrap_or_default();
let uname = row.actor_username.as_deref().unwrap_or("");
let ttype = row.target_type.as_deref().unwrap_or("");
let tid = row.target_id.as_deref().unwrap_or("");
let details_str = row
.details
.as_ref()
.and_then(|v| serde_json::to_string(v).ok())
.unwrap_or_default();
let ip_str = row.ip_address.as_deref().unwrap_or("");
let rid = row.request_id.as_deref().unwrap_or("");
let created_str = row
.created_at
.map(|c| c.to_rfc3339_opts(chrono::SecondsFormat::Micros, true))
.unwrap_or_default();
let mut hasher = Sha256::new();
hasher.update(row.prev_hash.as_bytes());
hasher.update(row.action.as_bytes());
hasher.update(uid_str.as_bytes());
hasher.update(uname.as_bytes());
hasher.update(ttype.as_bytes());
hasher.update(tid.as_bytes());
hasher.update(details_str.as_bytes());
hasher.update(ip_str.as_bytes());
hasher.update(rid.as_bytes());
hasher.update(created_str.as_bytes());
let computed_hash = hex::encode(hasher.finalize());
if row.row_hash != computed_hash {
errors.push(IntegrityError {
row_id: row.id,
expected_hash: computed_hash,
actual_hash: row.row_hash.clone(),
});
}
// Next row should have this row's hash as prev_hash
expected_prev_hash = row.row_hash.clone();
}
let intact = errors.is_empty();
let rows_checked = rows.len() as i64;
IntegrityResult {
intact,
rows_checked,
errors,
}
}

View File

@ -0,0 +1,214 @@
use config::{Config, ConfigError, Environment, File};
use serde::{Deserialize, Serialize};
/// Rate limiting configuration per route group.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimitConfig {
/// Enrollment endpoint: requests per minute per IP (default: 5)
#[serde(default = "default_enrollment_rpm")]
pub enrollment_rpm: u32,
/// Enrollment burst allowance (default: 3)
#[serde(default = "default_enrollment_burst")]
pub enrollment_burst: u32,
/// Public auth endpoints: requests per minute per IP (default: 20)
#[serde(default = "default_auth_rpm")]
pub auth_rpm: u32,
/// Auth burst allowance (default: 10)
#[serde(default = "default_auth_burst")]
pub auth_burst: u32,
/// Authenticated API: requests per minute per IP (default: 120)
#[serde(default = "default_api_rpm")]
pub api_rpm: u32,
/// API burst allowance (default: 30)
#[serde(default = "default_api_burst")]
pub api_burst: u32,
}
fn default_enrollment_rpm() -> u32 {
5
}
fn default_enrollment_burst() -> u32 {
3
}
fn default_auth_rpm() -> u32 {
20
}
fn default_auth_burst() -> u32 {
10
}
fn default_api_rpm() -> u32 {
120
}
fn default_api_burst() -> u32 {
30
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enrollment_rpm: default_enrollment_rpm(),
enrollment_burst: default_enrollment_burst(),
auth_rpm: default_auth_rpm(),
auth_burst: default_auth_burst(),
api_rpm: default_api_rpm(),
api_burst: default_api_burst(),
}
}
}
/// Top-level application configuration.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub worker: WorkerConfig,
pub logging: LoggingConfig,
pub security: SecurityConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
/// Bind address for the web server
pub host: String,
/// HTTPS port
pub port: u16,
/// Path to static frontend assets
pub static_dir: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DatabaseConfig {
/// Full PostgreSQL connection URL
pub url: String,
/// Maximum pool connections
pub max_connections: u32,
/// Minimum pool connections
pub min_connections: u32,
/// Connection acquire timeout in seconds
pub acquire_timeout_secs: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkerConfig {
/// Health poll interval in seconds (default: 300 = 5 min)
pub health_poll_interval_secs: u64,
/// Patch data poll interval in seconds (default: 1800 = 30 min)
pub patch_poll_interval_secs: u64,
/// Health check poll interval in seconds (default: 300 = 5 min)
#[serde(default = "default_health_check_poll_interval")]
pub health_check_poll_interval_secs: u64,
/// Maximum concurrent agent calls
pub max_concurrent_agent_calls: usize,
/// Worker heartbeat interval in seconds
pub heartbeat_interval_secs: u64,
/// WS relay HTTP polling fallback interval in seconds (default: 10)
pub ws_relay_poll_interval_secs: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LoggingConfig {
/// Log level filter: trace, debug, info, warn, error
pub level: String,
/// Output format: json or pretty
pub format: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SecurityConfig {
/// IP whitelist (CIDR or individual IPs); empty = allow all (not recommended)
pub ip_whitelist: Vec<String>,
/// JWT signing key path (Ed25519 PEM)
pub jwt_signing_key_path: String,
/// JWT verification key path (Ed25519 public PEM)
pub jwt_verify_key_path: String,
/// JWT access token TTL in seconds (default: 900 = 15 min)
pub jwt_access_ttl_secs: u64,
/// Agent mTLS client cert path
pub agent_client_cert_path: String,
/// Agent mTLS client key path
pub agent_client_key_path: String,
/// Internal CA cert path
pub ca_cert_path: String,
/// Internal CA key path
pub ca_key_path: String,
/// Web UI TLS cert path
pub web_tls_cert_path: String,
/// Web UI TLS key path
pub web_tls_key_path: String,
/// Frontend URL to redirect to after SSO callback (default: http://localhost:5173/auth/sso/callback)
#[serde(default = "default_sso_callback_url")]
pub sso_callback_url: String,
}
impl AppConfig {
/// Load configuration from a TOML file and environment variable overrides.
///
/// Environment variables follow the pattern: `PATCH_MANAGER__SECTION__KEY`
/// e.g. `PATCH_MANAGER__DATABASE__URL=postgres://...`
pub fn load(config_path: &str) -> Result<Self, ConfigError> {
let cfg = Config::builder()
.add_source(File::with_name(config_path).required(false))
.add_source(
Environment::with_prefix("PATCH_MANAGER")
.separator("__")
.try_parsing(true),
)
.build()?;
cfg.try_deserialize()
}
}
fn default_health_check_poll_interval() -> u64 {
300
}
fn default_sso_callback_url() -> String {
"http://localhost:5173/auth/sso/callback".to_string()
}
impl Default for AppConfig {
fn default() -> Self {
Self {
server: ServerConfig {
host: "0.0.0.0".to_string(),
port: 443,
static_dir: "/usr/share/patch-manager/frontend".to_string(),
},
database: DatabaseConfig {
url: "postgres://patch_manager:changeme@localhost/patch_manager".to_string(),
max_connections: 20,
min_connections: 2,
acquire_timeout_secs: 30,
},
worker: WorkerConfig {
health_poll_interval_secs: 300,
patch_poll_interval_secs: 1800,
health_check_poll_interval_secs: 300,
max_concurrent_agent_calls: 64,
heartbeat_interval_secs: 30,
ws_relay_poll_interval_secs: 10,
},
logging: LoggingConfig {
level: "info".to_string(),
format: "json".to_string(),
},
security: SecurityConfig {
ip_whitelist: vec![],
jwt_signing_key_path: "/etc/patch-manager/jwt/signing.pem".to_string(),
jwt_verify_key_path: "/etc/patch-manager/jwt/verify.pem".to_string(),
jwt_access_ttl_secs: 900,
agent_client_cert_path: "/etc/patch-manager/certs/client.crt".to_string(),
agent_client_key_path: "/etc/patch-manager/certs/client.key".to_string(),
ca_cert_path: "/etc/patch-manager/ca/ca.crt".to_string(),
ca_key_path: "/etc/patch-manager/ca/ca.key".to_string(),
web_tls_cert_path: "/etc/patch-manager/tls/web.crt".to_string(),
web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(),
sso_callback_url: default_sso_callback_url(),
},
rate_limit: RateLimitConfig::default(),
}
}
}

80
crates/pm-core/src/crypto.rs Executable file
View File

@ -0,0 +1,80 @@
//! AES-256-GCM encryption for sensitive health check credentials.
//!
//! Uses a per-install key stored at `/etc/patch-manager/keys/health-check.key`.
use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Nonce,
};
use rand::RngCore;
use std::fs;
use std::path::Path;
pub const KEY_PATH: &str = "/etc/patch-manager/keys/health-check.key";
/// Load or create the per-install encryption key.
/// If the key file doesn't exist, generates a new 256-bit key and saves it.
pub fn load_or_create_key(path: &Path) -> Result<[u8; 32], CryptoError> {
if path.exists() {
let key_bytes = fs::read(path).map_err(CryptoError::Io)?;
if key_bytes.len() != 32 {
return Err(CryptoError::InvalidKeyLength(key_bytes.len()));
}
let mut key = [0u8; 32];
key.copy_from_slice(&key_bytes);
Ok(key)
} else {
let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(CryptoError::Io)?;
}
fs::write(path, key).map_err(CryptoError::Io)?;
// Set permissions to 0600 (owner read/write only)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(path, fs::Permissions::from_mode(0o600))
.map_err(CryptoError::Io)?;
}
Ok(key)
}
}
/// Encrypt plaintext with AES-256-GCM. Returns (ciphertext, nonce).
pub fn encrypt(plaintext: &str, key: &[u8; 32]) -> Result<(Vec<u8>, Vec<u8>), CryptoError> {
let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::KeyInit(e.to_string()))?;
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|_| CryptoError::EncryptionFailed)?;
Ok((ciphertext, nonce_bytes.to_vec()))
}
/// Decrypt AES-256-GCM ciphertext with the given nonce.
pub fn decrypt(ciphertext: &[u8], nonce: &[u8], key: &[u8; 32]) -> Result<String, CryptoError> {
let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::KeyInit(e.to_string()))?;
let nonce = Nonce::from_slice(nonce);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|_| CryptoError::DecryptionFailed)?;
String::from_utf8(plaintext).map_err(CryptoError::Utf8)
}
#[derive(Debug, thiserror::Error)]
pub enum CryptoError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Invalid key length: expected 32 bytes, got {0}")]
InvalidKeyLength(usize),
#[error("Key init error: {0}")]
KeyInit(String),
#[error("Encryption failed")]
EncryptionFailed,
#[error("Decryption failed")]
DecryptionFailed,
#[error("UTF-8 error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
}

117
crates/pm-core/src/db.rs Executable file
View File

@ -0,0 +1,117 @@
use crate::config::DatabaseConfig;
use crate::models::{CreateEnrollmentRequest, EnrollmentRequest};
use sqlx::postgres::{PgPool, PgPoolOptions};
use std::time::Duration;
use uuid::Uuid;
/// Initialize and return a PostgreSQL connection pool.
pub async fn init_pool(cfg: &DatabaseConfig) -> Result<PgPool, sqlx::Error> {
let pool = PgPoolOptions::new()
.max_connections(cfg.max_connections)
.min_connections(cfg.min_connections)
.acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
.connect(&cfg.url)
.await?;
tracing::info!(
max_connections = cfg.max_connections,
"PostgreSQL connection pool initialized"
);
Ok(pool)
}
/// Run embedded SQLx migrations.
/// Uses a PostgreSQL advisory lock to ensure only one writer runs migrations.
pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::migrate::MigrateError> {
tracing::info!("Acquiring advisory lock for migrations");
// Advisory lock key — consistent hash of the application name
const LOCK_KEY: i64 = 0x7061_7463_686d_6772; // "patchmgr" bytes
// Acquire advisory lock; blocks until granted
sqlx::query("SELECT pg_advisory_lock($1)")
.bind(LOCK_KEY)
.execute(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to acquire advisory lock");
e
})
.expect("Advisory lock must be acquired before running migrations");
tracing::info!("Running database migrations");
let result = sqlx::migrate!("../../migrations").run(pool).await;
// Always release the lock
sqlx::query("SELECT pg_advisory_unlock($1)")
.bind(LOCK_KEY)
.execute(pool)
.await
.ok();
match &result {
Ok(_) => tracing::info!("Database migrations completed successfully"),
Err(e) => tracing::error!(error = %e, "Database migrations failed"),
}
result
}
// ============================================================
// Enrollment Requests
// ============================================================
pub async fn create_enrollment_request(
pool: &PgPool,
req: CreateEnrollmentRequest,
token_hash: String,
) -> Result<EnrollmentRequest, sqlx::Error> {
sqlx::query_as::<
_,
EnrollmentRequest,
>(
r#"
INSERT INTO enrollment_requests (machine_id, fqdn, ip_address, os_details, polling_token, hostname)
VALUES ($1, $2, $3::inet, $4, $5, $6)
RETURNING id, machine_id, fqdn, ip_address::text, os_details, polling_token, hostname, created_at, expires_at
"#,
)
.bind(req.machine_id)
.bind(req.fqdn)
.bind(req.ip_address)
.bind(req.os_details)
.bind(token_hash)
.bind(&req.hostname)
.fetch_one(pool)
.await
}
pub async fn list_enrollment_requests(
pool: &PgPool,
) -> Result<Vec<EnrollmentRequest>, sqlx::Error> {
sqlx::query_as::<_, EnrollmentRequest>(
"SELECT id, machine_id, fqdn, ip_address::text, os_details, polling_token, hostname, created_at, expires_at FROM enrollment_requests ORDER BY created_at DESC",
)
.fetch_all(pool)
.await
}
pub async fn delete_enrollment_request(pool: &PgPool, id: Uuid) -> Result<u64, sqlx::Error> {
let result = sqlx::query("DELETE FROM enrollment_requests WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
/// Check that the database schema is at the expected version.
/// Used by the worker to wait until migrations have been applied.
pub async fn check_schema_version(pool: &PgPool) -> Result<i64, sqlx::Error> {
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_migrations WHERE success = true")
.fetch_one(pool)
.await?;
Ok(row.0)
}

126
crates/pm-core/src/error.rs Executable file
View File

@ -0,0 +1,126 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Unified application error type.
#[derive(Debug, Error)]
pub enum AppError {
#[error("Not found: {0}")]
NotFound(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Conflict: {0}")]
Conflict(String),
#[error("Unprocessable entity: {0}")]
UnprocessableEntity(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Internal error: {0}")]
Internal(#[from] anyhow::Error),
#[error("Configuration error: {0}")]
Config(String),
}
/// JSON error envelope returned to clients.
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorDetail {
/// Machine-readable error code (e.g. "not_found", "unauthorized")
pub code: String,
/// Human-readable message
pub message: String,
/// Request ID for correlation (set by middleware)
pub request_id: Option<String>,
/// Optional structured details
pub details: Option<serde_json::Value>,
}
impl ErrorResponse {
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
Self {
error: ErrorDetail {
code: code.into(),
message: message.into(),
request_id: None,
details: None,
},
}
}
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
self.error.request_id = Some(request_id.into());
self
}
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.error.details = Some(details);
self
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, code, message) = match &self {
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, "not_found", msg.clone()),
AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, "unauthorized", msg.clone()),
AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, "forbidden", msg.clone()),
AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, "bad_request", msg.clone()),
AppError::Conflict(msg) => (StatusCode::CONFLICT, "conflict", msg.clone()),
AppError::UnprocessableEntity(msg) => (
StatusCode::UNPROCESSABLE_ENTITY,
"unprocessable_entity",
msg.clone(),
),
AppError::Database(e) => {
tracing::error!(error = %e, "Database error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"An internal error occurred".to_string(),
)
},
AppError::Internal(e) => {
tracing::error!(error = %e, "Internal error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"An internal error occurred".to_string(),
)
},
AppError::Config(msg) => {
tracing::error!(error = %msg, "Configuration error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"config_error",
"Server configuration error".to_string(),
)
},
};
let body = ErrorResponse::new(code, message);
(status, Json(body)).into_response()
}
}
/// Convenience alias for handler return types.
pub type ApiResult<T> = Result<T, AppError>;

23
crates/pm-core/src/lib.rs Executable file
View File

@ -0,0 +1,23 @@
pub mod audit;
pub mod config;
pub mod crypto;
pub mod db;
pub mod error;
pub mod logging;
pub mod models;
pub mod request_id;
// Re-export commonly used types
pub use config::AppConfig;
pub use crypto::{decrypt, encrypt, load_or_create_key, CryptoError, KEY_PATH};
pub use error::{AppError, ErrorResponse};
pub use models::{
AdminResetPasswordRequest, AuthProvider, ChangePasswordRequest, CreateGroupRequest,
CreateHealthCheckRequest, CreateHostRequest, CreateUserRequest, DiscoveryCidrRequest,
DiscoveryResult, Group, HealthCheck, HealthCheckResult, HealthCheckWithResult, Host,
HostHealthStatus, HostSummary, RegisterDiscoveredRequest, UpdateGroupRequest,
UpdateHealthCheckRequest, UpdateUserRequest, User, UserRole as DbUserRole,
};
// Re-export audit integrity types
pub use audit::{verify_integrity, IntegrityError, IntegrityResult};

31
crates/pm-core/src/logging.rs Executable file
View File

@ -0,0 +1,31 @@
use crate::config::LoggingConfig;
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
/// Initialize the global tracing subscriber.
///
/// Format is controlled by `cfg.format`:
/// - `"json"` — machine-readable JSON (production default)
/// - anything else — human-readable pretty output (development)
///
/// Log level is controlled by `cfg.level` (e.g. `"info"`, `"debug"`).
/// The `RUST_LOG` environment variable overrides `cfg.level`.
pub fn init(cfg: &LoggingConfig) {
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cfg.level));
match cfg.format.as_str() {
"json" => {
tracing_subscriber::registry()
.with(filter)
.with(fmt::layer().json().with_current_span(true))
.init();
},
_ => {
tracing_subscriber::registry()
.with(filter)
.with(fmt::layer().pretty())
.init();
},
}
tracing::info!(format = %cfg.format, level = %cfg.level, "Logging initialized");
}

555
crates/pm-core/src/models.rs Executable file
View File

@ -0,0 +1,555 @@
//! Shared database model types used across pm-web and pm-worker.
//!
//! These match the database schema defined in migrations/.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
// ============================================================
// Enumerations (matching PostgreSQL ENUM types)
// ============================================================
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "lowercase")]
#[sqlx(type_name = "host_health_status", rename_all = "lowercase")]
pub enum HostHealthStatus {
Pending,
Healthy,
Degraded,
Unreachable,
}
impl std::fmt::Display for HostHealthStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Pending => write!(f, "pending"),
Self::Healthy => write!(f, "healthy"),
Self::Degraded => write!(f, "degraded"),
Self::Unreachable => write!(f, "unreachable"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "lowercase")]
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
pub enum UserRole {
Admin,
Operator,
Reporter,
}
impl std::fmt::Display for UserRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Admin => write!(f, "admin"),
Self::Operator => write!(f, "operator"),
Self::Reporter => write!(f, "reporter"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "auth_provider", rename_all = "snake_case")]
pub enum AuthProvider {
Local,
#[sqlx(rename = "azure_sso")]
AzureSso,
Keycloak,
Oidc,
}
impl std::fmt::Display for AuthProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Local => write!(f, "local"),
Self::AzureSso => write!(f, "azure_sso"),
Self::Keycloak => write!(f, "keycloak"),
Self::Oidc => write!(f, "oidc"),
}
}
}
// ============================================================
// Host
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Host {
pub id: Uuid,
pub fqdn: String,
pub ip_address: String, // stored as INET, returned as text
pub display_name: String,
pub os_family: Option<String>,
pub os_name: Option<String>,
pub arch: Option<String>,
pub agent_version: Option<String>,
pub health_status: HostHealthStatus,
pub last_health_at: Option<DateTime<Utc>>,
pub last_patch_at: Option<DateTime<Utc>>,
pub agent_port: i32,
pub notes: String,
pub registered_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Payload for registering a new host.
#[derive(Debug, Deserialize)]
pub struct CreateHostRequest {
/// FQDN or IP address of the managed host
pub fqdn: String,
pub display_name: Option<String>,
pub agent_port: Option<i32>,
pub notes: Option<String>,
pub group_ids: Option<Vec<Uuid>>,
}
/// Payload for updating an existing host.
#[derive(Debug, Deserialize)]
pub struct UpdateHostRequest {
pub fqdn: Option<String>,
pub ip_address: Option<String>,
pub display_name: Option<String>,
}
/// Host list item (lighter projection for list views)
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct HostSummary {
pub id: Uuid,
pub fqdn: String,
pub ip_address: String,
pub display_name: String,
pub os_family: Option<String>,
pub os_name: Option<String>,
pub health_status: HostHealthStatus,
pub agent_version: Option<String>,
pub patches_missing: i32,
pub health_check_status: Option<String>,
pub registered_at: DateTime<Utc>,
}
// ============================================================
// Host Enrollment
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct EnrollmentRequest {
pub id: Uuid,
pub machine_id: String,
pub fqdn: String,
pub ip_address: String,
pub os_details: serde_json::Value,
pub polling_token: String,
/// Short hostname provided during enrollment (optional).
pub hostname: Option<String>,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
/// Payload for initial host enrollment request.
#[derive(Debug, Deserialize, Serialize)]
pub struct CreateEnrollmentRequest {
pub machine_id: String,
pub fqdn: String,
pub ip_address: String,
pub os_details: serde_json::Value,
/// Short hostname (from /etc/hostname, optional).
pub hostname: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "lowercase")]
pub enum EnrollmentStatusResponse {
Pending,
Approved {
ca_crt: String,
server_crt: String,
server_key: String,
},
Denied,
NotFound,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PkiBundle {
pub ca_crt: String,
pub server_crt: String,
pub server_key: String,
}
// ============================================================
// Health Checks
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct HealthCheck {
pub id: Uuid,
pub host_id: Uuid,
pub name: String,
pub check_type: String, // "service" or "http"
pub enabled: bool,
// Service check fields
pub service_name: Option<String>,
// HTTP check fields
pub url: Option<String>,
pub expected_body: Option<String>,
pub ignore_cert_errors: bool,
pub basic_auth_user: Option<String>,
pub target_host_id: Option<Uuid>,
// basic_auth_pass_encrypted and nonce NOT exposed in API responses
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckWithResult {
#[serde(flatten)]
pub check: HealthCheck,
pub last_result: Option<HealthCheckResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct HealthCheckResult {
pub id: Uuid,
pub check_id: Uuid,
pub healthy: bool,
pub detail: Option<String>,
pub latency_ms: Option<i32>,
pub checked_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateHealthCheckRequest {
pub name: String,
pub check_type: String, // "service" or "http"
pub service_name: Option<String>,
pub url: Option<String>,
pub expected_body: Option<String>,
#[serde(default = "default_true")]
pub ignore_cert_errors: bool,
pub basic_auth_user: Option<String>,
pub basic_auth_pass: Option<String>, // plaintext in request, encrypted before storage
pub target_host_id: Option<Uuid>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateHealthCheckRequest {
pub name: Option<String>,
pub enabled: Option<bool>,
pub service_name: Option<String>,
pub url: Option<String>,
pub expected_body: Option<String>,
pub ignore_cert_errors: Option<bool>,
pub basic_auth_user: Option<String>,
pub basic_auth_pass: Option<String>, // if provided, re-encrypt
pub target_host_id: Option<Uuid>,
}
fn default_true() -> bool {
true
}
// ============================================================
// Group
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Group {
pub id: Uuid,
pub name: String,
pub description: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
pub struct CreateGroupRequest {
pub name: String,
pub description: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateGroupRequest {
pub name: Option<String>,
pub description: Option<String>,
}
// ============================================================
// User
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct User {
pub id: Uuid,
pub username: String,
pub display_name: String,
pub email: String,
pub role: UserRole,
pub auth_provider: AuthProvider,
pub mfa_enabled: bool,
pub is_active: bool,
pub force_password_reset: bool,
pub last_login_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// User create payload (admin-only)
#[derive(Debug, Deserialize)]
pub struct CreateUserRequest {
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub role: String,
pub password: String,
}
/// User update payload
#[derive(Debug, Deserialize)]
pub struct UpdateUserRequest {
pub display_name: Option<String>,
pub email: Option<String>,
pub role: Option<String>,
pub is_active: Option<bool>,
pub force_password_reset: Option<bool>,
}
/// Self-service password change payload
#[derive(Debug, Deserialize)]
pub struct ChangePasswordRequest {
pub current_password: String,
pub new_password: String,
}
/// Admin password reset payload
#[derive(Debug, Deserialize)]
pub struct AdminResetPasswordRequest {
pub new_password: String,
#[serde(default)]
pub force_password_reset: bool,
}
// ============================================================
// Discovery
// ============================================================
/// Request body for CIDR auto-discovery scan.
#[derive(Debug, Deserialize)]
pub struct DiscoveryCidrRequest {
/// CIDR range to scan (e.g. "10.0.0.0/24")
pub cidr: String,
/// Agent port to probe (default 12443)
pub agent_port: Option<i32>,
}
/// A single discovered host result.
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct DiscoveryResult {
pub id: Uuid,
pub scan_id: Uuid,
pub ip_address: String,
pub fqdn: Option<String>,
pub agent_version: Option<String>,
pub os_name: Option<String>,
pub agent_port: i32,
pub discovered_at: DateTime<Utc>,
pub registered: bool,
}
/// Payload for registering a host from a discovery result.
#[derive(Debug, Deserialize)]
pub struct RegisterDiscoveredRequest {
pub discovery_id: Uuid,
pub display_name: Option<String>,
pub group_ids: Option<Vec<Uuid>>,
}
// ============================================================
// Patch Jobs
// ============================================================
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "lowercase")]
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
pub enum JobStatus {
Queued,
Pending,
Running,
Succeeded,
Failed,
Cancelled,
}
impl std::fmt::Display for JobStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Queued => write!(f, "queued"),
Self::Pending => write!(f, "pending"),
Self::Running => write!(f, "running"),
Self::Succeeded => write!(f, "succeeded"),
Self::Failed => write!(f, "failed"),
Self::Cancelled => write!(f, "cancelled"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "job_kind", rename_all = "snake_case")]
pub enum JobKind {
#[sqlx(rename = "patch_apply")]
PatchApply,
#[sqlx(rename = "patch_remove")]
PatchRemove,
Reboot,
Rollback,
}
/// Full `patch_jobs` row.
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct PatchJob {
pub id: Uuid,
pub kind: JobKind,
pub status: JobStatus,
pub created_by_user_id: Option<Uuid>,
pub parent_job_id: Option<Uuid>,
pub maintenance_window_id: Option<Uuid>,
pub immediate: bool,
pub patch_selection: serde_json::Value,
pub notes: String,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
/// Full `patch_job_hosts` row (includes columns added in migration 003).
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct PatchJobHost {
pub id: Uuid,
pub job_id: Uuid,
pub host_id: Uuid,
pub status: JobStatus,
pub agent_job_id: Option<String>,
pub retry_count: i32,
pub output: String,
pub error_message: Option<String>,
pub retry_next_at: Option<DateTime<Utc>>,
pub last_error: Option<String>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
/// Request payload for creating a patch job via `POST /api/v1/jobs`.
#[derive(Debug, Deserialize)]
pub struct CreateJobRequest {
/// Host IDs to patch.
pub host_ids: Vec<Uuid>,
/// Package names to apply (empty = all available patches).
pub packages: Vec<String>,
/// If true: apply immediately. If false: queue for next maintenance window.
pub immediate: bool,
/// Optional maintenance window to bind to.
pub maintenance_window_id: Option<Uuid>,
/// Allow reboot if required by patches.
pub allow_reboot: Option<bool>,
/// Optional operator notes.
pub notes: Option<String>,
}
/// Summary row for job list view (aggregates per-host counts).
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct PatchJobSummary {
pub id: Uuid,
pub kind: JobKind,
pub status: JobStatus,
pub immediate: bool,
pub host_count: i64,
pub succeeded_count: i64,
pub failed_count: i64,
pub notes: String,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
// ============================================================
// Maintenance Windows
// ============================================================
/// Recurrence type for a maintenance window.
/// Mirrors the `window_recurrence` PostgreSQL ENUM.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "lowercase")]
#[sqlx(type_name = "window_recurrence", rename_all = "lowercase")]
pub enum WindowRecurrence {
/// Single one-time window (at `start_at` for `duration_minutes` minutes).
Once,
/// Repeats every day at the time portion of `start_at`.
Daily,
/// Repeats on the day-of-week in `recurrence_day` (0 = Sunday).
Weekly,
/// Repeats on the day-of-month in `recurrence_day` (1-31).
Monthly,
}
impl std::fmt::Display for WindowRecurrence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Once => write!(f, "once"),
Self::Daily => write!(f, "daily"),
Self::Weekly => write!(f, "weekly"),
Self::Monthly => write!(f, "monthly"),
}
}
}
/// Full row from `maintenance_windows`.
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct MaintenanceWindow {
pub id: Uuid,
pub host_id: Uuid,
pub label: String,
pub recurrence: WindowRecurrence,
/// Absolute start time (one-time) or time-of-day reference (recurring).
pub start_at: DateTime<Utc>,
/// Duration of the window in minutes.
pub duration_minutes: i32,
/// Day-of-week (0=Sun, weekly) or day-of-month (1-31, monthly); NULL for once/daily.
pub recurrence_day: Option<i32>,
pub enabled: bool,
pub auto_apply: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Payload for `POST /api/v1/hosts/{id}/maintenance-windows`.
#[derive(Debug, Deserialize)]
pub struct CreateMaintenanceWindowRequest {
pub label: String,
pub recurrence: WindowRecurrence,
/// RFC 3339 / ISO 8601 timestamp (UTC recommended).
pub start_at: DateTime<Utc>,
/// How many minutes the window is open (default 60).
pub duration_minutes: Option<i32>,
/// Required for `weekly` (0-6) and `monthly` (1-31).
pub recurrence_day: Option<i32>,
/// Whether the window is active (default true).
pub enabled: Option<bool>,
/// Whether to auto-create a patch_apply job when this window opens and patches are pending (default true).
pub auto_apply: Option<bool>,
}
/// Payload for `PUT /api/v1/hosts/{id}/maintenance-windows/{window_id}`.
#[derive(Debug, Deserialize)]
pub struct UpdateMaintenanceWindowRequest {
pub label: Option<String>,
pub recurrence: Option<WindowRecurrence>,
pub start_at: Option<DateTime<Utc>>,
pub duration_minutes: Option<i32>,
pub recurrence_day: Option<i32>,
pub enabled: Option<bool>,
pub auto_apply: Option<bool>,
}

View File

@ -0,0 +1,39 @@
use axum::{extract::Request, http::HeaderValue, middleware::Next, response::Response};
use ulid::Ulid;
/// HTTP header name for request correlation IDs.
pub const REQUEST_ID_HEADER: &str = "x-request-id";
/// Axum middleware that generates a ULID request ID and attaches it
/// to both the request extensions and the response header.
pub async fn request_id_middleware(mut req: Request, next: Next) -> Response {
// Use existing X-Request-Id if provided by upstream proxy, else generate
let id = req
.headers()
.get(REQUEST_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| Ulid::new().to_string());
// Insert as extension so handlers can access it
req.extensions_mut().insert(RequestId(id.clone()));
let mut response = next.run(req).await;
// Echo the ID back in the response
if let Ok(value) = HeaderValue::from_str(&id) {
response.headers_mut().insert(REQUEST_ID_HEADER, value);
}
response
}
/// Extractor type for retrieving the request ID inside handlers.
#[derive(Debug, Clone)]
pub struct RequestId(pub String);
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

View File

@ -0,0 +1,24 @@
[package]
name = "pm-reports"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[dependencies]
pm-core = { path = "../pm-core" }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
chrono = { workspace = true }
sqlx = { workspace = true }
uuid = { workspace = true }
# Report generation
csv = "1"
printpdf = { version = "0.7", features = ["embedded_images"] }
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder", "line_series", "area_series", "ttf"] }
image = { version = "0.25", default-features = false, features = ["png"] }

351
crates/pm-reports/src/csv.rs Executable file
View File

@ -0,0 +1,351 @@
//! CSV report generation for pm-reports.
use crate::{ReportParams, ReportType};
use anyhow::Context;
/// Generate a CSV report and return the raw bytes.
pub async fn generate_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
match params.report_type {
ReportType::Compliance => compliance_csv(pool, params).await,
ReportType::PatchHistory => patch_history_csv(pool, params).await,
ReportType::Vulnerability => vulnerability_csv(pool, params).await,
ReportType::Audit => audit_csv(pool, params).await,
}
}
// ---------------------------------------------------------------------------
// Compliance
// ---------------------------------------------------------------------------
async fn compliance_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
let rows = if let Some(gid) = params.group_id {
sqlx::query(
"
SELECT
h.id::text AS host_id,
h.display_name,
h.fqdn,
h.health_status::text AS health_status,
h.last_patch_at,
COALESCE(jsonb_array_length(pd.installed_packages), 0) AS total_packages,
COALESCE(pd.patch_count, 0) AS pending_patches,
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages), 0) = 0 THEN 100.0
ELSE ROUND(CAST((1.0 - pd.patch_count::float / NULLIF(jsonb_array_length(pd.installed_packages), 0)) * 100 AS numeric), 1)
END)::float8 AS compliance_pct,
COALESCE(string_agg(DISTINCT g.name, ', '), '') AS group_names
FROM hosts h
LEFT JOIN host_patch_data pd ON pd.host_id = h.id
LEFT JOIN host_groups hg ON hg.host_id = h.id
LEFT JOIN groups g ON g.id = hg.group_id
WHERE h.id IN (
SELECT host_id FROM host_groups WHERE group_id = $1
)
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
ORDER BY compliance_pct ASC
",
)
.bind(gid)
.fetch_all(pool)
.await
.context("compliance query (group filter) failed")?
} else {
sqlx::query(
"
SELECT
h.id::text AS host_id,
h.display_name,
h.fqdn,
h.health_status::text AS health_status,
h.last_patch_at,
COALESCE(jsonb_array_length(pd.installed_packages), 0) AS total_packages,
COALESCE(pd.patch_count, 0) AS pending_patches,
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages), 0) = 0 THEN 100.0
ELSE ROUND(CAST((1.0 - pd.patch_count::float / NULLIF(jsonb_array_length(pd.installed_packages), 0)) * 100 AS numeric), 1)
END)::float8 AS compliance_pct,
COALESCE(string_agg(DISTINCT g.name, ', '), '') AS group_names
FROM hosts h
LEFT JOIN host_patch_data pd ON pd.host_id = h.id
LEFT JOIN host_groups hg ON hg.host_id = h.id
LEFT JOIN groups g ON g.id = hg.group_id
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
ORDER BY compliance_pct ASC
",
)
.fetch_all(pool)
.await
.context("compliance query failed")?
};
let mut wtr = csv::Writer::from_writer(vec![]);
wtr.write_record([
"host_id",
"display_name",
"fqdn",
"group_names",
"total_packages",
"pending_patches",
"compliance_pct",
"last_patch_at",
"health_status",
])?;
for row in &rows {
use sqlx::Row;
let host_id: String = row.try_get("host_id").unwrap_or_default();
let display_name: String = row.try_get("display_name").unwrap_or_default();
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
let group_names: String = row.try_get("group_names").unwrap_or_default();
let total_packages: i64 = row.try_get("total_packages").unwrap_or(0);
let pending_patches: i64 = row.try_get("pending_patches").unwrap_or(0);
let compliance_pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0);
let last_patch_at: Option<chrono::DateTime<chrono::Utc>> =
row.try_get("last_patch_at").unwrap_or(None);
let health_status: String = row.try_get("health_status").unwrap_or_default();
wtr.write_record(&[
host_id,
display_name,
fqdn,
group_names,
total_packages.to_string(),
pending_patches.to_string(),
format!("{:.1}", compliance_pct),
last_patch_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
health_status,
])?;
}
wtr.into_inner().context("csv flush failed")
}
// ---------------------------------------------------------------------------
// Patch history
// ---------------------------------------------------------------------------
async fn patch_history_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
let rows = sqlx::query(
"
SELECT
pj.id::text AS job_id,
pj.kind::text AS job_kind,
pj.status::text AS job_status,
h.display_name,
h.fqdn,
jsonb_array_length(COALESCE(pj.patch_selection->'packages', '[]'::jsonb)) AS package_count,
pjh.started_at,
pjh.completed_at,
EXTRACT(EPOCH FROM (pjh.completed_at - pjh.started_at))::bigint AS duration_seconds,
COALESCE(u.username, 'system') AS operator
FROM patch_job_hosts pjh
JOIN patch_jobs pj ON pj.id = pjh.job_id
JOIN hosts h ON h.id = pjh.host_id
LEFT JOIN users u ON u.id = pj.created_by_user_id
WHERE ($1::timestamptz IS NULL OR pjh.started_at >= $1)
AND ($2::timestamptz IS NULL OR pjh.started_at <= $2)
ORDER BY pjh.started_at DESC
",
)
.bind(params.from)
.bind(params.to)
.fetch_all(pool)
.await
.context("patch history query failed")?;
let mut wtr = csv::Writer::from_writer(vec![]);
wtr.write_record([
"job_id",
"job_kind",
"job_status",
"host_display_name",
"host_fqdn",
"package_count",
"started_at",
"completed_at",
"duration_seconds",
"operator",
])?;
for row in &rows {
use sqlx::Row;
let job_id: String = row.try_get("job_id").unwrap_or_default();
let job_kind: String = row.try_get("job_kind").unwrap_or_default();
let job_status: String = row.try_get("job_status").unwrap_or_default();
let display_name: String = row.try_get("display_name").unwrap_or_default();
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
let package_count: i64 = row.try_get("package_count").unwrap_or(0);
let started_at: Option<chrono::DateTime<chrono::Utc>> =
row.try_get("started_at").unwrap_or(None);
let completed_at: Option<chrono::DateTime<chrono::Utc>> =
row.try_get("completed_at").unwrap_or(None);
let duration_seconds: Option<i64> = row.try_get("duration_seconds").unwrap_or(None);
let operator: String = row.try_get("operator").unwrap_or_default();
wtr.write_record(&[
job_id,
job_kind,
job_status,
display_name,
fqdn,
package_count.to_string(),
started_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
completed_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
duration_seconds.unwrap_or(0).to_string(),
operator,
])?;
}
wtr.into_inner().context("csv flush failed")
}
// ---------------------------------------------------------------------------
// Vulnerability
// ---------------------------------------------------------------------------
async fn vulnerability_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
let mut wtr = csv::Writer::from_writer(vec![]);
wtr.write_record([
"host_id",
"display_name",
"fqdn",
"cve_id",
"package_name",
"severity",
"available_version",
"last_seen_at",
])?;
let result = sqlx::query(
"
SELECT
h.id::text AS host_id,
h.display_name,
h.fqdn,
cve_id,
patch->>'name' AS package_name,
patch->>'severity' AS severity,
patch->>'available_version' AS available_version,
pd.polled_at AS last_seen_at
FROM hosts h
JOIN host_patch_data pd ON pd.host_id = h.id
CROSS JOIN LATERAL jsonb_array_elements(COALESCE(pd.available_patches, '[]'::jsonb)) AS patch
CROSS JOIN LATERAL jsonb_array_elements_text(COALESCE(patch->'cve_ids', '[]'::jsonb)) AS cve_id
WHERE ($1::timestamptz IS NULL OR pd.polled_at >= $1)
AND ($2::timestamptz IS NULL OR pd.polled_at <= $2)
ORDER BY
CASE patch->>'severity'
WHEN 'critical' THEN 1
WHEN 'high' THEN 2
WHEN 'medium' THEN 3
ELSE 4
END,
h.display_name
",
)
.bind(params.from)
.bind(params.to)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
for row in &rows {
use sqlx::Row;
let host_id: String = row.try_get("host_id").unwrap_or_default();
let display_name: String = row.try_get("display_name").unwrap_or_default();
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
let cve_id: String = row.try_get("cve_id").unwrap_or_default();
let package_name: String = row.try_get("package_name").unwrap_or_default();
let severity: String = row.try_get("severity").unwrap_or_default();
let available_version: String =
row.try_get("available_version").unwrap_or_default();
let last_seen_at: Option<chrono::DateTime<chrono::Utc>> =
row.try_get("last_seen_at").unwrap_or(None);
wtr.write_record(&[
host_id,
display_name,
fqdn,
cve_id,
package_name,
severity,
available_version,
last_seen_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
])?;
}
},
Err(e) => {
tracing::warn!(error = %e, "vulnerability query failed — returning header-only CSV");
// Return header-only CSV (no invalid comment rows)
},
}
wtr.into_inner().context("csv flush failed")
}
// ---------------------------------------------------------------------------
// Audit
// ---------------------------------------------------------------------------
async fn audit_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
let rows = sqlx::query(
"
SELECT
id::text AS id,
created_at,
action::text AS action,
actor_username,
target_type,
target_id,
ip_address::text AS ip_address,
request_id
FROM audit_log
WHERE ($1::timestamptz IS NULL OR created_at >= $1)
AND ($2::timestamptz IS NULL OR created_at <= $2)
ORDER BY created_at DESC
LIMIT 10000
",
)
.bind(params.from)
.bind(params.to)
.fetch_all(pool)
.await
.context("audit query failed")?;
let mut wtr = csv::Writer::from_writer(vec![]);
wtr.write_record([
"id",
"created_at",
"action",
"actor_username",
"target_type",
"target_id",
"ip_address",
"request_id",
])?;
for row in &rows {
use sqlx::Row;
let id: String = row.try_get("id").unwrap_or_default();
let created_at: Option<chrono::DateTime<chrono::Utc>> =
row.try_get("created_at").unwrap_or(None);
let action: String = row.try_get("action").unwrap_or_default();
let actor_username: String = row.try_get("actor_username").unwrap_or_default();
let target_type: String = row.try_get("target_type").unwrap_or_default();
let target_id: String = row.try_get("target_id").unwrap_or_default();
let ip_address: String = row.try_get("ip_address").unwrap_or_default();
let request_id: String = row.try_get("request_id").unwrap_or_default();
wtr.write_record(&[
id,
created_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
action,
actor_username,
target_type,
target_id,
ip_address,
request_id,
])?;
}
wtr.into_inner().context("csv flush failed")
}

27
crates/pm-reports/src/lib.rs Executable file
View File

@ -0,0 +1,27 @@
//! pm-reports — CSV and PDF report generation.
//!
//! Uses printpdf + plotters for in-process PDF with charts.
pub mod csv;
pub mod pdf;
pub use csv::generate_csv;
pub use pdf::generate_pdf;
/// The type of report to generate.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum ReportType {
Compliance,
PatchHistory,
Vulnerability,
Audit,
}
/// Parameters controlling report generation.
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ReportParams {
pub report_type: ReportType,
pub from: Option<chrono::DateTime<chrono::Utc>>,
pub to: Option<chrono::DateTime<chrono::Utc>>,
pub group_id: Option<uuid::Uuid>,
}

599
crates/pm-reports/src/pdf.rs Executable file
View File

@ -0,0 +1,599 @@
//! PDF report generation for pm-reports.
//!
//! Uses printpdf for document structure and plotters + image for embedded charts.
use crate::{ReportParams, ReportType};
use anyhow::Context;
use plotters::prelude::*;
use printpdf::{
BuiltinFont, ColorBits, ColorSpace, Image, ImageTransform, ImageXObject, IndirectFontRef, Mm,
PdfDocument, PdfLayerIndex, PdfLayerReference, PdfPageIndex, Px,
};
const PAGE_W: f32 = 297.0; // A4 landscape width (mm)
const PAGE_H: f32 = 210.0; // A4 landscape height (mm)
const MARGIN: f32 = 10.0;
const ROW_H: f32 = 6.0;
const HEADER_Y_START: f32 = 190.0;
const NEW_PAGE_THRESHOLD: f32 = 20.0;
// ---------------------------------------------------------------------------
// Public entry point
// ---------------------------------------------------------------------------
/// Generate a PDF report and return the raw bytes.
pub async fn generate_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
match params.report_type {
ReportType::Compliance => compliance_pdf(pool, params).await,
ReportType::PatchHistory => patch_history_pdf(pool, params).await,
ReportType::Vulnerability => vulnerability_pdf(pool, params).await,
ReportType::Audit => audit_pdf(pool, params).await,
}
}
// ---------------------------------------------------------------------------
// Chart helper
// ---------------------------------------------------------------------------
/// Render a bar chart to an in-memory PNG and return the raw PNG bytes.
fn render_bar_chart(
labels: &[String],
values: &[f64],
title: &str,
) -> anyhow::Result<(Vec<u8>, u32, u32)> {
const W: u32 = 800;
const H: u32 = 400;
let mut pixel_buf = vec![0u8; (W * H * 3) as usize];
// Guard: skip rendering for empty or mismatched data
if labels.is_empty() || values.is_empty() || labels.len() != values.len() {
anyhow::bail!("cannot render bar chart with empty or mismatched data");
}
// Guard: reject NaN / infinity which would panic in plotters
if values.iter().any(|v| !v.is_finite()) {
anyhow::bail!("bar chart values contain non-finite numbers");
}
{
let root = BitMapBackend::with_buffer(&mut pixel_buf, (W, H)).into_drawing_area();
root.fill(&WHITE)?;
let max_val = values.iter().cloned().fold(0.0_f64, f64::max).max(1.0);
let n = labels.len().max(1);
let mut chart = ChartBuilder::on(&root)
.caption(title, ("sans-serif", 20).into_font())
.margin(20u32)
.x_label_area_size(60u32)
.y_label_area_size(50u32)
.build_cartesian_2d(0..n, 0.0..max_val * 1.1)?;
chart
.configure_mesh()
.x_labels(n.min(20))
.x_label_formatter(&|idx| {
labels
.get(*idx)
.map(|s| {
if s.len() > 12 {
s[..12].to_string()
} else {
s.clone()
}
})
.unwrap_or_default()
})
.y_desc("Value")
.draw()?;
chart.draw_series((0..n).map(|i| {
let v = values.get(i).copied().unwrap_or(0.0);
let color = if v >= 90.0 {
RGBColor(76, 175, 80)
} else if v >= 70.0 {
RGBColor(255, 193, 7)
} else {
RGBColor(244, 67, 54)
};
Rectangle::new([(i, 0.0), (i + 1, v)], color.filled())
}))?;
root.present()?;
}
// Return raw RGB pixels + dimensions for direct PDF embedding
Ok((pixel_buf, W, H))
}
// ---------------------------------------------------------------------------
// PDF builder
// ---------------------------------------------------------------------------
struct PdfBuilder {
doc: printpdf::PdfDocumentReference,
font: IndirectFontRef,
font_bold: IndirectFontRef,
page_idx: PdfPageIndex,
layer_idx: PdfLayerIndex,
current_y: f32,
}
impl PdfBuilder {
fn new(title: &str) -> anyhow::Result<Self> {
let doc = PdfDocument::empty(title);
let (page_idx, layer_idx) = doc.add_page(Mm(PAGE_W), Mm(PAGE_H), "Layer 1");
let font = doc.add_builtin_font(BuiltinFont::Helvetica)?;
let font_bold = doc.add_builtin_font(BuiltinFont::HelveticaBold)?;
Ok(Self {
doc,
font,
font_bold,
page_idx,
layer_idx,
current_y: HEADER_Y_START,
})
}
fn layer(&self) -> PdfLayerReference {
self.doc.get_page(self.page_idx).get_layer(self.layer_idx)
}
fn write_text(&self, s: &str, font_size: f32, x: f32, y: f32, bold: bool) {
let f = if bold { &self.font_bold } else { &self.font };
self.layer().use_text(s, font_size, Mm(x), Mm(y), f);
}
fn new_page(&mut self) {
let (pi, li) = self.doc.add_page(Mm(PAGE_W), Mm(PAGE_H), "Layer 1");
self.page_idx = pi;
self.layer_idx = li;
self.current_y = HEADER_Y_START;
}
fn ensure_space(&mut self, needed: f32) {
if self.current_y - needed < NEW_PAGE_THRESHOLD {
self.new_page();
}
}
fn table_row(&mut self, cells: &[&str], col_x: &[f32], font_size: f32, bold: bool) {
self.ensure_space(ROW_H);
let y = self.current_y;
for (i, cell) in cells.iter().enumerate() {
let x = col_x.get(i).copied().unwrap_or(MARGIN);
let s = if cell.len() > 30 { &cell[..30] } else { cell };
self.write_text(s, font_size, x, y, bold);
}
self.current_y -= ROW_H;
}
#[allow(clippy::too_many_arguments)]
fn embed_image(
&self,
raw_rgb: Vec<u8>,
img_w: u32,
img_h: u32,
x_mm: f32,
y_mm: f32,
scale_x: f32,
scale_y: f32,
) -> anyhow::Result<()> {
// Validate dimensions and buffer size to prevent panics
let expected_len = (img_w as usize) * (img_h as usize) * 3;
if raw_rgb.len() != expected_len || img_w == 0 || img_h == 0 {
anyhow::bail!(
"image buffer size mismatch: expected {} bytes for {}x{} RGB, got {}",
expected_len,
img_w,
img_h,
raw_rgb.len()
);
}
if !scale_x.is_finite() || !scale_y.is_finite() || scale_x <= 0.0 || scale_y <= 0.0 {
anyhow::bail!(
"invalid image scale factors: scale_x={}, scale_y={}",
scale_x,
scale_y
);
}
let xobj = ImageXObject {
width: Px(img_w as usize),
height: Px(img_h as usize),
color_space: ColorSpace::Rgb,
bits_per_component: ColorBits::Bit8,
interpolate: true,
image_data: raw_rgb,
image_filter: None,
smask: None,
clipping_bbox: None,
};
let pdf_img = Image::from(xobj);
pdf_img.add_to_layer(
self.layer(),
ImageTransform {
translate_x: Some(Mm(x_mm)),
translate_y: Some(Mm(y_mm)),
scale_x: Some(scale_x),
scale_y: Some(scale_y),
dpi: Some(150.0),
..Default::default()
},
);
Ok(())
}
fn save(self) -> anyhow::Result<Vec<u8>> {
Ok(self.doc.save_to_bytes()?)
}
}
// ---------------------------------------------------------------------------
// Title page helper
// ---------------------------------------------------------------------------
fn write_title_page(pdf: &mut PdfBuilder, title: &str, params: &ReportParams) {
pdf.write_text(title, 24.0, MARGIN, 160.0, true);
pdf.write_text(
&format!(
"Generated: {}",
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC")
),
11.0,
MARGIN,
148.0,
false,
);
if let Some(from) = params.from {
pdf.write_text(
&format!("From: {}", from.format("%Y-%m-%d")),
10.0,
MARGIN,
140.0,
false,
);
}
if let Some(to) = params.to {
pdf.write_text(
&format!("To: {}", to.format("%Y-%m-%d")),
10.0,
MARGIN,
134.0,
false,
);
}
if let Some(gid) = params.group_id {
pdf.write_text(&format!("Group: {}", gid), 10.0, MARGIN, 128.0, false);
}
pdf.new_page();
}
// ---------------------------------------------------------------------------
// Compliance PDF
// ---------------------------------------------------------------------------
async fn compliance_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
use sqlx::Row;
let rows = if let Some(gid) = params.group_id {
sqlx::query(
"
SELECT h.display_name, h.fqdn,
COALESCE(jsonb_array_length(pd.installed_packages),0) AS total_packages,
COALESCE(pd.patch_count,0) AS pending_patches,
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages),0)=0 THEN 100.0
ELSE ROUND(CAST((1.0-pd.patch_count::float/NULLIF(jsonb_array_length(pd.installed_packages),0))*100 AS numeric),1)
END)::float8 AS compliance_pct,
h.health_status::text AS health_status
FROM hosts h LEFT JOIN host_patch_data pd ON pd.host_id=h.id
WHERE h.id IN (SELECT host_id FROM host_groups WHERE group_id=$1)
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
ORDER BY compliance_pct ASC",
)
.bind(gid)
.fetch_all(pool)
.await
.context("compliance PDF query (group) failed")?
} else {
sqlx::query(
"
SELECT h.display_name, h.fqdn,
COALESCE(jsonb_array_length(pd.installed_packages),0) AS total_packages,
COALESCE(pd.patch_count,0) AS pending_patches,
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages),0)=0 THEN 100.0
ELSE ROUND(CAST((1.0-pd.patch_count::float/NULLIF(jsonb_array_length(pd.installed_packages),0))*100 AS numeric),1)
END)::float8 AS compliance_pct,
h.health_status::text AS health_status
FROM hosts h LEFT JOIN host_patch_data pd ON pd.host_id=h.id
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
ORDER BY compliance_pct ASC",
)
.fetch_all(pool)
.await
.context("compliance PDF query failed")?
};
let labels: Vec<String> = rows
.iter()
.map(|r| r.try_get::<String, _>("display_name").unwrap_or_default())
.collect();
let values: Vec<f64> = rows
.iter()
.map(|r| r.try_get::<f64, _>("compliance_pct").unwrap_or(0.0))
.collect();
let mut pdf = PdfBuilder::new("Compliance Report")?;
write_title_page(&mut pdf, "Compliance Report", params);
let col_x: &[f32] = &[MARGIN, 65.0, 130.0, 165.0, 200.0, 235.0];
pdf.table_row(
&[
"Host",
"FQDN",
"Total Pkgs",
"Pending",
"Compliance %",
"Status",
],
col_x,
9.0,
true,
);
for row in &rows {
let name: String = row.try_get("display_name").unwrap_or_default();
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
let total: i64 = row.try_get("total_packages").unwrap_or(0);
let pend: i64 = row.try_get("pending_patches").unwrap_or(0);
let pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0);
let status: String = row.try_get("health_status").unwrap_or_default();
pdf.table_row(
&[
&name,
&fqdn,
&total.to_string(),
&pend.to_string(),
&format!("{:.1}%", pct),
&status,
],
col_x,
8.0,
false,
);
}
if !labels.is_empty() {
match render_bar_chart(&labels, &values, "Compliance % by Host") {
Ok((raw, w, h)) => {
pdf.new_page();
pdf.write_text("Compliance Chart", 16.0, MARGIN, 200.0, true);
if let Err(e) = pdf.embed_image(raw, w, h, MARGIN, 10.0, 0.28, 0.28) {
tracing::warn!(error = %e, "chart embed failed");
}
},
Err(e) => tracing::warn!(error = %e, "chart render failed"),
}
}
pdf.save()
}
// ---------------------------------------------------------------------------
// Patch history PDF
// ---------------------------------------------------------------------------
async fn patch_history_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
use sqlx::Row;
let rows = sqlx::query(
"
SELECT pj.kind::text AS job_kind, pj.status::text AS job_status,
h.display_name, h.fqdn, pjh.started_at, pjh.completed_at,
EXTRACT(EPOCH FROM (pjh.completed_at-pjh.started_at))::bigint AS duration_seconds,
COALESCE(u.username,'system') AS operator
FROM patch_job_hosts pjh
JOIN patch_jobs pj ON pj.id=pjh.job_id
JOIN hosts h ON h.id=pjh.host_id
LEFT JOIN users u ON u.id=pj.created_by_user_id
WHERE ($1::timestamptz IS NULL OR pjh.started_at>=$1)
AND ($2::timestamptz IS NULL OR pjh.started_at<=$2)
ORDER BY pjh.started_at DESC",
)
.bind(params.from)
.bind(params.to)
.fetch_all(pool)
.await
.context("patch history PDF query failed")?;
let mut dc: std::collections::BTreeMap<String, f64> = std::collections::BTreeMap::new();
for row in &rows {
if let Ok(Some(s)) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("started_at") {
*dc.entry(s.format("%Y-%m-%d").to_string()).or_insert(0.0) += 1.0;
}
}
let cl: Vec<String> = dc.keys().cloned().collect();
let cv: Vec<f64> = dc.values().cloned().collect();
let mut pdf = PdfBuilder::new("Patch History Report")?;
write_title_page(&mut pdf, "Patch History Report", params);
let col_x: &[f32] = &[MARGIN, 45.0, 80.0, 115.0, 155.0, 200.0, 245.0, 270.0];
pdf.table_row(
&[
"Kind",
"Status",
"Host",
"FQDN",
"Started",
"Completed",
"Dur(s)",
"Operator",
],
col_x,
9.0,
true,
);
for row in &rows {
let kind: String = row.try_get("job_kind").unwrap_or_default();
let status: String = row.try_get("job_status").unwrap_or_default();
let name: String = row.try_get("display_name").unwrap_or_default();
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
let started: String = row
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("started_at")
.unwrap_or(None)
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_default();
let completed: String = row
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("completed_at")
.unwrap_or(None)
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_default();
let dur: i64 = row.try_get("duration_seconds").unwrap_or(0);
let op: String = row.try_get("operator").unwrap_or_default();
pdf.table_row(
&[
&kind,
&status,
&name,
&fqdn,
&started,
&completed,
&dur.to_string(),
&op,
],
col_x,
8.0,
false,
);
}
if !cl.is_empty() {
match render_bar_chart(&cl, &cv, "Jobs per Day") {
Ok((raw, w, h)) => {
pdf.new_page();
pdf.write_text("Patch Activity Chart", 16.0, MARGIN, 200.0, true);
if let Err(e) = pdf.embed_image(raw, w, h, MARGIN, 10.0, 0.28, 0.28) {
tracing::warn!(error = %e, "chart embed failed");
}
},
Err(e) => tracing::warn!(error = %e, "chart render failed"),
}
}
pdf.save()
}
// ---------------------------------------------------------------------------
// Vulnerability PDF
// ---------------------------------------------------------------------------
async fn vulnerability_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
use sqlx::Row;
// Query DB FIRST (before creating any non-Send PdfBuilder)
let query_result = sqlx::query("
SELECT h.display_name, h.fqdn,
cve_id,
patch->>'name' AS package_name,
patch->>'severity' AS severity,
patch->>'available_version' AS available_version,
pd.polled_at AS last_seen_at
FROM hosts h JOIN host_patch_data pd ON pd.host_id=h.id
CROSS JOIN LATERAL jsonb_array_elements(COALESCE(pd.available_patches,'[]'::jsonb)) AS patch
CROSS JOIN LATERAL jsonb_array_elements_text(COALESCE(patch->'cve_ids','[]'::jsonb)) AS cve_id
WHERE ($1::timestamptz IS NULL OR pd.polled_at>=$1)
AND ($2::timestamptz IS NULL OR pd.polled_at<=$2)
ORDER BY CASE patch->>'severity' WHEN 'critical' THEN 1 WHEN 'high' THEN 2 WHEN 'medium' THEN 3 ELSE 4 END,
h.display_name")
.bind(params.from).bind(params.to).fetch_all(pool).await;
// Now create PdfBuilder (non-Send Rc types) after all awaits
let mut pdf = PdfBuilder::new("Vulnerability Report")?;
write_title_page(&mut pdf, "Vulnerability Exposure Report", params);
let col_x: &[f32] = &[MARGIN, 55.0, 100.0, 130.0, 175.0, 215.0, 255.0];
pdf.table_row(
&[
"Host",
"FQDN",
"CVE ID",
"Package",
"Severity",
"Fix Version",
"Last Seen",
],
col_x,
9.0,
true,
);
match query_result {
Ok(rows) => {
for row in &rows {
let name: String = row.try_get("display_name").unwrap_or_default();
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
let cve: String = row.try_get("cve_id").unwrap_or_default();
let pkg: String = row.try_get("package_name").unwrap_or_default();
let sev: String = row.try_get("severity").unwrap_or_default();
let fix: String = row.try_get("available_version").unwrap_or_default();
let seen: String = row
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_seen_at")
.unwrap_or(None)
.map(|d| d.format("%Y-%m-%d").to_string())
.unwrap_or_default();
pdf.table_row(
&[&name, &fqdn, &cve, &pkg, &sev, &fix, &seen],
col_x,
8.0,
false,
);
}
},
Err(e) => {
tracing::warn!(error = %e, "vulnerability PDF query failed");
let y = pdf.current_y;
pdf.write_text(&format!("No data: {}", e), 10.0, MARGIN, y, false);
},
}
pdf.save()
}
async fn audit_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
use sqlx::Row;
let rows = sqlx::query(
"
SELECT id::text AS id, created_at, action::text AS action,
actor_username, target_type, target_id,
ip_address::text AS ip_address, request_id
FROM audit_log
WHERE ($1::timestamptz IS NULL OR created_at>=$1)
AND ($2::timestamptz IS NULL OR created_at<=$2)
ORDER BY created_at DESC LIMIT 10000",
)
.bind(params.from)
.bind(params.to)
.fetch_all(pool)
.await
.context("audit PDF query failed")?;
let mut pdf = PdfBuilder::new("Audit Trail Report")?;
write_title_page(&mut pdf, "Audit Trail Report", params);
let col_x: &[f32] = &[MARGIN, 50.0, 95.0, 135.0, 175.0, 215.0, 255.0];
pdf.table_row(
&[
"Timestamp",
"Action",
"Actor",
"Target Type",
"Target ID",
"IP",
"Request ID",
],
col_x,
9.0,
true,
);
for row in &rows {
let created: String = row
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("created_at")
.unwrap_or(None)
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_default();
let action: String = row.try_get("action").unwrap_or_default();
let actor: String = row.try_get("actor_username").unwrap_or_default();
let ttype: String = row.try_get("target_type").unwrap_or_default();
let tid: String = row.try_get("target_id").unwrap_or_default();
let ip: String = row.try_get("ip_address").unwrap_or_default();
let req: String = row.try_get("request_id").unwrap_or_default();
pdf.table_row(
&[&created, &action, &actor, &ttype, &tid, &ip, &req],
col_x,
8.0,
false,
);
}
pdf.save()
}

46
crates/pm-web/Cargo.toml Normal file
View File

@ -0,0 +1,46 @@
[package]
name = "pm-web"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[[bin]]
name = "pm-web"
path = "src/main.rs"
[dependencies]
pm-ca = { path = "../pm-ca" }
pm-core = { path = "../pm-core" }
pm-auth = { path = "../pm-auth" }
pm-reports = { path = "../pm-reports" }
tokio = { workspace = true }
axum = { workspace = true }
axum-server = { workspace = true }
rustls = { workspace = true }
axum-extra = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
sqlx = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
ulid = { workspace = true }
chrono = { workspace = true }
ipnet = { workspace = true }
dashmap = { version = "6" }
tower_governor = { workspace = true }
governor = { workspace = true }
reqwest = { workspace = true }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
rand = { workspace = true }
hex = "0.4"
base64 = { workspace = true }
sha2 = { workspace = true }
jsonwebtoken = { workspace = true }
url = { workspace = true }
urlencoding = "2"

353
crates/pm-web/src/main.rs Normal file
View File

@ -0,0 +1,353 @@
//! pm-web — Linux Patch Manager web server.
mod routes;
use axum::{extract::State, http::StatusCode, middleware, response::Json, routing::get, Router};
use axum_server::tls_rustls::RustlsConfig;
use dashmap::DashMap;
use pm_auth::{
jwt,
rbac::{require_auth, AuthConfig},
};
use pm_core::{
config::AppConfig, db, logging, models::PkiBundle, request_id::request_id_middleware,
};
use routes::sso::{OidcCache, SsoSession};
use routes::ws::WsTicket;
use serde_json::{json, Value};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::sync::Mutex;
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
};
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
/// Shared application state threaded through Axum.
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::PgPool,
pub config: Arc<AppConfig>,
pub signing_key_pem: String,
pub auth_config: Arc<AuthConfig>,
/// In-memory store for single-use WebSocket authentication tickets.
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
/// In-memory store for SSO PKCE sessions (state → code_verifier).
pub sso_sessions: Arc<DashMap<String, SsoSession>>,
/// Cached OIDC discovery document and JWKS for SSO id_token verification.
pub oidc_cache: Arc<Mutex<OidcCache>>,
/// Internal certificate authority for mTLS client cert issuance.
pub ca: Arc<pm_ca::CertAuthority>,
/// Short-lived cache for approved enrollment PKI bundles.
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Install the default crypto provider for rustls (required since 0.23)
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
eprintln!("Config file not found or invalid, using defaults");
AppConfig::default()
});
logging::init(&config.logging);
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"patch-manager-web starting"
);
let signing_key_pem = jwt::load_signing_key(&config.security.jwt_signing_key_path)
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "JWT signing key not found (dev mode)");
String::new()
});
let verify_key_pem =
jwt::load_verify_key(&config.security.jwt_verify_key_path).unwrap_or_else(|e| {
tracing::warn!(error = %e, "JWT verify key not found (dev mode)");
String::new()
});
let auth_config = Arc::new(AuthConfig::new(
verify_key_pem,
&config.security.ip_whitelist,
));
let pool = db::init_pool(&config.database).await?;
db::run_migrations(&pool).await?;
// Initialise the internal CA using the configured certificate paths.
// The CA certificate and key must exist at the configured locations and be
// unencrypted PEM. If absent, a new CA is generated in that directory.
let ca_base = std::path::Path::new(&config.security.ca_cert_path)
.parent()
.expect("CA certificate path must have a parent directory");
let ca = pm_ca::CertAuthority::init(ca_base, &pool)
.await
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "CA init failed (dev mode)");
panic!("CA initialization failed: {}", e);
});
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
// Background task: purge expired WS tickets every 30 seconds.
{
let tickets = ws_tickets.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
let now = chrono::Utc::now();
let before = tickets.len();
tickets.retain(|_, v| v.expires_at > now);
let removed = before.saturating_sub(tickets.len());
if removed > 0 {
tracing::debug!(removed, "Purged expired WS tickets");
}
}
});
}
// Background task: purge expired SSO sessions every 60 seconds (sessions older than 10 minutes).
{
let sessions = sso_sessions.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
let now = chrono::Utc::now();
let cutoff = now - chrono::Duration::minutes(10);
let before = sessions.len();
sessions.retain(|_, v| v.created_at > cutoff);
let removed = before.saturating_sub(sessions.len());
if removed > 0 {
tracing::debug!(removed, "Purged expired SSO sessions");
}
}
});
}
// Background task: purge approved enrollment PKI bundles every 10 minutes.
{
let approved = approved_enrollments.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(600));
loop {
interval.tick().await;
approved.clear();
}
});
}
let state = AppState {
db: pool,
config: Arc::new(config.clone()),
signing_key_pem,
auth_config,
ws_tickets,
sso_sessions,
ca: Arc::new(ca),
approved_enrollments,
oidc_cache,
};
let app = build_router(state);
let addr: SocketAddr = format!("{}:{}", config.server.host, config.server.port)
.parse()
.expect("Invalid bind address");
// Try to load TLS certificate and key; fall back to plain HTTP if missing.
let tls_cert = std::path::Path::new(&config.security.web_tls_cert_path);
let tls_key = std::path::Path::new(&config.security.web_tls_key_path);
if tls_cert.exists() && tls_key.exists() {
let tls_config = RustlsConfig::from_pem_file(
&config.security.web_tls_cert_path,
&config.security.web_tls_key_path,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load TLS certificates");
e
})?;
tracing::info!(%addr, "Listening (HTTPS)");
axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
} else {
tracing::warn!(
cert_path = %config.security.web_tls_cert_path,
key_path = %config.security.web_tls_key_path,
"TLS certificates not found — falling back to plain HTTP. \
This is insecure for production!"
);
tracing::info!(%addr, "Listening (HTTP — no TLS)");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
}
Ok(())
}
/// Construct the full Axum router.
pub fn build_router(state: AppState) -> Router {
let static_dir = state.config.server.static_dir.clone();
let auth_config = state.auth_config.clone();
let rl = &state.config.rate_limit;
// Enrollment rate limiting: strict (5 req/min per IP, burst 3)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 12_000ms = ~5/min sustained
let enrollment_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(12_000)
.burst_size(rl.enrollment_burst)
.finish()
.expect("Invalid enrollment governor config"),
);
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 3_000ms = ~20/min sustained
let auth_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(3_000)
.burst_size(rl.auth_burst)
.finish()
.expect("Invalid auth governor config"),
);
// API rate limiting: normal (120 req/min per IP, burst 30)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 500ms = ~120/min sustained
let api_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(500)
.burst_size(rl.api_burst)
.finish()
.expect("Invalid API governor config"),
);
// Enrollment routes with strict per-IP rate limiting
let enrollment_router =
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
// Public auth routes with moderate per-IP rate limiting
let auth_public_router =
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
// SSO routes with moderate per-IP rate limiting
let sso_public_router =
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
let sso_azure_router =
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
// All protected API routes — require valid JWT, with normal per-IP rate limiting
let protected_api = Router::new()
// Auth: MFA setup/verify
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
.nest("/auth", routes::auth::protected_router())
// Hosts
.nest("/hosts", routes::hosts::router())
// Host-scoped certificate endpoints (merged separately to avoid conflict)
.nest("/hosts", routes::ca::host_cert_router())
// Groups
.nest("/groups", routes::groups::router())
// Users
.nest("/users", routes::users::router())
// Discovery
.nest("/discovery", routes::discovery::router())
// Fleet status
.nest("/status", routes::status::router())
// Patch jobs
.nest("/jobs", routes::jobs::router())
// Maintenance windows (nested under hosts path param)
.nest(
"/hosts/{host_id}/maintenance-windows",
routes::maintenance_windows::router(),
)
// Maintenance windows — bulk list-all endpoint
.nest(
"/maintenance-windows",
routes::maintenance_windows::all_windows_router(),
)
// CA root certificate download
.nest("/ca", routes::ca::ca_router())
// Certificate list / renew / revoke
.nest("/certificates", routes::ca::certs_router())
// WS ticket issuance (JWT-protected — ticket returned to browser, then used for WS upgrade)
.merge(routes::ws::ticket_router())
// Reports
.nest("/reports", routes::reports::router())
.nest(
"/hosts/{host_id}/health-checks",
routes::health_checks::router(),
)
// Settings (admin-only)
.nest("/settings", routes::settings::router())
// Admin enrollment routes (JWT protected, Admin role enforced)
.nest("/admin", routes::enrollment::admin_router())
// Apply rate limiting then auth middleware
.layer(GovernorLayer::new(api_governor))
.route_layer(middleware::from_fn(move |req, next| {
let auth_config = auth_config.clone();
require_auth(auth_config, req, next)
}));
Router::new()
.route("/status/health", get(health_handler))
// Public auth routes (rate-limited, no JWT)
.nest("/api/v1/auth", auth_public_router)
// Public enrollment endpoints (rate-limited, no JWT)
.nest("/api/v1", enrollment_router)
// Public SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/sso", sso_public_router)
// Public Azure SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/azure", sso_azure_router)
// Protected API routes (JWT required, rate-limited)
.nest("/api/v1", protected_api)
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
.merge(routes::ws::ws_router())
// Serve React SPA
.fallback_service(
ServeDir::new(&static_dir)
.append_index_html_on_directories(true)
.fallback(ServeFile::new(format!("{}/index.html", static_dir))),
)
.layer(middleware::from_fn(request_id_middleware))
.layer(TraceLayer::new_for_http())
.with_state(state)
}
async fn health_handler(State(state): State<AppState>) -> Result<Json<Value>, StatusCode> {
let db_ok = sqlx::query("SELECT 1").execute(&state.db).await.is_ok();
let status = if db_ok { "healthy" } else { "degraded" };
let body = json!({ "service": "patch-manager-web", "version": env!("CARGO_PKG_VERSION"), "status": status, "database": if db_ok { "ok" } else { "error" } });
if db_ok {
Ok(Json(body))
} else {
Err(StatusCode::SERVICE_UNAVAILABLE)
}
}

434
crates/pm-web/src/routes/auth.rs Executable file
View File

@ -0,0 +1,434 @@
//! Authentication route handlers.
//!
//! Public routes (no auth required):
//! POST /api/v1/auth/login
//! POST /api/v1/auth/refresh
//! POST /api/v1/auth/logout
//!
//! Protected routes (JWT required):
//! GET /api/v1/auth/mfa/setup
//! POST /api/v1/auth/mfa/verify
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::Json,
routing::delete,
routing::{get, post},
Router,
};
use pm_auth::{
hash_password, mfa_totp,
rbac::AuthUser,
session::{self, LoginRequest, LoginResponse},
validate_password_strength, verify_password,
};
use serde::Deserialize;
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
// ============================================================
// Public router — no authentication required
// ============================================================
pub fn public_router() -> Router<AppState> {
Router::new()
.route("/login", post(login_handler))
.route("/refresh", post(refresh_handler))
.route("/logout", post(logout_handler))
.route(
"/force-change-password",
post(force_change_password_handler),
)
}
// ============================================================
// Protected router — requires valid JWT (applied by caller)
// ============================================================
pub fn protected_router() -> Router<AppState> {
Router::new()
.route("/mfa/setup", get(mfa_setup_handler))
.route("/mfa/verify", post(mfa_verify_handler))
.route("/mfa", delete(disable_mfa))
}
// ============================================================
// Helpers
// ============================================================
fn user_agent(headers: &HeaderMap) -> Option<String> {
headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.map(str::to_string)
}
fn remote_ip(headers: &HeaderMap) -> Option<String> {
headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or("").trim().to_string())
}
// ============================================================
// POST /api/v1/auth/login
// ============================================================
async fn login_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<LoginRequest>,
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
let ip = remote_ip(&headers);
let ua = user_agent(&headers);
session::login(
&state.db,
&req,
&state.signing_key_pem,
state.config.security.jwt_access_ttl_secs as i64,
ua.as_deref(),
ip.as_deref(),
)
.await
.map(Json)
.map_err(|e| {
use pm_auth::session::SessionError;
let (status, code, message) = match e {
SessionError::InvalidCredentials | SessionError::InvalidMfaCode => (
StatusCode::UNAUTHORIZED,
"invalid_credentials",
"Invalid username or password",
),
SessionError::MfaRequired => (
StatusCode::UNAUTHORIZED,
"mfa_required",
"MFA code required",
),
SessionError::AccountDisabled => (
StatusCode::FORBIDDEN,
"account_disabled",
"Account is disabled",
),
SessionError::PasswordResetRequired => (
StatusCode::FORBIDDEN,
"password_reset_required",
"Password reset is required before login",
),
SessionError::AccountLocked => (
StatusCode::LOCKED,
"account_locked",
"Account is locked due to too many failed login attempts",
),
_ => {
tracing::error!(error = %e, "Login error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"An error occurred",
)
},
};
(
status,
Json(json!({ "error": { "code": code, "message": message } })),
)
})
}
// ============================================================
// POST /api/v1/auth/refresh
// ============================================================
#[derive(Debug, Deserialize)]
struct RefreshRequest {
refresh_token: String,
}
async fn refresh_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<RefreshRequest>,
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
let ip = remote_ip(&headers);
let ua = user_agent(&headers);
session::refresh_session(
&state.db,
&req.refresh_token,
&state.signing_key_pem,
state.config.security.jwt_access_ttl_secs as i64,
ua.as_deref(),
ip.as_deref(),
)
.await
.map(Json)
.map_err(|e| {
use pm_auth::session::SessionError;
let (status, code, msg) = match e {
SessionError::Refresh(_) => (
StatusCode::UNAUTHORIZED,
"invalid_refresh_token",
"Refresh token is invalid or expired",
),
SessionError::AccountDisabled => (
StatusCode::FORBIDDEN,
"account_disabled",
"Account is disabled",
),
_ => {
tracing::error!(error = %e, "Refresh error");
(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"An error occurred",
)
},
};
(
status,
Json(json!({ "error": { "code": code, "message": msg } })),
)
})
}
// ============================================================
// POST /api/v1/auth/logout
// ============================================================
#[derive(Debug, Deserialize)]
struct LogoutRequest {
refresh_token: String,
}
async fn logout_handler(
State(state): State<AppState>,
Json(req): Json<LogoutRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
session::logout(&state.db, &req.refresh_token)
.await
.map(|_| Json(json!({ "message": "Logged out successfully" })))
.map_err(|e| {
tracing::error!(error = %e, "Logout error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "An error occurred" } })),
)
})
}
// ============================================================
// GET /api/v1/auth/mfa/setup (JWT required — via middleware)
// ============================================================
// ============================================================
// POST /api/v1/auth/force-change-password (PUBLIC — no JWT)
// ============================================================
#[derive(Debug, Deserialize)]
struct ForceChangePasswordRequest {
username: String,
current_password: String,
new_password: String,
}
async fn force_change_password_handler(
State(state): State<AppState>,
Json(req): Json<ForceChangePasswordRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Validate new password strength
if let Err(msg) = validate_password_strength(&req.new_password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
// Look up user by username
let row: Option<(Uuid, Option<String>, bool)> = sqlx::query_as(
"SELECT id, password_hash, force_password_reset FROM users WHERE username = $1 AND auth_provider = 'local'",
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to fetch user");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let (user_id, hash_opt, _force_reset) = match row {
Some(r) => r,
None => {
return Err((
StatusCode::UNAUTHORIZED,
Json(
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
),
));
},
};
// Verify current password
let hash_str = hash_opt.as_deref().unwrap_or("");
let valid = verify_password(&req.current_password, hash_str).unwrap_or(false);
if !valid {
return Err((
StatusCode::UNAUTHORIZED,
Json(
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
),
));
}
// Hash and update password, clear force_password_reset
let new_hash = hash_password(&req.new_password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
sqlx::query(
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, failed_login_attempts = 0, locked_until = NULL, updated_at = NOW() WHERE id = $2",
)
.bind(&new_hash)
.bind(user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to update password");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
)
})?;
tracing::info!(user_id = %user_id, username = %req.username, "Password changed via force-change-password");
Ok(Json(json!({ "message": "Password changed successfully" })))
}
async fn mfa_setup_handler(
auth_user: AuthUser,
) -> Result<Json<mfa_totp::TotpSetup>, (StatusCode, Json<Value>)> {
mfa_totp::generate_setup(&auth_user.username)
.map(Json)
.map_err(|e| {
tracing::error!(error = %e, "TOTP setup error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})
}
// ============================================================
// POST /api/v1/auth/mfa/verify (JWT required — via middleware)
// ============================================================
#[derive(Debug, Deserialize)]
struct MfaVerifyRequest {
secret_base32: String,
code: String,
}
async fn mfa_verify_handler(
State(state): State<AppState>,
auth_user: AuthUser,
Json(req): Json<MfaVerifyRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let valid =
mfa_totp::verify_code(&auth_user.username, &req.secret_base32, &req.code).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
if !valid {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "invalid_code", "message": "Invalid TOTP code" } })),
));
}
sqlx::query("UPDATE users SET totp_secret = $1, mfa_enabled = TRUE WHERE id = $2")
.bind(&req.secret_base32)
.bind(auth_user.user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to save TOTP secret");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to enable MFA" } })),
)
})?;
tracing::info!(user_id = %auth_user.user_id, "MFA enabled for user");
Ok(Json(json!({ "message": "MFA enabled successfully" })))
}
// ============================================================
// DELETE /api/v1/auth/mfa (JWT required — disable own MFA)
// ============================================================
#[derive(Debug, Deserialize)]
struct DisableMfaRequest {
password: String,
}
async fn disable_mfa(
State(state): State<AppState>,
auth_user: AuthUser,
Json(req): Json<DisableMfaRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Verify current password to confirm identity
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
.bind(auth_user.user_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to fetch password hash");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?
.flatten();
let hash_str = hash.unwrap_or_default();
let valid = verify_password(&req.password, &hash_str).unwrap_or(false);
if !valid {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
),
));
}
sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE WHERE id = $1")
.bind(auth_user.user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to disable MFA");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
)
})?;
tracing::info!(user_id = %auth_user.user_id, "MFA disabled for user");
Ok(Json(json!({ "message": "MFA disabled successfully" })))
}

516
crates/pm-web/src/routes/ca.rs Executable file
View File

@ -0,0 +1,516 @@
//! CA / certificate management routes.
//!
//! ca_router() → mounted at /api/v1/ca
//! GET /root.crt download_root_ca (any authed role)
//!
//! certs_router() → mounted at /api/v1/certificates
//! GET / list_certificates (any authed role)
//! POST /:cert_id/renew renew_cert (admin only)
//! DELETE /:cert_id revoke_cert (admin only)
//!
//! host_cert_router() → merged under /api/v1/hosts
//! GET /:host_id/client.crt download_client_cert (admin only)
//! POST /:host_id/certificates issue_client_cert (admin only)
//! POST /:host_id/certificates/reissue reissue_host_cert (admin only)
use axum::{
body::Body,
extract::{Path, Query, State},
http::{header, Response, StatusCode},
response::Json,
routing::{delete, get, post},
Router,
};
use chrono::{DateTime, Utc};
use pm_auth::rbac::AuthUser;
use pm_core::audit::{log_event, AuditAction};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlx::Row;
use uuid::Uuid;
use crate::AppState;
// ── Router constructors ───────────────────────────────────────────────────────
/// Handles routes mounted at /api/v1/ca
pub fn ca_router() -> Router<AppState> {
Router::new().route("/root.crt", get(download_root_ca))
}
/// Handles routes mounted at /api/v1/certificates
pub fn certs_router() -> Router<AppState> {
Router::new()
.route("/", get(list_certificates))
.route("/{cert_id}/renew", post(renew_cert))
.route("/{cert_id}", delete(revoke_cert))
}
/// Handles cert-specific paths merged under /api/v1/hosts.
/// Only adds paths not already claimed by the hosts router.
pub fn host_cert_router() -> Router<AppState> {
Router::new()
.route("/{host_id}/client.crt", get(download_client_cert))
.route("/{host_id}/certificates", post(issue_client_cert))
.route("/{host_id}/certificates/reissue", post(reissue_host_cert))
}
// ── Shared types ──────────────────────────────────────────────────────────────
/// Row returned from the `certificates` table.
#[derive(Debug, Serialize, sqlx::FromRow)]
struct CertRow {
id: Uuid,
host_id: Option<Uuid>,
serial_number: String,
common_name: String,
/// Cast to TEXT in all queries to avoid custom-enum decode.
status: String,
issued_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
revoked_at: Option<DateTime<Utc>>,
}
/// Query params for `list_certificates`.
#[derive(Debug, Deserialize)]
struct CertListQuery {
host_id: Option<Uuid>,
status: Option<String>,
}
/// Request body for `issue_client_cert`.
#[derive(Debug, Deserialize)]
struct IssueCertRequest {
hostname: String,
}
// ── Helper: build PEM download response ──────────────────────────────────────
fn pem_response(pem: String, filename: &str) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
let disposition = format!("attachment; filename=\"{filename}\"");
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/x-pem-file")
.header(header::CONTENT_DISPOSITION, disposition)
.body(Body::from(pem))
.map_err(|e| {
tracing::error!(error = %e, "Failed to build PEM response");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Response build error" } })),
)
})
}
// ── Helper: admin-only guard ──────────────────────────────────────────────────
fn require_write_access(user: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
if !user.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
Ok(())
}
// ── Helper: map sqlx error to 500 ─────────────────────────────────────────────
fn db_error(e: sqlx::Error) -> (StatusCode, Json<Value>) {
tracing::error!(error = %e, "Database error");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
}
// ── Helper: build the full IssuedCert JSON response ──────────────────────────
fn issued_cert_json(issued: &pm_ca::IssuedCert) -> Value {
json!({
"cert_pem": issued.cert_pem,
"key_pem": issued.key_pem,
"serial_number": issued.serial_number,
"expires_at": issued.expires_at,
"server_cert_pem": issued.server_cert_pem,
"server_key_pem": issued.server_key_pem,
"server_serial_number": issued.server_serial_number,
"ca_root_pem": issued.ca_root_pem,
})
}
// ── GET /api/v1/ca/root.crt ───────────────────────────────────────────────────
/// Download the root CA certificate as a PEM file.
async fn download_root_ca(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
let pem = state.ca.root_cert_pem().to_owned();
log_event(
&state.db,
AuditAction::CertificateDownloaded,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some("root_ca"),
json!({ "operation": "download_root_ca" }),
None,
None,
)
.await;
pem_response(pem, "ca.crt")
}
// ── GET /api/v1/certificates ──────────────────────────────────────────────────
/// List certificates with optional `?host_id=` and `?status=` filters.
async fn list_certificates(
State(state): State<AppState>,
_auth: AuthUser,
Query(q): Query<CertListQuery>,
) -> Result<Json<Vec<CertRow>>, (StatusCode, Json<Value>)> {
// Use the non-macro query_as form — avoids needing DATABASE_URL at compile
// time. status is cast to TEXT so sqlx decodes it into String directly.
let rows: Vec<CertRow> = match (q.host_id, q.status.as_deref()) {
(Some(hid), Some(st)) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
WHERE host_id = $1 AND status::text = $2
ORDER BY issued_at DESC"#,
)
.bind(hid)
.bind(st)
.fetch_all(&state.db)
.await
},
(Some(hid), None) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
WHERE host_id = $1
ORDER BY issued_at DESC"#,
)
.bind(hid)
.fetch_all(&state.db)
.await
},
(None, Some(st)) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
WHERE status::text = $1
ORDER BY issued_at DESC"#,
)
.bind(st)
.fetch_all(&state.db)
.await
},
(None, None) => {
sqlx::query_as::<_, CertRow>(
r#"SELECT id, host_id, serial_number, common_name,
status::text AS status,
issued_at, expires_at, revoked_at
FROM certificates
ORDER BY issued_at DESC"#,
)
.fetch_all(&state.db)
.await
},
}
.map_err(db_error)?;
Ok(Json(rows))
}
// ── GET /api/v1/hosts/:host_id/client.crt ────────────────────────────────────
/// Download the most recent active client certificate PEM for a host.
async fn download_client_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
let cert_pem: Option<String> = sqlx::query_scalar(
r#"SELECT cert_pem
FROM certificates
WHERE host_id = $1
AND status = 'active'::cert_status
AND common_name NOT LIKE '%-server'
ORDER BY issued_at DESC
LIMIT 1"#,
)
.bind(host_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to fetch client cert");
db_error(e)
})?;
match cert_pem {
Some(pem) => {
log_event(
&state.db,
AuditAction::CertificateDownloaded,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&host_id.to_string()),
json!({ "operation": "download_client_cert" }),
None,
None,
)
.await;
pem_response(pem, "client.crt")
},
None => Err((
StatusCode::NOT_FOUND,
Json(json!({
"error": {
"code": "not_found",
"message": "No active certificate found for this host"
}
})),
)),
}
}
// ── POST /api/v1/hosts/:host_id/certificates ─────────────────────────────────
/// Issue a new mTLS client certificate (and server certificate) for a host.
/// **The private keys are returned only once — the caller must save them.**
async fn issue_client_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
Json(req): Json<IssueCertRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
// Look up the host's IP address from the database.
let ip_address: String = sqlx::query_scalar("SELECT host(ip_address) FROM hosts WHERE id = $1")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to fetch host IP address");
if e.to_string().contains("no rows") {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
} else {
db_error(e)
}
})?;
let issued = state
.ca
.issue_client_cert(host_id, &req.hostname, &ip_address, &state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, hostname = %req.hostname,
"Failed to issue client cert");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::CertificateIssued,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&host_id.to_string()),
json!({ "hostname": req.hostname, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
None,
None,
)
.await;
Ok(Json(issued_cert_json(&issued)))
}
// ── POST /api/v1/certificates/:cert_id/renew ─────────────────────────────────
/// Revoke the specified certificate and issue a replacement with the same CN.
async fn renew_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(cert_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
let issued = state.ca.renew_cert(cert_id, &state.db).await.map_err(|e| {
let msg = e.to_string();
tracing::error!(error = %e, %cert_id, "Failed to renew cert");
if msg.contains("not found") {
(
StatusCode::NOT_FOUND,
Json(
json!({ "error": { "code": "not_found", "message": "Certificate not found" } }),
),
)
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
)
}
})?;
log_event(
&state.db,
AuditAction::CertificateRenewed,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&cert_id.to_string()),
json!({ "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
None,
None,
)
.await;
Ok(Json(issued_cert_json(&issued)))
}
// ── POST /api/v1/hosts/:host_id/certificates/reissue ────────────────────────
/// Revoke ALL active certificates for a host and issue new ones.
/// The private keys are returned only once — the caller must save them.
async fn reissue_host_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
// Look up the host's FQDN and IP address for the new certificate CN and SANs.
let row = sqlx::query("SELECT fqdn, host(ip_address) AS ip_address FROM hosts WHERE id = $1")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to fetch host FQDN/IP");
if e.to_string().contains("no rows") {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
} else {
db_error(e)
}
})?;
let fqdn: String = row.try_get("fqdn").map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to read fqdn");
db_error(e)
})?;
let ip_address: String = row.try_get("ip_address").map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to read ip_address");
db_error(e)
})?;
// Revoke all active certificates for this host.
let revoked = sqlx::query(
"UPDATE certificates SET status = 'revoked'::cert_status, revoked_at = NOW() \
WHERE host_id = $1 AND status = 'active'::cert_status",
)
.bind(host_id)
.execute(&state.db)
.await
.map_err(db_error)?;
tracing::info!(%host_id, rows_revoked = revoked.rows_affected(), "Revoked all active certs for host");
// Issue a new certificate bundle using the host's FQDN and IP.
let issued = state
.ca
.issue_client_cert(host_id, &fqdn, &ip_address, &state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "Failed to issue new cert during reissue");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::CertificateReissued,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&host_id.to_string()),
json!({ "hostname": &fqdn, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number, "rows_revoked": revoked.rows_affected() }),
None,
None,
)
.await;
Ok(Json(issued_cert_json(&issued)))
}
// ── DELETE /api/v1/certificates/:cert_id ─────────────────────────────────────
/// Revoke a certificate by ID. Sets status to 'revoked' in the database.
async fn revoke_cert(
State(state): State<AppState>,
auth: AuthUser,
Path(cert_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
require_write_access(&auth)?;
state
.ca
.revoke_cert(cert_id, &state.db)
.await
.map_err(|e| {
let msg = e.to_string();
tracing::error!(error = %e, %cert_id, "Failed to revoke cert");
if msg.contains("not found") {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Certificate not found" } })),
)
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
)
}
})?;
tracing::info!(%cert_id, "Certificate revoked via API");
log_event(
&state.db,
AuditAction::CertificateRevoked,
Some(auth.user_id),
Some(&auth.username),
Some("certificate"),
Some(&cert_id.to_string()),
json!({ "operation": "revoke" }),
None,
None,
)
.await;
Ok(Json(json!({ "revoked": true })))
}

View File

@ -0,0 +1,304 @@
//! CIDR auto-discovery routes.
//!
//! POST /api/v1/discovery/cidr — start a CIDR scan
//! GET /api/v1/discovery/:scan_id — get scan results
//! POST /api/v1/discovery/:id/register — register a discovered host
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{DiscoveryCidrRequest, DiscoveryResult, RegisterDiscoveredRequest},
};
use serde_json::{json, Value};
use std::{
net::{IpAddr, TcpStream},
time::Duration,
};
use tokio::{sync::Semaphore, task};
use uuid::Uuid;
use crate::AppState;
/// Maximum concurrent TCP probes during CIDR scan.
const MAX_CONCURRENT_PROBES: usize = 128;
/// TCP connect timeout per probe.
const PROBE_TIMEOUT_SECS: u64 = 2;
pub fn router() -> Router<AppState> {
Router::new()
.route("/cidr", post(start_cidr_scan))
.route("/{scan_id}", get(get_scan_results))
.route("/{id}/register", post(register_discovered_host))
}
// ── POST /api/v1/discovery/cidr ───────────────────────────────────────────────
async fn start_cidr_scan(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<DiscoveryCidrRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let cidr: ipnet::IpNet = req.cidr.parse().map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "bad_request", "message": "Invalid CIDR range" } })),
)
})?;
let agent_port = req.agent_port.unwrap_or(12443) as u16;
let scan_id = Uuid::new_v4();
// Clear previous results for this type of scan and start async scan
let pool = state.db.clone();
let scan_id_clone = scan_id;
let cidr_str = req.cidr.clone();
// Spawn non-blocking background scan
task::spawn(async move {
run_cidr_scan(pool, scan_id_clone, cidr, agent_port).await;
});
log_event(
&state.db,
AuditAction::DiscoveryScanStarted,
Some(auth.user_id),
Some(&auth.username),
Some("discovery"),
Some(&scan_id.to_string()),
json!({ "cidr": cidr_str }),
None,
None,
)
.await;
tracing::info!(scan_id = %scan_id, cidr = %req.cidr, "CIDR scan started");
Ok(Json(
json!({ "scan_id": scan_id, "message": "Discovery scan started", "cidr": req.cidr }),
))
}
/// Background CIDR scanner.
async fn run_cidr_scan(pool: sqlx::PgPool, scan_id: Uuid, cidr: ipnet::IpNet, port: u16) {
let semaphore = std::sync::Arc::new(Semaphore::new(MAX_CONCURRENT_PROBES));
let hosts: Vec<IpAddr> = cidr.hosts().collect();
let total = hosts.len();
tracing::info!(scan_id = %scan_id, total = total, "CIDR scan probing {} hosts", total);
let mut handles = Vec::new();
for ip in hosts {
let sem = semaphore.clone();
let pool_clone = pool.clone();
let h = task::spawn(async move {
let _permit = sem.acquire().await.ok()?;
probe_and_store(pool_clone, scan_id, ip, port).await
});
handles.push(h);
}
for h in handles {
let _ = h.await;
}
tracing::info!(scan_id = %scan_id, "CIDR scan complete");
}
/// Probe a single IP:port and store the result if the port is open.
async fn probe_and_store(pool: sqlx::PgPool, scan_id: Uuid, ip: IpAddr, port: u16) -> Option<()> {
let addr = format!("{ip}:{port}");
// TCP connect probe (blocking, run in thread pool)
// TCP connect probe (blocking, run in thread pool)
let addr_clone = addr.clone();
let open = task::spawn_blocking(move || {
TcpStream::connect_timeout(
&match addr_clone.parse() {
Ok(a) => a,
Err(_) => return false,
},
Duration::from_secs(PROBE_TIMEOUT_SECS),
)
.is_ok()
})
.await
.unwrap_or(false);
if !open {
return None;
}
// Reverse DNS lookup (best-effort)
let ip_clone = ip;
let fqdn = task::spawn_blocking(move || {
use std::net::ToSocketAddrs;
let addr = format!("{ip_clone}:{port}");
addr.to_socket_addrs()
.ok()
.and_then(|mut a| a.next())
.and_then(|_| dns_lookup_for_ip(ip_clone))
})
.await
.ok()
.flatten();
let _ = sqlx::query(
r#"INSERT INTO discovery_results (scan_id, ip_address, fqdn, agent_port)
VALUES ($1, $2::inet, $3, $4)
ON CONFLICT DO NOTHING"#,
)
.bind(scan_id)
.bind(ip.to_string())
.bind(fqdn)
.bind(port as i32)
.execute(&pool)
.await;
tracing::debug!(ip = %ip, port = port, "Discovered agent");
Some(())
}
/// Simple reverse DNS lookup.
fn dns_lookup_for_ip(ip: IpAddr) -> Option<String> {
use std::net::{SocketAddr, ToSocketAddrs};
let _addr = SocketAddr::new(ip, 0);
// Standard library doesn't have reverse lookup; use getaddrinfo via format
let host = format!("{ip}");
// Best-effort: try to resolve numeric address to hostname
(host + ":0")
.to_socket_addrs()
.ok()?
.next()
.map(|a| a.ip().to_string())
.filter(|s| s != &ip.to_string())
}
// ── GET /api/v1/discovery/:scan_id ────────────────────────────────────────────
async fn get_scan_results(
State(state): State<AppState>,
_auth: AuthUser,
Path(scan_id): Path<Uuid>,
) -> Result<Json<Vec<DiscoveryResult>>, (StatusCode, Json<Value>)> {
sqlx::query_as::<_, DiscoveryResult>(
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
agent_version, os_name, agent_port, discovered_at, registered
FROM discovery_results
WHERE scan_id = $1
ORDER BY ip_address"#,
)
.bind(scan_id)
.fetch_all(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})
}
// ── POST /api/v1/discovery/:id/register ──────────────────────────────────────
async fn register_discovered_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<RegisterDiscoveredRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Fetch discovery result
let result: Option<DiscoveryResult> = sqlx::query_as(
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
agent_version, os_name, agent_port, discovered_at, registered
FROM discovery_results WHERE id = $1"#,
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
let result = result.ok_or_else(|| (
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Discovery result not found" } }))
))?;
let fqdn = result.fqdn.as_deref().unwrap_or(&result.ip_address);
let display_name = req.display_name.as_deref().unwrap_or(fqdn);
let host_id: Uuid = sqlx::query_scalar(
r#"INSERT INTO hosts (fqdn, ip_address, display_name, agent_port)
VALUES ($1, $2::inet, $3, $4)
ON CONFLICT DO NOTHING
RETURNING id"#,
)
.bind(fqdn)
.bind(&result.ip_address)
.bind(display_name)
.bind(result.agent_port)
.fetch_one(&state.db)
.await
.map_err(|e| {
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": e.to_string() } })),
)
})?;
// Assign to groups
if let Some(group_ids) = &req.group_ids {
for gid in group_ids {
let _ = sqlx::query("INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING")
.bind(host_id).bind(gid).execute(&state.db).await;
}
}
// Mark as registered
let _ = sqlx::query("UPDATE discovery_results SET registered = TRUE WHERE id = $1")
.bind(id)
.execute(&state.db)
.await;
log_event(
&state.db,
AuditAction::HostRegistered,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&host_id.to_string()),
json!({ "from_discovery": true, "ip": result.ip_address }),
None,
None,
)
.await;
Ok(Json(
json!({ "host_id": host_id, "message": "Host registered from discovery" }),
))
}

View File

@ -0,0 +1,319 @@
use crate::AppState;
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
Json, Router,
};
use chrono::Utc;
use pm_auth::AuthUser;
use pm_core::{
db,
models::{
CreateEnrollmentRequest, EnrollmentRequest, EnrollmentStatusResponse, Host, PkiBundle,
},
};
use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize;
#[derive(Debug, Clone, Serialize)]
pub struct HostConflict {
pub existing_host: Host,
pub message: String,
}
/// Define public enrollment routes.
pub fn router() -> Router<AppState> {
Router::new()
.route("/enroll", post(enroll_host))
.route("/enroll/status/{token}", get(enroll_status))
}
/// POST /api/v1/enroll
/// Initiates host self-enrollment.
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
async fn enroll_host(
State(state): State<AppState>,
Json(payload): Json<CreateEnrollmentRequest>,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
// Generate secure random polling token
let polling_token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
// For database storage, we'll hash the token (spec says hashed)
// Using a simple SHA256 or similar for the hash storage
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(polling_token.as_bytes());
let token_hash = hex::encode(hasher.finalize());
// 3. Store in DB
db::create_enrollment_request(&state.db, payload, token_hash)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to create enrollment request");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// 4. Return the raw token to the client
Ok((
StatusCode::ACCEPTED,
Json(serde_json::json!({ "polling_token": polling_token })),
)
.into_response())
}
/// GET /api/v1/enroll/status/{token}
/// Returns status of enrollment (pending/approved/denied/not_found).
async fn enroll_status(
State(state): State<AppState>,
Path(token): Path<String>,
) -> Result<Json<EnrollmentStatusResponse>, (StatusCode, Json<serde_json::Value>)> {
// Hash the provided token to match DB
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
let token_hash = hex::encode(hasher.finalize());
// 1. Check enrollment_requests table
let requests = db::list_enrollment_requests(&state.db).await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
if let Some(req) = requests.into_iter().find(|r| r.polling_token == token_hash) {
if req.expires_at < Utc::now() {
return Ok(Json(EnrollmentStatusResponse::NotFound));
}
return Ok(Json(EnrollmentStatusResponse::Pending));
}
// 2. If not in pending, check if it was recently approved.
if let Some(pki) = state.approved_enrollments.get(&token_hash) {
return Ok(Json(EnrollmentStatusResponse::Approved {
ca_crt: pki.ca_crt.clone(),
server_crt: pki.server_crt.clone(),
server_key: pki.server_key.clone(),
}));
}
Ok(Json(EnrollmentStatusResponse::NotFound))
}
/// Define admin enrollment routes.
pub fn admin_router() -> Router<AppState> {
Router::new()
.route("/enrollments", get(list_admin_enrollments))
.route("/enrollments/{id}/approve", post(approve_enrollment))
.route("/enrollments/{id}/deny", delete(deny_enrollment))
}
/// GET /api/v1/admin/enrollments
/// Lists all pending enrollment requests.
async fn list_admin_enrollments(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Vec<EnrollmentRequest>>, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
db::list_enrollment_requests(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e, "Failed to list enrollment requests");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})
}
/// POST /api/v1/admin/enrollments/{id}/approve
/// Approves a pending enrollment request, generates PKI, and moves to hosts table.
async fn approve_enrollment(
State(state): State<AppState>,
Path(id): Path<uuid::Uuid>,
auth: AuthUser,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
// Fetch the enrollment request
let mut requests = db::list_enrollment_requests(&state.db).await.map_err(|e| {
tracing::error!(error = %e, "Failed to list enrollment requests for approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
let enrollment_request = match requests.iter().position(|r| r.id == id) {
Some(idx) => requests.remove(idx),
None => return Ok(StatusCode::NOT_FOUND),
};
// Check for FQDN/IP collision in hosts table
if let Some(existing_host) = sqlx::query_as::<_, Host>(
"SELECT id, fqdn, ip_address::text, display_name, os_family, os_name, arch, agent_version, health_status, last_health_at, last_patch_at, agent_port, notes, registered_at, updated_at FROM hosts WHERE fqdn = $1 OR ip_address = $2::inet"
)
.bind(&enrollment_request.fqdn)
.bind(enrollment_request.ip_address.to_string())
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to check for host collision");
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": "Database error" })))
})? {
return Err((
StatusCode::CONFLICT,
Json(serde_json::json!({ "error": "Host collision detected", "conflict": HostConflict { existing_host, message: "FQDN or IP already exists".to_string() } }))
));
}
// Move to hosts table FIRST (certificates table has FK reference to hosts)
let os_family = enrollment_request
.os_details
.get("os")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let os_name = enrollment_request
.os_details
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
// Build os_name from os + os_version if "name" is absent
let os = enrollment_request
.os_details
.get("os")
.and_then(|v| v.as_str())?;
let ver = enrollment_request
.os_details
.get("os_version")
.and_then(|v| v.as_str())
.unwrap_or("");
Some(format!("{} {}", os, ver).trim().to_string())
});
let arch = enrollment_request
.os_details
.get("architecture")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let display_name = enrollment_request
.hostname
.clone()
.unwrap_or_else(|| enrollment_request.fqdn.clone());
sqlx::query(
r#"
INSERT INTO hosts (id, fqdn, ip_address, os_family, os_name, arch, display_name, registered_at, updated_at)
VALUES ($1, $2, $3::inet, $4, $5, $6, $7, NOW(), NOW())
"#,
)
.bind(enrollment_request.id)
.bind(&enrollment_request.fqdn)
.bind(enrollment_request.ip_address.to_string())
.bind(&os_family)
.bind(&os_name)
.bind(&arch)
.bind(&display_name)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to insert host after approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// Generate PKI bundle using CA (after host row exists)
let issued = state
.ca
.issue_client_cert(
enrollment_request.id,
&enrollment_request.fqdn,
&enrollment_request.ip_address.to_string(),
&state.db,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to issue client certificate");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Certificate generation failed" })),
)
})?;
// Delete from enrollment_requests table
db::delete_enrollment_request(&state.db, id)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to delete enrollment request after approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// Store PKI bundle in cache for client retrieval
let pki = PkiBundle {
ca_crt: issued.ca_root_pem,
server_crt: issued.server_cert_pem,
server_key: issued.server_key_pem,
};
state
.approved_enrollments
.insert(enrollment_request.polling_token.clone(), pki);
Ok(StatusCode::OK)
}
/// DELETE /api/v1/admin/enrollments/{id}/deny
/// Denies and purges a pending enrollment request.
async fn deny_enrollment(
State(state): State<AppState>,
Path(id): Path<uuid::Uuid>,
auth: AuthUser,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
db::delete_enrollment_request(&state.db, id)
.await
.map(|_| StatusCode::NO_CONTENT)
.map_err(|e| {
tracing::error!(error = %e, "Failed to deny enrollment request");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})
}

View File

@ -0,0 +1,312 @@
//! Group management routes.
//!
//! GET /api/v1/groups — list all groups
//! POST /api/v1/groups — create group (admin)
//! GET /api/v1/groups/:id — get group detail + members
//! PUT /api/v1/groups/:id — update group (admin)
//! DELETE /api/v1/groups/:id — delete group (admin)
//! POST /api/v1/groups/:id/users/:user_id — add user to group (admin)
//! DELETE /api/v1/groups/:id/users/:user_id — remove user from group (admin)
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateGroupRequest, Group, UpdateGroupRequest},
};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_groups).post(create_group))
.route(
"/{id}",
get(get_group).put(update_group).delete(delete_group),
)
.route(
"/{id}/users/{user_id}",
post(add_user_to_group).delete(remove_user_from_group),
)
}
async fn list_groups(
State(state): State<AppState>,
_auth: AuthUser,
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
sqlx::query_as::<_, Group>(
"SELECT id, name, description, created_at, updated_at FROM groups ORDER BY name",
)
.fetch_all(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e, "Failed to list groups");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})
}
async fn create_group(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateGroupRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let id: Uuid =
sqlx::query_scalar("INSERT INTO groups (name, description) VALUES ($1, $2) RETURNING id")
.bind(&req.name)
.bind(req.description.as_deref().unwrap_or(""))
.fetch_one(&state.db)
.await
.map_err(|e| {
let msg = if e.to_string().contains("unique") {
"Group name already exists".to_string()
} else {
"Database error".to_string()
};
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupCreated,
Some(auth.user_id),
Some(&auth.username),
Some("group"),
Some(&id.to_string()),
json!({ "name": req.name }),
None,
None,
)
.await;
Ok(Json(json!({ "id": id, "message": "Group created" })))
}
async fn get_group(
State(state): State<AppState>,
_auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let group: Option<Group> = sqlx::query_as(
"SELECT id, name, description, created_at, updated_at FROM groups WHERE id = $1",
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let group = group.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
)
})?;
// Fetch member counts
let host_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM host_groups WHERE group_id = $1")
.bind(id)
.fetch_one(&state.db)
.await
.unwrap_or(0);
let user_count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM user_groups WHERE group_id = $1")
.bind(id)
.fetch_one(&state.db)
.await
.unwrap_or(0);
Ok(Json(
json!({ "group": group, "host_count": host_count, "user_count": user_count }),
))
}
async fn update_group(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<UpdateGroupRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let rows = sqlx::query(
"UPDATE groups SET name = COALESCE($1, name), description = COALESCE($2, description), updated_at = NOW() WHERE id = $3"
)
.bind(req.name.as_deref())
.bind(req.description.as_deref())
.bind(id)
.execute(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
));
}
Ok(Json(json!({ "message": "Group updated" })))
}
async fn delete_group(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
let rows = sqlx::query("DELETE FROM groups WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
));
}
log_event(
&state.db,
AuditAction::GroupDeleted,
Some(auth.user_id),
Some(&auth.username),
Some("group"),
Some(&id.to_string()),
json!({}),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Group deleted" })))
}
async fn add_user_to_group(
State(state): State<AppState>,
auth: AuthUser,
Path((id, user_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query(
"INSERT INTO user_groups (user_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
)
.bind(user_id)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("user_group"),
Some(&id.to_string()),
json!({ "user_id": user_id, "action": "added" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User added to group" })))
}
async fn remove_user_from_group(
State(state): State<AppState>,
auth: AuthUser,
Path((id, user_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query("DELETE FROM user_groups WHERE user_id = $1 AND group_id = $2")
.bind(user_id)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("user_group"),
Some(&id.to_string()),
json!({ "user_id": user_id, "action": "removed" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User removed from group" })))
}

File diff suppressed because it is too large Load Diff

678
crates/pm-web/src/routes/hosts.rs Executable file
View File

@ -0,0 +1,678 @@
//! Host management routes.
//!
//! GET /api/v1/hosts — list hosts (RBAC scoped)
//! POST /api/v1/hosts — register new host (admin only)
//! GET /api/v1/hosts/{id} — get host detail
//! DELETE /api/v1/hosts/{id} — remove host (admin only)
//! PUT /api/v1/hosts/{id} — update host (write access)
//! GET /api/v1/hosts/{id}/groups — list groups for host
//! POST /api/v1/hosts/{id}/groups — assign host to group
//! DELETE /api/v1/hosts/{id}/groups/{group_id} — remove host from group
//! POST /api/v1/hosts/{id}/refresh — queue on-demand refresh (write access)
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
routing::{delete, get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateHostRequest, Group, HostSummary, UpdateHostRequest},
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_hosts).post(register_host))
.route("/{id}", get(get_host).put(update_host).delete(remove_host))
.route(
"/{id}/groups",
get(list_host_groups).post(add_host_to_group),
)
.route("/{id}/groups/{group_id}", delete(remove_host_from_group))
.route("/{id}/refresh", post(refresh_host))
}
// ── Query params ─────────────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct HostListQuery {
pub group_id: Option<Uuid>,
pub health_status: Option<String>,
pub os_family: Option<String>,
pub search: Option<String>,
pub limit: Option<i64>,
pub offset: Option<i64>,
}
// ── Response types ────────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct HostListResponse {
hosts: Vec<HostSummary>,
total: i64,
limit: i64,
offset: i64,
}
// ── Helper: check if operator can access a host ───────────────────────────────
async fn operator_can_access_host(
pool: &sqlx::PgPool,
user_id: Uuid,
host_id: Uuid,
) -> Result<bool, sqlx::Error> {
// Admins can access all; operators can access hosts in their groups
// OR ungrouped hosts (no group memberships)
let in_group: bool = sqlx::query_scalar(
r#"
SELECT EXISTS (
SELECT 1 FROM host_groups hg
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE hg.host_id = $1 AND ug.user_id = $2
)
"#,
)
.bind(host_id)
.bind(user_id)
.fetch_one(pool)
.await?;
if in_group {
return Ok(true);
}
// Ungrouped hosts are accessible to any operator
let ungrouped: bool =
sqlx::query_scalar("SELECT NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = $1)")
.bind(host_id)
.fetch_one(pool)
.await?;
Ok(ungrouped)
}
// ── GET /api/v1/hosts ─────────────────────────────────────────────────────────
async fn list_hosts(
State(state): State<AppState>,
auth: AuthUser,
Query(q): Query<HostListQuery>,
) -> Result<Json<HostListResponse>, (StatusCode, Json<Value>)> {
let limit = q.limit.unwrap_or(50).min(200);
let offset = q.offset.unwrap_or(0);
// For operators: only show hosts in their groups (or ungrouped)
let hosts: Vec<HostSummary> = if auth.role.is_admin() {
sqlx::query_as(
r#"
SELECT h.id, h.fqdn, host(h.ip_address)::text AS ip_address, h.display_name,
h.os_family, h.os_name, h.health_status, h.agent_version,
COALESCE(hpd.patch_count, 0) AS patches_missing,
CASE
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
THEN NULL
WHEN EXISTS (
SELECT 1 FROM host_health_checks hc
LEFT JOIN LATERAL (
SELECT healthy FROM host_health_check_results r
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
) lr ON TRUE
WHERE hc.host_id = h.id AND hc.enabled = TRUE
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
)
THEN 'some_unhealthy'
ELSE 'all_healthy'
END AS health_check_status,
h.registered_at
FROM hosts h
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
ORDER BY h.fqdn
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await
} else {
sqlx::query_as(
r#"
SELECT DISTINCT h.id, h.fqdn, host(h.ip_address)::text AS ip_address,
h.display_name, h.os_family, h.os_name,
h.health_status, h.agent_version,
COALESCE(hpd.patch_count, 0) AS patches_missing,
CASE
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
THEN NULL
WHEN EXISTS (
SELECT 1 FROM host_health_checks hc
LEFT JOIN LATERAL (
SELECT healthy FROM host_health_check_results r
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
) lr ON TRUE
WHERE hc.host_id = h.id AND hc.enabled = TRUE
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
)
THEN 'some_unhealthy'
ELSE 'all_healthy'
END AS health_check_status,
h.registered_at
FROM hosts h
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
WHERE
-- Hosts in operator's groups
EXISTS (
SELECT 1 FROM host_groups hg
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE hg.host_id = h.id AND ug.user_id = $3
)
-- OR ungrouped hosts
OR NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = h.id)
ORDER BY h.fqdn
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.bind(auth.user_id)
.fetch_all(&state.db)
.await
}
.map_err(|e| {
tracing::error!(error = %e, "Failed to list hosts");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM hosts")
.fetch_one(&state.db)
.await
.unwrap_or(0);
Ok(Json(HostListResponse {
hosts,
total,
limit,
offset,
}))
}
// ── POST /api/v1/hosts ────────────────────────────────────────────────────────
async fn register_host(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateHostRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Admin only
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Resolve FQDN to IP address
let ip_address = resolve_fqdn(&req.fqdn).await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "fqdn_resolution_failed", "message": e } })),
)
})?;
let display_name = req.display_name.clone().unwrap_or_else(|| req.fqdn.clone());
let agent_port = req.agent_port.unwrap_or(12443);
let notes = req.notes.clone().unwrap_or_default();
// Insert host
let host_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO hosts (fqdn, ip_address, display_name, agent_port, notes)
VALUES ($1, $2::inet, $3, $4, $5)
RETURNING id
"#,
)
.bind(&req.fqdn)
.bind(&ip_address)
.bind(&display_name)
.bind(agent_port)
.bind(&notes)
.fetch_one(&state.db)
.await
.map_err(|e| {
let msg = if e.to_string().contains("unique") {
"Host with this FQDN and IP already exists".to_string()
} else {
"Database error".to_string()
};
tracing::error!(error = %e, "Failed to register host");
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
// Assign to groups if specified
if let Some(group_ids) = &req.group_ids {
for gid in group_ids {
let _ = sqlx::query(
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
)
.bind(host_id)
.bind(gid)
.execute(&state.db)
.await;
}
}
// Audit log
log_event(
&state.db,
AuditAction::HostRegistered,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&host_id.to_string()),
json!({ "fqdn": req.fqdn, "ip": ip_address }),
None,
None,
)
.await;
tracing::info!(host_id = %host_id, fqdn = %req.fqdn, "Host registered");
Ok(Json(json!({ "id": host_id, "message": "Host registered" })))
}
// ── GET /api/v1/hosts/:id ─────────────────────────────────────────────────────
async fn get_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
.await
.unwrap_or(false);
if !can_access {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
}
let host: Option<Value> = sqlx::query_scalar(
r#"
SELECT row_to_json(h) FROM (
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
os_family, os_name, arch, agent_version, health_status,
last_health_at, last_patch_at, agent_port, notes,
registered_at, updated_at
FROM hosts WHERE id = $1
) h
"#,
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to get host");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
host.map(Json).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
})
}
// ── DELETE /api/v1/hosts/:id ──────────────────────────────────────────────────
async fn remove_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Fetch FQDN for audit before deletion
let fqdn: Option<String> = sqlx::query_scalar("SELECT fqdn FROM hosts WHERE id = $1")
.bind(id)
.fetch_optional(&state.db)
.await
.unwrap_or(None);
let result = sqlx::query("DELETE FROM hosts WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to remove host");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
if result.rows_affected() == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
));
}
log_event(
&state.db,
AuditAction::HostRemoved,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&id.to_string()),
json!({ "fqdn": fqdn }),
None,
None,
)
.await;
tracing::info!(host_id = %id, "Host removed");
Ok(Json(json!({ "message": "Host removed" })))
}
// ── PUT /api/v1/hosts/:id ─────────────────────────────────────────────────────
async fn update_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<UpdateHostRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Update only fields that were provided; COALESCE preserves existing values.
let host = sqlx::query_scalar(
r#"
WITH updated AS (
UPDATE hosts SET
fqdn = COALESCE($1, fqdn),
ip_address = COALESCE($2::inet, ip_address),
display_name = COALESCE($3, display_name),
updated_at = NOW()
WHERE id = $4
RETURNING id
)
SELECT row_to_json(h) FROM (
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
os_family, os_name, arch, agent_version, health_status,
last_health_at, last_patch_at, agent_port, notes,
registered_at, updated_at
FROM hosts WHERE id = (SELECT id FROM updated)
) h
"#,
)
.bind(&req.fqdn)
.bind(&req.ip_address)
.bind(&req.display_name)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, host_id = %id, "Failed to update host");
let msg = if e.to_string().contains("unique") {
"A host with this FQDN and IP already exists".to_string()
} else {
"Database error".to_string()
};
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
host.map(Json).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
)
})
}
// ── GET /api/v1/hosts/:id/groups ──────────────────────────────────────────────
async fn list_host_groups(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
.await
.unwrap_or(false);
if !can_access {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
}
let groups: Vec<Group> = sqlx::query_as(
r#"SELECT g.id, g.name, g.description, g.created_at, g.updated_at
FROM groups g
JOIN host_groups hg ON hg.group_id = g.id
WHERE hg.host_id = $1
ORDER BY g.name"#,
)
.bind(id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to list host groups");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(Json(groups))
}
// ── POST /api/v1/hosts/:id/groups ─────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct AddToGroupRequest {
group_id: Uuid,
}
async fn add_host_to_group(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<AddToGroupRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query(
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
)
.bind(id)
.bind(req.group_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to add host to group");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&id.to_string()),
json!({ "group_id": req.group_id, "action": "added" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Host added to group" })))
}
// ── DELETE /api/v1/hosts/:id/groups/:group_id ─────────────────────────────────
async fn remove_host_from_group(
State(state): State<AppState>,
auth: AuthUser,
Path((id, group_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
sqlx::query("DELETE FROM host_groups WHERE host_id = $1 AND group_id = $2")
.bind(id)
.bind(group_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to remove host from group");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
log_event(
&state.db,
AuditAction::GroupMembershipChanged,
Some(auth.user_id),
Some(&auth.username),
Some("host"),
Some(&id.to_string()),
json!({ "group_id": group_id, "action": "removed" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Host removed from group" })))
}
// ── FQDN resolution ───────────────────────────────────────────────────────────
/// Resolve an FQDN (or IP) to its primary IP address.
/// If the input is already a valid IP, returns it as-is.
async fn resolve_fqdn(fqdn: &str) -> Result<String, String> {
use std::net::ToSocketAddrs;
// Try direct IP parse first
if fqdn.parse::<std::net::IpAddr>().is_ok() {
return Ok(fqdn.to_string());
}
// DNS resolution
let addr = format!("{fqdn}:0");
match tokio::task::spawn_blocking(move || addr.to_socket_addrs()).await {
Ok(Ok(mut addrs)) => addrs
.next()
.map(|a| a.ip().to_string())
.ok_or_else(|| format!("No addresses found for {fqdn}")),
_ => Err(format!("Failed to resolve FQDN: {fqdn}")),
}
}
// ── POST /api/v1/hosts/:id/refresh ───────────────────────────────────────────
/// Queue an on-demand health + patch refresh for a single host.
///
/// Sends a PostgreSQL NOTIFY on the `refresh_requested` channel; the
/// pm-worker refresh listener picks this up and polls the host immediately.
/// Requires Operator or Admin role (any authenticated user).
async fn refresh_host(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<(StatusCode, Json<Value>), (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
// Verify the host exists.
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
.bind(id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "refresh_host: db error checking host existence");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
if !exists {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
));
}
// NOTIFY the worker's refresh listener.
sqlx::query("SELECT pg_notify('refresh_requested', $1)")
.bind(id.to_string())
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "refresh_host: pg_notify failed");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to queue refresh" } })),
)
})?;
tracing::info!(%id, "On-demand refresh queued");
Ok((
StatusCode::ACCEPTED,
Json(json!({ "message": "Refresh queued" })),
))
}

677
crates/pm-web/src/routes/jobs.rs Executable file
View File

@ -0,0 +1,677 @@
//! Patch job management routes.
//!
//! POST /api/v1/jobs — create a new patch job (operator+)
//! GET /api/v1/jobs — list jobs with pagination (RBAC scoped)
//! GET /api/v1/jobs/{id} — get job detail + per-host status
//! POST /api/v1/jobs/{id}/cancel — cancel a queued/pending job (admin or creator)
//! POST /api/v1/jobs/{id}/rollback — create a rollback job (admin only)
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateJobRequest, PatchJobSummary},
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
// ── Router ────────────────────────────────────────────────────────────────────
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_jobs).post(create_job))
.route("/{id}", get(get_job))
.route("/{id}/cancel", post(cancel_job))
.route("/{id}/rollback", post(rollback_job))
}
// ── Query params ──────────────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
pub struct JobListQuery {
pub limit: Option<i64>,
pub offset: Option<i64>,
}
// ── Response types ────────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct JobListResponse {
jobs: Vec<PatchJobSummary>,
total: i64,
limit: i64,
offset: i64,
}
/// Per-host row included in `GET /api/v1/jobs/{id}` response.
#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
struct JobHostRow {
pub id: Uuid,
pub host_id: Uuid,
pub display_name: String,
pub status: String,
pub agent_job_id: Option<String>,
pub retry_count: i32,
pub output: String,
pub error_message: Option<String>,
pub last_error: Option<String>,
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
}
// ── Error helper ──────────────────────────────────────────────────────────────
#[inline]
fn err(
status: StatusCode,
code: &'static str,
message: impl Into<String>,
) -> (StatusCode, Json<Value>) {
(
status,
Json(json!({ "error": { "code": code, "message": message.into() } })),
)
}
// ── RBAC helper ───────────────────────────────────────────────────────────────
/// Returns `true` when the operator's groups contain at least one host that
/// belongs to the given job. Admins always pass this check at the call site.
async fn operator_can_access_job(
pool: &sqlx::PgPool,
user_id: Uuid,
job_id: Uuid,
) -> Result<bool, sqlx::Error> {
sqlx::query_scalar(
r#"
SELECT EXISTS (
SELECT 1
FROM patch_job_hosts pjh
JOIN host_groups hg ON hg.host_id = pjh.host_id
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE pjh.job_id = $1
AND ug.user_id = $2
)
"#,
)
.bind(job_id)
.bind(user_id)
.fetch_one(pool)
.await
}
// ── POST /api/v1/jobs ─────────────────────────────────────────────────────────
async fn create_job(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateJobRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
if req.host_ids.is_empty() {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"host_ids must not be empty",
));
}
// Encode package list as JSONB.
let patch_selection = serde_json::to_value(&req.packages).unwrap_or(json!([]));
let notes = req.notes.clone().unwrap_or_default();
// Insert the parent job row; the DB NOTIFY trigger fires automatically
// when immediate = TRUE (see migration 003_jobs_scheduling.sql).
let job_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO patch_jobs
(kind, status, created_by_user_id, maintenance_window_id,
immediate, patch_selection, notes)
VALUES
('patch_apply'::job_kind, 'queued'::job_status, $1, $2, $3, $4, $5)
RETURNING id
"#,
)
.bind(auth.user_id)
.bind(req.maintenance_window_id)
.bind(req.immediate)
.bind(&patch_selection)
.bind(&notes)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "create_job: insert patch_jobs failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Insert one patch_job_hosts row per requested host.
for host_id in &req.host_ids {
sqlx::query(
r#"
INSERT INTO patch_job_hosts (job_id, host_id, status)
VALUES ($1, $2, 'queued'::job_status)
ON CONFLICT (job_id, host_id) DO NOTHING
"#,
)
.bind(job_id)
.bind(host_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(
error = %e, %job_id, %host_id,
"create_job: insert patch_job_hosts failed"
);
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
}
log_event(
&state.db,
AuditAction::PatchJobCreated,
Some(auth.user_id),
Some(&auth.username),
Some("job"),
Some(&job_id.to_string()),
json!({
"kind": "patch_apply",
"immediate": req.immediate,
"host_count": req.host_ids.len(),
"packages": req.packages,
"notes": notes,
}),
None,
None,
)
.await;
tracing::info!(
job_id = %job_id,
host_count = req.host_ids.len(),
immediate = req.immediate,
user = %auth.username,
"Patch job created"
);
Ok(Json(json!({ "id": job_id, "message": "Job created" })))
}
// ── GET /api/v1/jobs ──────────────────────────────────────────────────────────
async fn list_jobs(
State(state): State<AppState>,
auth: AuthUser,
Query(q): Query<JobListQuery>,
) -> Result<Json<JobListResponse>, (StatusCode, Json<Value>)> {
let limit = q.limit.unwrap_or(50).min(200);
let offset = q.offset.unwrap_or(0);
let jobs: Vec<PatchJobSummary> = if auth.role.is_admin() {
// Admins see every job.
sqlx::query_as(
r#"
SELECT
pj.id,
pj.kind,
pj.status,
pj.immediate,
pj.notes,
pj.created_at,
pj.started_at,
pj.completed_at,
COUNT(pjh.id) AS host_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
FROM patch_jobs pj
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
GROUP BY pj.id
ORDER BY pj.created_at DESC
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await
} else {
// Operators: only jobs where at least one host is in their groups.
sqlx::query_as(
r#"
SELECT
pj.id,
pj.kind,
pj.status,
pj.immediate,
pj.notes,
pj.created_at,
pj.started_at,
pj.completed_at,
COUNT(pjh.id) AS host_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
FROM patch_jobs pj
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
WHERE EXISTS (
SELECT 1
FROM patch_job_hosts pjh2
JOIN host_groups hg ON hg.host_id = pjh2.host_id
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE pjh2.job_id = pj.id
AND ug.user_id = $3
)
GROUP BY pj.id
ORDER BY pj.created_at DESC
LIMIT $1 OFFSET $2
"#,
)
.bind(limit)
.bind(offset)
.bind(auth.user_id)
.fetch_all(&state.db)
.await
}
.map_err(|e| {
tracing::error!(error = %e, "list_jobs: query failed");
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
})?;
// Total count for pagination metadata.
let total: i64 = if auth.role.is_admin() {
sqlx::query_scalar("SELECT COUNT(*) FROM patch_jobs")
.fetch_one(&state.db)
.await
.unwrap_or(0)
} else {
sqlx::query_scalar(
r#"
SELECT COUNT(DISTINCT pj.id)
FROM patch_jobs pj
WHERE EXISTS (
SELECT 1
FROM patch_job_hosts pjh
JOIN host_groups hg ON hg.host_id = pjh.host_id
JOIN user_groups ug ON ug.group_id = hg.group_id
WHERE pjh.job_id = pj.id
AND ug.user_id = $1
)
"#,
)
.bind(auth.user_id)
.fetch_one(&state.db)
.await
.unwrap_or(0)
};
Ok(Json(JobListResponse {
jobs,
total,
limit,
offset,
}))
}
// ── GET /api/v1/jobs/:id ─────────────────────────────────────────────────────
async fn get_job(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// RBAC: operators may only view jobs touching their group's hosts.
if !auth.role.is_admin() {
let allowed = operator_can_access_job(&state.db, auth.user_id, id)
.await
.unwrap_or(false);
if !allowed {
return Err(err(StatusCode::FORBIDDEN, "forbidden", "Access denied"));
}
}
// Fetch the job header row as JSON.
let job: Option<Value> = sqlx::query_scalar(
r#"
SELECT row_to_json(j) FROM (
SELECT id, kind, status, created_by_user_id, parent_job_id,
maintenance_window_id, immediate, patch_selection, notes,
created_at, started_at, completed_at
FROM patch_jobs
WHERE id = $1
) j
"#,
)
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "get_job: failed to fetch job");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
let job = job.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
// Fetch per-host status rows joined to the host display name.
let hosts: Vec<JobHostRow> = sqlx::query_as(
r#"
SELECT
pjh.id,
pjh.host_id,
COALESCE(h.display_name, h.fqdn) AS display_name,
pjh.status::text AS status,
pjh.agent_job_id,
pjh.retry_count,
pjh.output,
pjh.error_message,
pjh.last_error,
pjh.started_at,
pjh.completed_at
FROM patch_job_hosts pjh
JOIN hosts h ON h.id = pjh.host_id
WHERE pjh.job_id = $1
ORDER BY h.display_name
"#,
)
.bind(id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "get_job: failed to fetch host rows");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
Ok(Json(json!({ "job": job, "hosts": hosts })))
}
// ── POST /api/v1/jobs/:id/cancel ─────────────────────────────────────────────
async fn cancel_job(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Fetch the job to verify it exists and check ownership.
let row: Option<(String, Option<Uuid>)> =
sqlx::query_as("SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1")
.bind(id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "cancel_job: db fetch failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
let (status_str, creator_id) =
row.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
// Only admin or the job creator may cancel.
if !auth.role.can_write() {
let is_creator = creator_id == Some(auth.user_id);
if !is_creator {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
}
// Only queued or pending jobs can be cancelled.
if status_str != "queued" && status_str != "pending" {
return Err(err(
StatusCode::CONFLICT,
"invalid_state",
format!(
"Cannot cancel a job in '{}' state; only queued or pending jobs may be cancelled",
status_str
),
));
}
// Cancel the parent job.
sqlx::query("UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Cancel all queued/pending host rows for this job.
sqlx::query(
r#"
UPDATE patch_job_hosts
SET status = 'cancelled'::job_status
WHERE job_id = $1
AND status IN ('queued'::job_status, 'pending'::job_status)
"#,
)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "cancel_job: update patch_job_hosts failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Fire job-level pg_notify so the frontend can update the job row.
let notify_payload = json!({
"event_type": "job",
"job_id": id.to_string(),
"host_id": "",
"status": "cancelled",
"succeeded_count": 0,
"failed_count": 0,
"host_count": 0,
});
if let Ok(payload_str) = serde_json::to_string(&notify_payload) {
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
.bind(&payload_str)
.execute(&state.db)
.await
{
tracing::error!(error = %e, %id, "cancel_job: job-level pg_notify failed");
} else {
tracing::info!(%id, "cancel_job: job-level pg_notify sent");
}
}
log_event(
&state.db,
AuditAction::PatchJobCancelled,
Some(auth.user_id),
Some(&auth.username),
Some("job"),
Some(&id.to_string()),
json!({ "previous_status": status_str }),
None,
None,
)
.await;
tracing::info!(job_id = %id, user = %auth.username, "Patch job cancelled");
Ok(Json(json!({ "message": "Job cancelled" })))
}
// ── POST /api/v1/jobs/:id/rollback ────────────────────────────────────────────
async fn rollback_job(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Admin-only operation.
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
// Verify the original job exists.
let original_exists: bool =
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM patch_jobs WHERE id = $1)")
.bind(id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "rollback_job: existence check failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if !original_exists {
return Err(err(StatusCode::NOT_FOUND, "not_found", "Job not found"));
}
// Gather the host IDs from the original job.
let host_ids: Vec<Uuid> =
sqlx::query_scalar("SELECT host_id FROM patch_job_hosts WHERE job_id = $1")
.bind(id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %id, "rollback_job: host fetch failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if host_ids.is_empty() {
return Err(err(
StatusCode::UNPROCESSABLE_ENTITY,
"no_hosts",
"Original job has no host entries to roll back",
));
}
// Create the rollback job row (immediate = true so the worker picks it up
// right away and the NOTIFY trigger fires).
let rollback_job_id: Uuid = sqlx::query_scalar(
r#"
INSERT INTO patch_jobs
(kind, status, created_by_user_id, parent_job_id, immediate,
patch_selection, notes)
VALUES
('rollback'::job_kind, 'queued'::job_status, $1, $2, TRUE,
'[]'::jsonb, $3)
RETURNING id
"#,
)
.bind(auth.user_id)
.bind(id)
.bind(format!("Rollback of job {}", id))
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, parent_job_id = %id, "rollback_job: insert failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
// Replicate host list into the rollback job.
for host_id in &host_ids {
sqlx::query(
r#"
INSERT INTO patch_job_hosts (job_id, host_id, status)
VALUES ($1, $2, 'queued'::job_status)
ON CONFLICT (job_id, host_id) DO NOTHING
"#,
)
.bind(rollback_job_id)
.bind(host_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(
error = %e, %rollback_job_id, %host_id,
"rollback_job: insert patch_job_hosts failed"
);
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
}
log_event(
&state.db,
AuditAction::PatchJobRollback,
Some(auth.user_id),
Some(&auth.username),
Some("job"),
Some(&rollback_job_id.to_string()),
json!({
"original_job_id": id,
"rollback_job_id": rollback_job_id,
"host_count": host_ids.len(),
}),
None,
None,
)
.await;
tracing::info!(
rollback_job_id = %rollback_job_id,
original_job_id = %id,
user = %auth.username,
"Rollback job created"
);
Ok(Json(json!({
"id": rollback_job_id,
"parent_job_id": id,
"message": "Rollback job created"
})))
}

View File

@ -0,0 +1,452 @@
//! Maintenance window management routes.
//!
//! GET /api/v1/hosts/{id}/maintenance-windows — list windows for host
//! GET /api/v1/maintenance-windows — list ALL windows (bulk)
//! POST /api/v1/hosts/{id}/maintenance-windows — create window for host
//! PUT /api/v1/hosts/{id}/maintenance-windows/{win_id} — update window
//! DELETE /api/v1/hosts/{id}/maintenance-windows/{win_id} — delete window
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{get, put},
Router,
};
use pm_auth::rbac::AuthUser;
use pm_core::{
audit::{log_event, AuditAction},
models::{CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest},
};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
// ── Router ────────────────────────────────────────────────────────────────────
/// Mount as a nested router under `/hosts/{host_id}/maintenance-windows`.
/// Axum will merge the `{host_id}` path segment from the parent nest.
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_windows).post(create_window))
.route("/{win_id}", put(update_window).delete(delete_window))
}
/// Top-level router for `/api/v1/maintenance-windows` — bulk list-all endpoint.
pub fn all_windows_router() -> Router<AppState> {
Router::new().route("/", get(list_all_windows))
}
// ── GET /api/v1/maintenance-windows ──────────────────────────────────────────
/// Bulk endpoint: return every maintenance window across all hosts.
/// Eliminates N+1 queries from the frontend (one request instead of one per host).
async fn list_all_windows(
State(state): State<AppState>,
_auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
r#"
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
FROM maintenance_windows
ORDER BY host_id, created_at ASC
"#,
)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "list_all_windows: query failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
Ok(Json(json!({ "windows": windows })))
}
// ── Error helper ──────────────────────────────────────────────────────────────
#[inline]
fn err(
status: StatusCode,
code: &'static str,
message: impl Into<String>,
) -> (StatusCode, Json<Value>) {
(
status,
Json(json!({ "error": { "code": code, "message": message.into() } })),
)
}
// ── GET /api/v1/hosts/:host_id/maintenance-windows ────────────────────────────
async fn list_windows(
State(state): State<AppState>,
_auth: AuthUser,
Path(host_id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Verify host exists.
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "list_windows: host existence check failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if !host_exists {
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
}
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
r#"
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
FROM maintenance_windows
WHERE host_id = $1
ORDER BY created_at ASC
"#,
)
.bind(host_id)
.fetch_all(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "list_windows: query failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
Ok(Json(json!({ "windows": windows })))
}
// ── POST /api/v1/hosts/:host_id/maintenance-windows ───────────────────────────
async fn create_window(
State(state): State<AppState>,
auth: AuthUser,
Path(host_id): Path<Uuid>,
Json(req): Json<CreateMaintenanceWindowRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
// Validate: weekly requires recurrence_day 0-6
if req.recurrence == pm_core::models::WindowRecurrence::Weekly {
match req.recurrence_day {
Some(d) if (0..=6).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Weekly recurrence requires recurrence_day 0-6 (0=Sunday)",
));
},
}
}
// Validate: monthly requires recurrence_day 1-31
if req.recurrence == pm_core::models::WindowRecurrence::Monthly {
match req.recurrence_day {
Some(d) if (1..=31).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Monthly recurrence requires recurrence_day 1-31",
));
},
}
}
// Verify host exists.
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
.bind(host_id)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "create_window: host existence check failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if !host_exists {
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
}
let duration = req.duration_minutes.unwrap_or(60);
let enabled = req.enabled.unwrap_or(true);
let auto_apply = req.auto_apply.unwrap_or(true);
let window: MaintenanceWindow = sqlx::query_as(
r#"
INSERT INTO maintenance_windows
(host_id, label, recurrence, start_at, duration_minutes, recurrence_day, enabled, auto_apply)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
"#,
)
.bind(host_id)
.bind(&req.label)
.bind(&req.recurrence)
.bind(req.start_at)
.bind(duration)
.bind(req.recurrence_day)
.bind(enabled)
.bind(auto_apply)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %host_id, "create_window: insert failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
log_event(
&state.db,
AuditAction::MaintenanceWindowCreated,
Some(auth.user_id),
Some(&auth.username),
Some("maintenance_window"),
Some(&window.id.to_string()),
json!({
"host_id": host_id,
"label": window.label,
"recurrence": window.recurrence.to_string(),
}),
None,
None,
)
.await;
tracing::info!(
window_id = %window.id,
%host_id,
recurrence = %window.recurrence,
user = %auth.username,
"Maintenance window created"
);
Ok(Json(json!(window)))
}
// ── PUT /api/v1/hosts/:host_id/maintenance-windows/:win_id ───────────────────
async fn update_window(
State(state): State<AppState>,
auth: AuthUser,
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
Json(req): Json<UpdateMaintenanceWindowRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
// Fetch existing record (verify ownership and existence).
let existing: Option<MaintenanceWindow> = sqlx::query_as(
r#"
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
FROM maintenance_windows
WHERE id = $1 AND host_id = $2
"#,
)
.bind(win_id)
.bind(host_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %win_id, "update_window: fetch failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
let existing = existing.ok_or_else(|| {
err(
StatusCode::NOT_FOUND,
"not_found",
"Maintenance window not found",
)
})?;
// Apply partial updates using existing values as defaults.
let new_label = req.label.unwrap_or(existing.label);
let new_recurrence = req.recurrence.unwrap_or(existing.recurrence);
let new_start_at = req.start_at.unwrap_or(existing.start_at);
let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes);
let new_rec_day = req.recurrence_day.or(existing.recurrence_day);
let new_enabled = req.enabled.unwrap_or(existing.enabled);
let new_auto_apply = req.auto_apply.unwrap_or(existing.auto_apply);
// Validate recurrence_day for the final recurrence type.
if new_recurrence == pm_core::models::WindowRecurrence::Weekly {
match new_rec_day {
Some(d) if (0..=6).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Weekly recurrence requires recurrence_day 0-6",
));
},
}
}
if new_recurrence == pm_core::models::WindowRecurrence::Monthly {
match new_rec_day {
Some(d) if (1..=31).contains(&d) => {},
_ => {
return Err(err(
StatusCode::BAD_REQUEST,
"bad_request",
"Monthly recurrence requires recurrence_day 1-31",
));
},
}
}
let updated: MaintenanceWindow = sqlx::query_as(
r#"
UPDATE maintenance_windows
SET label = $3,
recurrence = $4,
start_at = $5,
duration_minutes = $6,
recurrence_day = $7,
enabled = $8,
auto_apply = $9,
updated_at = NOW()
WHERE id = $1 AND host_id = $2
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
recurrence_day, enabled, auto_apply, created_at, updated_at
"#,
)
.bind(win_id)
.bind(host_id)
.bind(&new_label)
.bind(&new_recurrence)
.bind(new_start_at)
.bind(new_duration)
.bind(new_rec_day)
.bind(new_enabled)
.bind(new_auto_apply)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %win_id, "update_window: update failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
log_event(
&state.db,
AuditAction::MaintenanceWindowUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("maintenance_window"),
Some(&win_id.to_string()),
json!({ "host_id": host_id }),
None,
None,
)
.await;
tracing::info!(
window_id = %win_id,
%host_id,
user = %auth.username,
"Maintenance window updated"
);
Ok(Json(json!(updated)))
}
// ── DELETE /api/v1/hosts/:host_id/maintenance-windows/:win_id ────────────────
async fn delete_window(
State(state): State<AppState>,
auth: AuthUser,
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err(err(
StatusCode::FORBIDDEN,
"forbidden",
"Write access required",
));
}
let result = sqlx::query("DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2")
.bind(win_id)
.bind(host_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, %win_id, "delete_window: delete failed");
err(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_error",
"Database error",
)
})?;
if result.rows_affected() == 0 {
return Err(err(
StatusCode::NOT_FOUND,
"not_found",
"Maintenance window not found",
));
}
log_event(
&state.db,
AuditAction::MaintenanceWindowDeleted,
Some(auth.user_id),
Some(&auth.username),
Some("maintenance_window"),
Some(&win_id.to_string()),
json!({ "host_id": host_id }),
None,
None,
)
.await;
tracing::info!(
window_id = %win_id,
%host_id,
user = %auth.username,
"Maintenance window deleted"
);
Ok(Json(json!({ "message": "Maintenance window deleted" })))
}

17
crates/pm-web/src/routes/mod.rs Executable file
View File

@ -0,0 +1,17 @@
//! Route modules for the pm-web API.
pub mod auth;
pub mod ca;
pub mod discovery;
pub mod enrollment;
pub mod groups;
pub mod health_checks;
pub mod hosts;
pub mod jobs;
pub mod maintenance_windows;
pub mod settings;
pub mod sso;
pub mod status;
pub mod users;
pub mod ws;
pub mod reports;

View File

@ -0,0 +1,163 @@
//! Report generation endpoints.
//!
//! GET /api/v1/reports/compliance?format=csv|pdf&from=...&to=...&group_id=...
//! GET /api/v1/reports/patch-history?format=csv|pdf&from=...&to=...
//! GET /api/v1/reports/vulnerability?format=csv|pdf&from=...&to=...
//! GET /api/v1/reports/audit?format=csv|pdf&from=...&to=...
use axum::{
body::Bytes,
extract::{Query, State},
http::{header, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
routing::get,
Router,
};
use pm_reports::{ReportParams, ReportType};
use crate::AppState;
#[derive(serde::Deserialize)]
struct ReportQuery {
/// "csv" or "pdf" (defaults to "csv")
format: Option<String>,
from: Option<chrono::DateTime<chrono::Utc>>,
to: Option<chrono::DateTime<chrono::Utc>>,
group_id: Option<uuid::Uuid>,
}
pub fn router() -> Router<AppState> {
Router::new()
.route("/compliance", get(compliance_report))
.route("/patch-history", get(patch_history_report))
.route("/vulnerability", get(vulnerability_report))
.route("/audit", get(audit_report))
}
// ---------------------------------------------------------------------------
// Internal helper
// ---------------------------------------------------------------------------
async fn run_report(
db: sqlx::PgPool,
params: ReportParams,
use_pdf: bool,
csv_name: &'static str,
pdf_name: &'static str,
) -> Response {
let (ct, disposition, result) = if use_pdf {
let disp = format!("attachment; filename=\"{}\"", pdf_name);
let data = pm_reports::generate_pdf(&db, &params).await;
("application/pdf", disp, data)
} else {
let disp = format!("attachment; filename=\"{}\"", csv_name);
let data = pm_reports::generate_csv(&db, &params).await;
("text/csv; charset=utf-8", disp, data)
};
match result {
Ok(bytes) => {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static(ct));
headers.insert(
header::CONTENT_DISPOSITION,
HeaderValue::from_str(&disposition)
.unwrap_or_else(|_| HeaderValue::from_static("attachment")),
);
(headers, Bytes::from(bytes)).into_response()
},
Err(e) => {
tracing::error!(error = %e, "report generation failed");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Report error: {}", e),
)
.into_response()
},
}
}
// ---------------------------------------------------------------------------
// Handlers
// ---------------------------------------------------------------------------
async fn compliance_report(
State(state): State<AppState>,
Query(q): Query<ReportQuery>,
) -> Response {
let params = ReportParams {
report_type: ReportType::Compliance,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"compliance-report.csv",
"compliance-report.pdf",
)
.await
}
async fn patch_history_report(
State(state): State<AppState>,
Query(q): Query<ReportQuery>,
) -> Response {
let params = ReportParams {
report_type: ReportType::PatchHistory,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"patch-history-report.csv",
"patch-history-report.pdf",
)
.await
}
async fn vulnerability_report(
State(state): State<AppState>,
Query(q): Query<ReportQuery>,
) -> Response {
let params = ReportParams {
report_type: ReportType::Vulnerability,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"vulnerability-report.csv",
"vulnerability-report.pdf",
)
.await
}
async fn audit_report(State(state): State<AppState>, Query(q): Query<ReportQuery>) -> Response {
let params = ReportParams {
report_type: ReportType::Audit,
from: q.from,
to: q.to,
group_id: q.group_id,
};
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
run_report(
state.db,
params,
use_pdf,
"audit-report.csv",
"audit-report.pdf",
)
.await
}

View File

@ -0,0 +1,977 @@
//! Settings management routes.
//!
//! GET /api/v1/settings — get all settings (admin only)
//! PUT /api/v1/settings — update settings (admin only)
//! POST /api/v1/settings/sso/discover — discover OIDC endpoints (admin only)
//! POST /api/v1/settings/sso/test — test OIDC provider connectivity (admin only)
//! POST /api/v1/settings/azure-sso/test — backward-compat alias for SSO test (admin only)
//! POST /api/v1/settings/smtp/test — send test email (admin only)
//! GET /api/v1/settings/ip-whitelist — get IP whitelist (admin only)
//! PUT /api/v1/settings/ip-whitelist — update IP whitelist (admin only)
//! POST /api/v1/settings/audit-integrity — verify audit log integrity (admin only)
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use lettre::{
message::{header::ContentType, Mailbox},
transport::smtp::authentication::Credentials,
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
};
use pm_auth::rbac::AuthUser;
use pm_core::audit::{log_event, verify_integrity, AuditAction};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::AppState;
// ============================================================
// Data structures
// ============================================================
#[derive(Debug, Serialize)]
pub struct SettingsResponse {
pub oidc: OidcConfigResponse,
pub smtp: SmtpConfig,
pub polling: PollingConfig,
pub ip_whitelist: Vec<String>,
pub web_tls_strategy: String,
pub notification: NotificationConfig,
pub sso_callback_url: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OidcConfigResponse {
pub enabled: bool,
pub provider_type: String, // "keycloak", "azure", "custom"
pub display_name: String,
pub discovery_url: String,
pub client_id: String,
pub client_secret: String, // Always masked in responses
pub redirect_uri: String,
pub scopes: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SmtpConfig {
pub enabled: bool,
pub host: String,
pub port: u16,
pub username: String,
pub from: String,
pub tls_mode: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PollingConfig {
pub health_poll_interval_secs: u64,
pub patch_poll_interval_secs: u64,
}
#[derive(Debug, Deserialize)]
pub struct UpdateSettingsRequest {
pub oidc: Option<OidcConfigUpdate>,
pub smtp: Option<SmtpConfigUpdate>,
pub polling: Option<PollingConfigUpdate>,
pub ip_whitelist: Option<Vec<String>>,
pub web_tls_strategy: Option<String>,
pub notification: Option<NotificationConfigUpdate>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct NotificationConfig {
pub email_enabled: bool,
pub email_from: String,
pub recipients: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct NotificationConfigUpdate {
pub email_enabled: Option<bool>,
pub email_from: Option<String>,
pub recipients: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct OidcConfigUpdate {
pub enabled: Option<bool>,
pub provider_type: Option<String>,
pub display_name: Option<String>,
pub discovery_url: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub redirect_uri: Option<String>,
pub scopes: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OidcDiscoveryRequest {
pub discovery_url: String,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
pub struct OidcDiscoveryResult {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub jwks_uri: String,
pub userinfo_endpoint: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct SmtpConfigUpdate {
pub enabled: Option<bool>,
pub host: Option<String>,
pub port: Option<u16>,
pub username: Option<String>,
pub password: Option<String>,
pub from: Option<String>,
pub tls_mode: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct PollingConfigUpdate {
pub health_poll_interval_secs: Option<u64>,
pub patch_poll_interval_secs: Option<u64>,
}
#[derive(Debug, Deserialize)]
pub struct IpWhitelistUpdate {
pub entries: Vec<String>,
}
// ============================================================
// Router
// ============================================================
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(get_settings).put(update_settings))
.route("/sso/discover", post(discover_oidc))
.route("/sso/test", post(test_oidc))
.route("/azure-sso/test", post(test_azure_sso_compat))
.route("/smtp/test", post(test_smtp))
.route(
"/ip-whitelist",
get(get_ip_whitelist).put(update_ip_whitelist),
)
.route("/audit-integrity", post(audit_integrity))
}
// ============================================================
// Helpers
// ============================================================
const MASKED: &str = "********";
fn write_access_required(auth: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
if !auth.role.can_write() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
));
}
Ok(())
}
async fn load_system_config(
pool: &sqlx::PgPool,
) -> Result<HashMap<String, String>, (StatusCode, Json<Value>)> {
let rows: Vec<(String, String)> = sqlx::query_as("SELECT key, value FROM system_config")
.fetch_all(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load system_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(rows.into_iter().collect())
}
fn build_settings_response(
cfg: &HashMap<String, String>,
oidc: OidcConfigResponse,
) -> SettingsResponse {
let get = |key: &str| -> String { cfg.get(key).cloned().unwrap_or_default() };
let recipients: Vec<String> =
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
SettingsResponse {
oidc,
smtp: SmtpConfig {
enabled: get("smtp_enabled") == "true",
host: get("smtp_host"),
port: get("smtp_port").parse().unwrap_or(587),
username: get("smtp_username"),
from: get("smtp_from"),
tls_mode: get("smtp_tls_mode"),
},
polling: PollingConfig {
health_poll_interval_secs: get("health_poll_interval_secs").parse().unwrap_or(300),
patch_poll_interval_secs: get("patch_poll_interval_secs").parse().unwrap_or(1800),
},
ip_whitelist: serde_json::from_str(&get("ip_whitelist")).unwrap_or_default(),
web_tls_strategy: get("web_tls_strategy"),
notification: NotificationConfig {
email_enabled: get("notification_email_enabled") == "true",
email_from: get("notification_email_from"),
recipients,
},
sso_callback_url: get("sso_callback_url"),
}
}
async fn update_config_key(
pool: &sqlx::PgPool,
key: &str,
value: &str,
) -> Result<(), (StatusCode, Json<Value>)> {
sqlx::query("UPDATE system_config SET value = $1, updated_at = NOW() WHERE key = $2")
.bind(value)
.bind(key)
.execute(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, key, "Failed to update system_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(())
}
async fn fetch_oidc_config(
pool: &sqlx::PgPool,
) -> Result<OidcConfigResponse, (StatusCode, Json<Value>)> {
let row: Option<(bool, String, String, String, String, String, String, String)> = sqlx::query_as(
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
)
.fetch_optional(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(match row {
Some((
enabled,
provider_type,
display_name,
discovery_url,
client_id,
client_secret,
redirect_uri,
scopes,
)) => OidcConfigResponse {
enabled,
provider_type,
display_name,
discovery_url,
client_id,
client_secret: if client_secret.is_empty() {
String::new()
} else {
MASKED.to_string()
},
redirect_uri,
scopes,
},
None => OidcConfigResponse {
enabled: false,
provider_type: "azure".to_string(),
display_name: "Azure AD".to_string(),
discovery_url: String::new(),
client_id: String::new(),
client_secret: String::new(),
redirect_uri: String::new(),
scopes: "openid profile email".to_string(),
},
})
}
// ============================================================
// GET /api/v1/settings
// ============================================================
async fn get_settings(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let cfg = load_system_config(&state.db).await?;
// Inject read-only config values from TOML file (not stored in DB)
let mut cfg = cfg;
cfg.insert(
"sso_callback_url".to_string(),
state.config.security.sso_callback_url.clone(),
);
let oidc = fetch_oidc_config(&state.db).await?;
Ok(Json(build_settings_response(&cfg, oidc)))
}
// ============================================================
// PUT /api/v1/settings
// ============================================================
async fn update_settings(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<UpdateSettingsRequest>,
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
// Update OIDC config
if let Some(oidc) = req.oidc {
let update_secret = oidc
.client_secret
.as_ref()
.is_some_and(|s| s != MASKED && !s.is_empty());
let result = if update_secret {
sqlx::query(
"UPDATE oidc_config SET \
enabled = COALESCE($1, enabled), \
provider_type = COALESCE($2, provider_type), \
display_name = COALESCE($3, display_name), \
discovery_url = COALESCE($4, discovery_url), \
client_id = COALESCE($5, client_id), \
client_secret = $6, \
redirect_uri = COALESCE($7, redirect_uri), \
scopes = COALESCE($8, scopes), \
updated_at = NOW() \
WHERE id = 1",
)
.bind(oidc.enabled)
.bind(&oidc.provider_type)
.bind(&oidc.display_name)
.bind(&oidc.discovery_url)
.bind(&oidc.client_id)
.bind(oidc.client_secret.as_deref().unwrap_or(""))
.bind(&oidc.redirect_uri)
.bind(&oidc.scopes)
.execute(&state.db)
.await
} else {
sqlx::query(
"UPDATE oidc_config SET \
enabled = COALESCE($1, enabled), \
provider_type = COALESCE($2, provider_type), \
display_name = COALESCE($3, display_name), \
discovery_url = COALESCE($4, discovery_url), \
client_id = COALESCE($5, client_id), \
redirect_uri = COALESCE($6, redirect_uri), \
scopes = COALESCE($7, scopes), \
updated_at = NOW() \
WHERE id = 1",
)
.bind(oidc.enabled)
.bind(&oidc.provider_type)
.bind(&oidc.display_name)
.bind(&oidc.discovery_url)
.bind(&oidc.client_id)
.bind(&oidc.redirect_uri)
.bind(&oidc.scopes)
.execute(&state.db)
.await
};
result.map_err(|e| {
tracing::error!(error = %e, "Failed to update oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": format!("Failed to update OIDC config: {}", e) } })),
)
})?;
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("oidc"),
Some("1"),
json!({ "section": "oidc" }),
None,
None,
)
.await;
}
// Update SMTP config
if let Some(smtp) = &req.smtp {
if let Some(v) = smtp.enabled {
update_config_key(&state.db, "smtp_enabled", &v.to_string()).await?;
}
if let Some(ref v) = smtp.host {
update_config_key(&state.db, "smtp_host", v).await?;
}
if let Some(v) = smtp.port {
update_config_key(&state.db, "smtp_port", &v.to_string()).await?;
}
if let Some(ref v) = smtp.username {
update_config_key(&state.db, "smtp_username", v).await?;
}
if let Some(ref v) = smtp.password {
if v != MASKED {
update_config_key(&state.db, "smtp_password", v).await?;
}
}
if let Some(ref v) = smtp.from {
update_config_key(&state.db, "smtp_from", v).await?;
}
if let Some(ref v) = smtp.tls_mode {
update_config_key(&state.db, "smtp_tls_mode", v).await?;
}
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("smtp"),
Some("system_config"),
json!({ "section": "smtp" }),
None,
None,
)
.await;
}
// Update polling config
if let Some(polling) = &req.polling {
if let Some(v) = polling.health_poll_interval_secs {
update_config_key(&state.db, "health_poll_interval_secs", &v.to_string()).await?;
}
if let Some(v) = polling.patch_poll_interval_secs {
update_config_key(&state.db, "patch_poll_interval_secs", &v.to_string()).await?;
}
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("polling"),
Some("system_config"),
json!({ "section": "polling" }),
None,
None,
)
.await;
}
// Update IP whitelist
if let Some(ref entries) = req.ip_whitelist {
let json_str = serde_json::to_string(entries).unwrap_or_else(|_| "[]".to_string());
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
// Update in-memory AuthConfig for immediate enforcement
state.auth_config.update_ip_whitelist(entries.clone());
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("ip_whitelist"),
Some("system_config"),
json!({ "entries": entries }),
None,
None,
)
.await;
}
// Update web TLS strategy
if let Some(ref v) = req.web_tls_strategy {
update_config_key(&state.db, "web_tls_strategy", v).await?;
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("web_tls_strategy"),
Some("system_config"),
json!({ "web_tls_strategy": v }),
None,
None,
)
.await;
}
// Update notification config
if let Some(notif) = &req.notification {
if let Some(v) = notif.email_enabled {
update_config_key(&state.db, "notification_email_enabled", &v.to_string()).await?;
}
if let Some(ref v) = notif.email_from {
update_config_key(&state.db, "notification_email_from", v).await?;
}
if let Some(ref v) = notif.recipients {
let json_str = serde_json::to_string(v).unwrap_or_else(|_| "[]".to_string());
update_config_key(&state.db, "notification_email_recipients", &json_str).await?;
}
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("notification"),
Some("system_config"),
json!({ "section": "notification" }),
None,
None,
)
.await;
}
// Return updated settings
let cfg = load_system_config(&state.db).await?;
// Inject read-only config values from TOML file (not stored in DB)
let mut cfg = cfg;
cfg.insert(
"sso_callback_url".to_string(),
state.config.security.sso_callback_url.clone(),
);
let oidc = fetch_oidc_config(&state.db).await?;
Ok(Json(build_settings_response(&cfg, oidc)))
}
// ============================================================
// POST /api/v1/settings/sso/discover
// ============================================================
async fn discover_oidc(
State(_state): State<AppState>,
auth: AuthUser,
Json(req): Json<OidcDiscoveryRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
if req.discovery_url.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "bad_request", "message": "discovery_url is required" } }),
),
));
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| {
tracing::error!(error = %e, "Failed to build HTTP client");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
)
})?;
match client.get(&req.discovery_url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: Value = resp.json().await.unwrap_or(json!({}));
Ok(Json(json!({
"success": true,
"issuer": body.get("issuer").and_then(|v| v.as_str()).unwrap_or(""),
"authorization_endpoint": body.get("authorization_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
"token_endpoint": body.get("token_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
"jwks_uri": body.get("jwks_uri").and_then(|v| v.as_str()).unwrap_or(""),
"userinfo_endpoint": body.get("userinfo_endpoint").and_then(|v| v.as_str()),
})))
},
Ok(resp) => Err((
StatusCode::BAD_GATEWAY,
Json(
json!({ "error": { "code": "discovery_failed", "message": format!("Discovery endpoint returned HTTP {}", resp.status()) } }),
),
)),
Err(e) => Err((
StatusCode::BAD_GATEWAY,
Json(
json!({ "error": { "code": "discovery_failed", "message": format!("Failed to reach discovery endpoint: {}", e) } }),
),
)),
}
}
// ============================================================
// POST /api/v1/settings/sso/test
// ============================================================
async fn test_oidc(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let row: Option<(bool, String, String)> = sqlx::query_as(
"SELECT enabled, provider_type, discovery_url FROM oidc_config WHERE id = 1",
)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let (enabled, provider_type, discovery_url) = match row {
Some(r) => r,
None => {
return Ok(Json(json!({
"success": false,
"message": "OIDC is not configured"
})));
},
};
if !enabled {
return Ok(Json(json!({
"success": false,
"message": "OIDC is not enabled"
})));
}
if discovery_url.is_empty() {
return Ok(Json(json!({
"success": false,
"message": "OIDC discovery URL is not set"
})));
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| {
tracing::error!(error = %e, "Failed to build HTTP client");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
)
})?;
match client.get(&discovery_url).send().await {
Ok(resp) if resp.status().is_success() => {
let body: Value = resp.json().await.unwrap_or(json!({}));
let issuer = body.get("issuer").and_then(|v| v.as_str()).unwrap_or("");
let provider_label = match provider_type.as_str() {
"keycloak" => "Keycloak",
"azure" => "Azure AD",
_ => "OIDC",
};
Ok(Json(json!({
"success": true,
"message": format!("{} provider verified successfully", provider_label),
"issuer": issuer,
"provider_type": provider_type,
})))
},
Ok(resp) => Ok(Json(json!({
"success": false,
"message": format!("Failed to reach OIDC provider: HTTP {}", resp.status())
}))),
Err(e) => Ok(Json(json!({
"success": false,
"message": format!("Failed to reach OIDC provider: {}", e)
}))),
}
}
// ============================================================
// POST /api/v1/settings/azure-sso/test (backward-compatible alias)
// ============================================================
async fn test_azure_sso_compat(
state: State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
test_oidc(state, auth).await
}
// ============================================================
// POST /api/v1/settings/smtp/test
// ============================================================
async fn test_smtp(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let cfg = load_system_config(&state.db).await?;
let smtp_enabled = cfg.get("smtp_enabled").map(|v| v.as_str()) == Some("true");
if !smtp_enabled {
return Ok(Json(json!({
"success": false,
"message": "SMTP is not enabled"
})));
}
let host = cfg.get("smtp_host").cloned().unwrap_or_default();
let port: u16 = cfg
.get("smtp_port")
.and_then(|v| v.parse().ok())
.unwrap_or(587);
let username = cfg.get("smtp_username").cloned().unwrap_or_default();
let password = cfg.get("smtp_password").cloned().unwrap_or_default();
let from_addr = cfg.get("smtp_from").cloned().unwrap_or_default();
let tls_mode = cfg
.get("smtp_tls_mode")
.cloned()
.unwrap_or_else(|| "starttls".to_string());
let recipients_str = cfg
.get("notification_email_recipients")
.cloned()
.unwrap_or_default();
let recipients: Vec<String> = serde_json::from_str(&recipients_str).unwrap_or_default();
if host.is_empty() || from_addr.is_empty() {
return Ok(Json(json!({
"success": false,
"message": "SMTP host or from address is not configured"
})));
}
let result = send_smtp_test(
&host,
port,
&username,
&password,
&from_addr,
&tls_mode,
&recipients,
)
.await;
match result {
Ok(()) => {
let recipient_info = if recipients.is_empty() {
String::new()
} else {
format!(" and {} recipient(s)", recipients.len())
};
Ok(Json(json!({
"success": true,
"message": format!("Test email sent successfully to from address{}", recipient_info)
})))
},
Err(e) => Ok(Json(json!({
"success": false,
"message": format!("Failed to send test email: {}", e)
}))),
}
}
async fn send_smtp_test(
host: &str,
port: u16,
username: &str,
password: &str,
from_addr: &str,
tls_mode: &str,
recipients: &[String],
) -> Result<(), String> {
let from_mailbox: Mailbox = from_addr
.parse()
.map_err(|e| format!("Invalid from address: {}", e))?;
let mut builder = Message::builder()
.from(from_mailbox.clone())
.to(from_mailbox);
for recipient in recipients {
if let Ok(addr) = recipient.parse() {
builder = builder.bcc(addr);
}
}
let body = if recipients.is_empty() {
"This is a test email from Linux Patch Manager.".to_string()
} else {
format!(
"This is a test email from Linux Patch Manager.\n\nSent to: {}",
recipients.join(", ")
)
};
let email = builder
.subject("Linux Patch Manager — SMTP Test")
.header(ContentType::TEXT_PLAIN)
.body(body)
.map_err(|e| format!("Failed to build email: {}", e))?;
let result = match tls_mode {
"tls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(host)
.map_err(|e| format!("TLS relay error: {}", e))?;
builder = builder.port(port);
if !username.is_empty() {
builder = builder
.credentials(Credentials::new(username.to_string(), password.to_string()));
}
let transport = builder.build();
transport.send(email).await
},
"starttls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(host)
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
builder = builder.port(port);
if !username.is_empty() {
builder = builder
.credentials(Credentials::new(username.to_string(), password.to_string()));
}
let transport = builder.build();
transport.send(email).await
},
_ => {
// "none" — plaintext / no TLS
let mut builder =
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(host).port(port);
if !username.is_empty() {
builder = builder
.credentials(Credentials::new(username.to_string(), password.to_string()));
}
let transport = builder.build();
transport.send(email).await
},
};
result
.map(|_| ())
.map_err(|e| format!("SMTP send error: {}", e))
}
// ============================================================
// GET /api/v1/settings/ip-whitelist
// ============================================================
async fn get_ip_whitelist(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let value: Option<String> = sqlx::query_scalar(
"SELECT value FROM system_config WHERE key = 'ip_whitelist'",
)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load ip_whitelist");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let entries: Vec<String> = serde_json::from_str(&value.unwrap_or_default()).unwrap_or_default();
Ok(Json(json!({ "entries": entries })))
}
// ============================================================
// PUT /api/v1/settings/ip-whitelist
// ============================================================
async fn update_ip_whitelist(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<IpWhitelistUpdate>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
// Validate each entry
for entry in &req.entries {
if entry.parse::<ipnet::IpNet>().is_err() && entry.parse::<std::net::IpAddr>().is_err() {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "bad_request", "message": format!("Invalid CIDR or IP: {}", entry) } }),
),
));
}
}
let json_str = serde_json::to_string(&req.entries).unwrap_or_else(|_| "[]".to_string());
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
// Update in-memory AuthConfig for immediate enforcement
state.auth_config.update_ip_whitelist(req.entries.clone());
log_event(
&state.db,
AuditAction::ConfigChanged,
Some(auth.user_id),
Some(&auth.username),
Some("ip_whitelist"),
Some("system_config"),
json!({ "entries": req.entries }),
None,
None,
)
.await;
Ok(Json(json!({ "entries": req.entries })))
}
// ============================================================
// POST /api/v1/settings/audit-integrity
// ============================================================
/// Verify audit log hash chain integrity.
/// Returns whether the chain is intact, rows checked, and any errors.
async fn audit_integrity(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
write_access_required(&auth)?;
let result = verify_integrity(&state.db).await;
log_event(
&state.db,
AuditAction::AuditIntegrityVerified,
Some(auth.user_id),
Some(&auth.username),
Some("audit_log"),
None,
json!({
"intact": result.intact,
"rows_checked": result.rows_checked,
"error_count": result.errors.len(),
}),
None,
None,
)
.await;
Ok(Json(json!({
"intact": result.intact,
"rows_checked": result.rows_checked,
"errors": result.errors.iter().map(|e| json!({
"row_id": e.row_id,
"expected_hash": e.expected_hash,
"actual_hash": e.actual_hash,
})).collect::<Vec<_>>(),
})))
}

838
crates/pm-web/src/routes/sso.rs Executable file
View File

@ -0,0 +1,838 @@
//! Generic OIDC SSO routes (Keycloak, Azure AD, Custom).
//!
//! Public routes (no auth required):
//! GET /api/v1/auth/sso/login — redirect to OIDC provider authorization URL
//! GET /api/v1/auth/sso/callback — handle OIDC provider callback, redirect to frontend SPA
//!
//! Backward-compatible aliases:
//! GET /api/v1/auth/azure/login → redirects to generic SSO login
//! GET /api/v1/auth/azure/callback → redirects to generic SSO callback
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Redirect},
routing::get,
Router,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use chrono::Utc;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use pm_auth::{jwt::issue_access_token, refresh};
use pm_core::audit::{log_event, AuditAction};
use serde::Deserialize;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::AppState;
// ============================================================
// Data structures
// ============================================================
#[derive(Clone)]
pub struct SsoSession {
pub code_verifier: String,
pub created_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
#[allow(dead_code)]
access_token: Option<String>,
id_token: Option<String>,
#[allow(dead_code)]
token_type: Option<String>,
#[allow(dead_code)]
expires_in: Option<i64>,
}
#[derive(Debug, Deserialize)]
struct IdTokenClaims {
email: Option<String>,
name: Option<String>,
sub: Option<String>,
oid: Option<String>,
preferred_username: Option<String>,
}
#[derive(Debug, sqlx::FromRow)]
struct DbUserForSso {
id: Uuid,
username: String,
display_name: String,
role: String,
is_active: bool,
mfa_enabled: bool,
}
/// OIDC provider configuration from database.
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct OidcConfig {
pub enabled: bool,
pub provider_type: String,
pub display_name: String,
pub discovery_url: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub scopes: String,
}
/// Cached OIDC discovery document.
#[derive(Debug, Clone)]
pub struct OidcDiscovery {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub jwks_uri: String,
pub userinfo_endpoint: Option<String>,
pub fetched_at: chrono::DateTime<Utc>,
}
/// Cache for OIDC discovery documents and JWKS with TTL-based refresh.
#[derive(Default)]
pub struct OidcCache {
pub discovery: Option<OidcDiscovery>,
pub jwks: Option<serde_json::Value>,
pub jwks_fetched_at: Option<chrono::DateTime<Utc>>,
}
/// JWKS cache TTL in seconds (1 hour).
const JWKS_CACHE_TTL_SECS: i64 = 3600;
/// Discovery cache TTL in seconds (1 hour).
const DISCOVERY_CACHE_TTL_SECS: i64 = 3600;
// ============================================================
// Router
// ============================================================
pub fn public_router() -> Router<AppState> {
Router::new()
.route("/login", get(sso_login))
.route("/callback", get(sso_callback))
.route("/config", get(sso_config))
}
/// Backward-compatible Azure SSO routes — redirect to generic SSO endpoints.
pub fn azure_compat_router() -> Router<AppState> {
Router::new()
.route("/login", get(azure_login_redirect))
.route("/callback", get(azure_callback_redirect))
}
// ============================================================
// GET /api/v1/auth/sso/config
// ============================================================
/// Public endpoint returning minimal SSO configuration for the login page.
/// Returns only: enabled, display_name, auth_url — no secrets exposed.
async fn sso_config(
State(state): State<AppState>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let config = match load_oidc_config(&state.db).await {
Ok(c) => c,
Err(_) => {
// If we can't load config, SSO is effectively disabled
return Ok(Json(json!({
"enabled": false,
"display_name": "SSO",
"auth_url": ""
})));
},
};
Ok(Json(json!({
"enabled": config.enabled,
"display_name": if config.display_name.is_empty() { "SSO".to_string() } else { config.display_name },
"auth_url": "/api/v1/auth/sso/login"
})))
}
// ============================================================
// GET /api/v1/auth/sso/login
// ============================================================
async fn sso_login(
State(state): State<AppState>,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
let config = load_oidc_config(&state.db).await?;
if !config.enabled {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "SSO is not enabled" } })),
));
}
if config.discovery_url.is_empty() {
return Err((
StatusCode::FORBIDDEN,
Json(
json!({ "error": { "code": "forbidden", "message": "SSO discovery URL is not configured" } }),
),
));
}
// Fetch OIDC discovery document (with caching)
let discovery = match fetch_discovery(&state).await {
Ok(d) => d,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(
json!({ "error": { "code": "internal_error", "message": format!("Failed to fetch OIDC discovery: {}", e) } }),
),
));
},
};
// Generate PKCE code_verifier (32 random bytes → base64url)
let mut verifier_bytes = [0u8; 32];
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut verifier_bytes);
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
// code_challenge = BASE64URL(SHA256(code_verifier))
let challenge_digest = Sha256::digest(code_verifier.as_bytes());
let code_challenge = URL_SAFE_NO_PAD.encode(challenge_digest);
// Generate state token
let state_token = Uuid::new_v4().to_string();
// Store (state_token, code_verifier) in sso_sessions DashMap
state.sso_sessions.insert(
state_token.clone(),
SsoSession {
code_verifier,
created_at: Utc::now(),
},
);
// Build authorization URL from discovery
let encoded_scopes = urlencoding::encode(&config.scopes);
let auth_url = format!(
"{}?client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
discovery.authorization_endpoint,
urlencoding::encode(&config.client_id),
urlencoding::encode(&config.redirect_uri),
encoded_scopes,
code_challenge,
state_token
);
Ok(Redirect::to(&auth_url))
}
// ============================================================
// GET /api/v1/auth/sso/callback
// ============================================================
#[derive(Debug, Deserialize)]
struct CallbackParams {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
async fn sso_callback(
State(state): State<AppState>,
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
) -> Result<Redirect, Redirect> {
let callback_url = &state.config.security.sso_callback_url;
let error_redirect = |code: &str, message: &str| -> Redirect {
let url = format!(
"{}?error={}&error_description={}",
callback_url,
urlencoding::encode(code),
urlencoding::encode(message)
);
Redirect::to(&url)
};
if let Some(error) = params.error {
let desc = params.error_description.unwrap_or_default();
let message = format!("OIDC provider error: {} - {}", error, desc);
return Err(error_redirect("sso_error", &message));
}
let code = match params.code {
Some(c) => c,
None => return Err(error_redirect("bad_request", "Missing authorization code")),
};
let state_token = match params.state {
Some(s) => s,
None => return Err(error_redirect("bad_request", "Missing state parameter")),
};
let sso_session = match state.sso_sessions.remove(&state_token).map(|(_, v)| v) {
Some(s) => s,
None => {
return Err(error_redirect(
"bad_request",
"Invalid or expired state token",
))
},
};
let config = match load_oidc_config(&state.db).await {
Ok(c) => c,
Err(_) => {
return Err(error_redirect(
"internal_error",
"Failed to load OIDC config",
))
},
};
let discovery = match fetch_discovery(&state).await {
Ok(d) => d,
Err(e) => {
tracing::error!(error = %e, "Failed to fetch OIDC discovery");
return Err(error_redirect(
"internal_error",
"Failed to fetch OIDC discovery",
));
},
};
// Exchange code for tokens
let client = match reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
{
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Failed to build HTTP client");
return Err(error_redirect("internal_error", "HTTP client error"));
},
};
let mut params_vec: Vec<(&str, String)> = vec![
("grant_type", "authorization_code".to_string()),
("code", code.clone()),
("redirect_uri", config.redirect_uri.clone()),
("client_id", config.client_id.clone()),
("code_verifier", sso_session.code_verifier.clone()),
];
// For confidential clients (Azure AD), include client_secret
if !config.client_secret.is_empty() {
params_vec.push(("client_secret", config.client_secret.clone()));
}
let token_resp = match client
.post(&discovery.token_endpoint)
.form(&params_vec)
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Token exchange request failed");
return Err(error_redirect(
"sso_error",
&format!("Token exchange failed: {}", e),
));
},
};
if !token_resp.status().is_success() {
let status = token_resp.status();
let body = token_resp.text().await.unwrap_or_default();
tracing::error!(status = %status, body = %body, "Token exchange failed");
return Err(error_redirect(
"sso_error",
&format!("Token exchange failed: HTTP {}", status),
));
}
let token_data: TokenResponse = match token_resp.json().await {
Ok(d) => d,
Err(e) => {
tracing::error!(error = %e, "Failed to parse token response");
return Err(error_redirect(
"internal_error",
"Failed to parse token response",
));
},
};
let id_token = match token_data.id_token {
Some(t) => t,
None => return Err(error_redirect("sso_error", "No id_token in response")),
};
let claims = match verify_id_token(&id_token, &config, &discovery, &state.oidc_cache).await {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Failed to verify id_token");
return Err(error_redirect(
"internal_error",
"Failed to verify id_token",
));
},
};
let email = claims.email.unwrap_or_default();
let name = claims.name.unwrap_or_default();
let oidc_sub = claims.sub.unwrap_or_default();
let azure_oid = claims.oid.unwrap_or_default();
let preferred_username = claims.preferred_username.unwrap_or_else(|| email.clone());
let provider_subject = if !oidc_sub.is_empty() {
oidc_sub.clone()
} else if !azure_oid.is_empty() {
azure_oid.clone()
} else {
return Err(error_redirect(
"sso_error",
"Missing subject identifier in id_token",
));
};
if email.is_empty() {
return Err(error_redirect("sso_error", "Missing email in id_token"));
}
let auth_provider = match config.provider_type.as_str() {
"keycloak" => "keycloak",
"azure" => "azure_sso",
_ => "oidc",
};
// First try exact match: email AND auth_provider
let user_opt: Option<DbUserForSso> = match sqlx::query_as(
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
FROM users WHERE email = $1 AND auth_provider = $2::auth_provider"#,
)
.bind(&email)
.bind(auth_provider)
.fetch_optional(&state.db)
.await
{
Ok(o) => o,
Err(e) => {
tracing::error!(error = %e, "Failed to look up SSO user");
return Err(error_redirect("internal_error", "Database error"));
},
};
let user = match user_opt {
Some(u) if !u.is_active => {
return Err(error_redirect("account_disabled", "Account is disabled"));
},
Some(u) => u,
None => {
// Try to find existing user by email alone (may have different auth_provider)
let existing_user: Option<DbUserForSso> = match sqlx::query_as(
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
FROM users WHERE email = $1"#,
)
.bind(&email)
.fetch_optional(&state.db)
.await
{
Ok(o) => o,
Err(e) => {
tracing::error!(error = %e, "Failed to look up existing user by email");
return Err(error_redirect("internal_error", "Database error"));
},
};
match existing_user {
Some(existing) if !existing.is_active => {
return Err(error_redirect("account_disabled", "Account is disabled"));
},
Some(existing) => {
// Link existing local user to SSO provider
tracing::info!(user_id = %existing.id, "Linking existing user to SSO provider");
if let Err(e) = sqlx::query(
"UPDATE users SET auth_provider = $1::auth_provider, azure_oid = COALESCE(azure_oid, $2), oidc_sub = COALESCE(oidc_sub, $3) WHERE id = $4",
)
.bind(auth_provider)
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
.bind(existing.id)
.execute(&state.db)
.await
{
tracing::error!(error = %e, "Failed to link user to SSO provider");
return Err(error_redirect("internal_error", "Failed to link SSO account"));
}
log_event(
&state.db,
AuditAction::UserCreated,
None,
Some(auth_provider),
Some("user"),
Some(&existing.id.to_string()),
json!({ "action": "sso_link", "auth_provider": auth_provider, "email": email }),
None,
None,
)
.await;
DbUserForSso {
id: existing.id,
username: existing.username.clone(),
display_name: if name.is_empty() {
existing.display_name.clone()
} else {
name
},
role: existing.role.clone(),
is_active: existing.is_active,
mfa_enabled: existing.mfa_enabled,
}
},
None => {
// No existing user - create new one
let id: Uuid = match sqlx::query_scalar(
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid, oidc_sub)
VALUES ($1, $2, $3, 'reporter'::user_role, $4::auth_provider, $5, $6)
RETURNING id"#,
)
.bind(&preferred_username)
.bind(&name)
.bind(&email)
.bind(auth_provider)
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
.fetch_one(&state.db)
.await
{
Ok(id) => id,
Err(e) => {
tracing::error!(error = %e, "Failed to create SSO user");
return Err(error_redirect("internal_error", "Failed to create user"));
},
};
log_event(
&state.db,
AuditAction::UserCreated,
None,
Some(auth_provider),
Some("user"),
Some(&id.to_string()),
json!({ "auth_provider": auth_provider, "email": email }),
None,
None,
)
.await;
DbUserForSso {
id,
username: preferred_username,
display_name: name,
role: "reporter".to_string(),
is_active: true,
mfa_enabled: false,
}
},
}
},
};
// Update last_login_at and provider subject IDs
if let Err(e) = sqlx::query(
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1), oidc_sub = COALESCE(oidc_sub, $2) WHERE id = $3",
)
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
.bind(user.id)
.execute(&state.db)
.await
{
tracing::error!(error = %e, "Failed to update last_login_at");
return Err(error_redirect("internal_error", "Database error"));
}
let access_ttl = state.config.security.jwt_access_ttl_secs as i64;
let access_token = match issue_access_token(
user.id,
&user.username,
&user.role,
access_ttl,
&state.signing_key_pem,
) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "Failed to issue access token");
return Err(error_redirect("internal_error", "Token issuance failed"));
},
};
let raw_refresh = match refresh::issue(&state.db, user.id, None, None).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "Failed to issue refresh token");
return Err(error_redirect(
"internal_error",
"Refresh token issuance failed",
));
},
};
log_event(
&state.db,
AuditAction::UserLogin,
Some(user.id),
Some(&user.username),
None,
None,
json!({ "auth_provider": auth_provider }),
None,
None,
)
.await;
let user_json = json!({
"id": user.id.to_string(),
"username": user.username,
"display_name": user.display_name,
"role": user.role,
"auth_provider": auth_provider,
"mfa_enabled": user.mfa_enabled,
});
let redirect_url = format!(
"{}?access_token={}&refresh_token={}&token_type=Bearer&expires_in={}&user={}",
callback_url,
urlencoding::encode(&access_token),
urlencoding::encode(&raw_refresh.0),
access_ttl,
urlencoding::encode(&user_json.to_string()),
);
Ok(Redirect::to(&redirect_url))
}
// ============================================================
// Backward-compatible Azure SSO redirect handlers
// ============================================================
async fn azure_login_redirect(
State(state): State<AppState>,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
sso_login(State(state)).await
}
async fn azure_callback_redirect(
State(state): State<AppState>,
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
) -> Result<Redirect, Redirect> {
sso_callback(State(state), axum::extract::Query(params)).await
}
// ============================================================
// Database helpers
// ============================================================
async fn load_oidc_config(pool: &sqlx::PgPool) -> Result<OidcConfig, (StatusCode, Json<Value>)> {
let row: Option<OidcConfig> = sqlx::query_as(
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
)
.fetch_optional(pool)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to load oidc_config");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
Ok(row.unwrap_or(OidcConfig {
enabled: false,
provider_type: "azure".to_string(),
display_name: "Azure AD".to_string(),
discovery_url: String::new(),
client_id: String::new(),
client_secret: String::new(),
redirect_uri: String::new(),
scopes: "openid profile email".to_string(),
}))
}
// ============================================================
// OIDC Discovery & JWKS
// ============================================================
async fn fetch_discovery(state: &AppState) -> Result<OidcDiscovery, String> {
let config = match load_oidc_config(&state.db).await {
Ok(c) => c,
Err(_) => {
return Err("Failed to load OIDC config".to_string());
},
};
let discovery_url = config.discovery_url;
// Check cache first
{
let cache = state.oidc_cache.lock().await;
if let Some(ref disc) = cache.discovery {
let elapsed = Utc::now().signed_duration_since(disc.fetched_at);
if elapsed.num_seconds() < DISCOVERY_CACHE_TTL_SECS {
return Ok(disc.clone());
}
}
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;
let resp = client
.get(&discovery_url)
.send()
.await
.map_err(|e| format!("Discovery fetch failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!(
"Discovery fetch failed: HTTP {}{}",
status, body
));
}
let doc: Value = resp
.json()
.await
.map_err(|e| format!("Failed to parse discovery document: {}", e))?;
let discovery = OidcDiscovery {
issuer: doc
.get("issuer")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
authorization_endpoint: doc
.get("authorization_endpoint")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
token_endpoint: doc
.get("token_endpoint")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
jwks_uri: doc
.get("jwks_uri")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
userinfo_endpoint: doc
.get("userinfo_endpoint")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
fetched_at: Utc::now(),
};
{
let mut cache = state.oidc_cache.lock().await;
cache.discovery = Some(discovery.clone());
}
Ok(discovery)
}
async fn verify_id_token(
token: &str,
config: &OidcConfig,
discovery: &OidcDiscovery,
oidc_cache: &Arc<Mutex<OidcCache>>,
) -> Result<IdTokenClaims, String> {
let header = decode_header(token).map_err(|e| format!("Failed to decode JWT header: {}", e))?;
let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
let jwks = {
let cache = oidc_cache.lock().await;
let needs_fetch = match (&cache.jwks, &cache.jwks_fetched_at) {
(None, _) => true,
(Some(_), None) => true,
(Some(_), Some(fetched)) => {
let elapsed = Utc::now().signed_duration_since(*fetched);
elapsed.num_seconds() > JWKS_CACHE_TTL_SECS
},
};
if needs_fetch {
drop(cache);
let jwks_value = fetch_jwks(&discovery.jwks_uri).await?;
let mut cache = oidc_cache.lock().await;
cache.jwks = Some(jwks_value);
cache.jwks_fetched_at = Some(Utc::now());
cache.jwks.clone().unwrap()
} else {
cache.jwks.clone().unwrap()
}
};
let keys_array = jwks
.get("keys")
.ok_or("JWKS response missing 'keys' array")?
.as_array()
.ok_or("JWKS 'keys' is not an array")?;
let jwk = keys_array
.iter()
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()))
.ok_or_else(|| format!("No matching JWK found for kid: {}", kid))?;
let n = jwk
.get("n")
.and_then(|v| v.as_str())
.ok_or("JWK missing 'n' (modulus) field")?;
let e = jwk
.get("e")
.and_then(|v| v.as_str())
.ok_or("JWK missing 'e' (exponent) field")?;
let decoding_key = DecodingKey::from_rsa_components(n, e)
.map_err(|e| format!("Failed to construct RSA decoding key: {}", e))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.iss = Some(HashSet::from([discovery.issuer.clone()]));
validation.aud = Some(HashSet::from([config.client_id.clone()]));
validation.leeway = 60;
let token_data = decode::<IdTokenClaims>(token, &decoding_key, &validation)
.map_err(|e| format!("JWT signature verification failed: {}", e))?;
Ok(token_data.claims)
}
async fn fetch_jwks(jwks_uri: &str) -> Result<serde_json::Value, String> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| format!("Failed to build HTTP client for JWKS fetch: {}", e))?;
let resp = client
.get(jwks_uri)
.send()
.await
.map_err(|e| format!("JWKS fetch request failed: {}", e))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("JWKS fetch failed: HTTP {}{}", status, body));
}
resp.json::<serde_json::Value>()
.await
.map_err(|e| format!("Failed to parse JWKS response: {}", e))
}

View File

@ -0,0 +1,145 @@
//! Fleet status routes.
//!
//! GET /api/v1/status/fleet — aggregate health and patch summary across all hosts.
use axum::{extract::State, http::StatusCode, response::Json, routing::get, Router};
use serde::Serialize;
use serde_json::{json, Value};
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new().route("/fleet", get(fleet_status))
}
// ── Response type ─────────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
pub struct FleetStatus {
pub total_hosts: i64,
pub healthy: i64,
pub degraded: i64,
pub unreachable: i64,
pub pending: i64,
pub total_pending_patches: i64,
pub hosts_requiring_reboot: i64,
pub compliance_pct: f64,
}
// ── GET /api/v1/status/fleet ──────────────────────────────────────────────────
pub async fn fleet_status(
State(state): State<AppState>,
) -> Result<Json<FleetStatus>, (StatusCode, Json<Value>)> {
// ── 1. Host health aggregates ─────────────────────────────────────────
let health_row: (i64, i64, i64, i64, i64) = sqlx::query_as(
r#"
SELECT
COUNT(*) AS total_hosts,
COUNT(*) FILTER (WHERE health_status = 'healthy') AS healthy,
COUNT(*) FILTER (WHERE health_status = 'degraded') AS degraded,
COUNT(*) FILTER (WHERE health_status = 'unreachable') AS unreachable,
COUNT(*) FILTER (WHERE health_status = 'pending') AS pending
FROM hosts
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query host health aggregates");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let (total_hosts, healthy, degraded, unreachable, pending) = health_row;
// ── 2. Total pending patches across fleet (latest row per host) ───────
let total_pending_patches: i64 = sqlx::query_scalar(
r#"
SELECT COALESCE(SUM(patch_count), 0)
FROM (
SELECT DISTINCT ON (host_id) patch_count
FROM host_patch_data
ORDER BY host_id, polled_at DESC
) latest
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query total pending patches");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
// ── 3. Hosts requiring a reboot (latest patch row per host) ───────────
let hosts_requiring_reboot: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM (
SELECT DISTINCT ON (host_id) available_patches
FROM host_patch_data
ORDER BY host_id, polled_at DESC
) latest
WHERE available_patches @> '[{"requires_reboot": true}]'
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query reboot-required hosts");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
// ── 4. Compliance: hosts with zero pending patches / total hosts ───────
// Hosts that have been polled and have patch_count == 0 are considered
// compliant. Hosts with no patch data at all are excluded from the
// compliance calculation.
let compliant_hosts: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM (
SELECT DISTINCT ON (host_id) patch_count
FROM host_patch_data
ORDER BY host_id, polled_at DESC
) latest
WHERE patch_count = 0
"#,
)
.fetch_one(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "fleet_status: failed to query compliant hosts");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
let compliance_pct = if total_hosts == 0 {
100.0_f64
} else {
(compliant_hosts as f64 / total_hosts as f64) * 100.0
};
// Round to one decimal place.
let compliance_pct = (compliance_pct * 10.0).round() / 10.0;
Ok(Json(FleetStatus {
total_hosts,
healthy,
degraded,
unreachable,
pending,
total_pending_patches,
hosts_requiring_reboot,
compliance_pct,
}))
}

571
crates/pm-web/src/routes/users.rs Executable file
View File

@ -0,0 +1,571 @@
//! User management routes.
//!
//! GET /api/v1/users — list users (admin only)
//! POST /api/v1/users — create user (admin only)
//! GET /api/v1/users/:id — get user detail
//! PUT /api/v1/users/:id — update user
//! DELETE /api/v1/users/:id — delete user (admin only)
//! GET /api/v1/users/me — current user profile
//! POST /api/v1/users/:id/revoke — revoke all sessions (admin only)
use axum::{
extract::{Path, State},
http::StatusCode,
response::Json,
routing::{delete, get, post, put},
Router,
};
use pm_auth::validate_password_strength;
use pm_auth::{hash_password, rbac::AuthUser, session::force_logout, verify_password};
use pm_core::{
audit::{log_event, AuditAction},
models::{
AdminResetPasswordRequest, ChangePasswordRequest, CreateUserRequest, UpdateUserRequest,
User,
},
};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::AppState;
pub fn router() -> Router<AppState> {
Router::new()
.route("/", get(list_users).post(create_user))
.route("/me", get(get_current_user))
.route("/me/password", put(change_own_password))
.route("/{id}", get(get_user).put(update_user).delete(delete_user))
.route("/{id}/password", put(admin_reset_password))
.route("/{id}/mfa", delete(admin_disable_mfa))
.route("/{id}/revoke", post(revoke_user_sessions))
}
async fn list_users(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Vec<User>>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
sqlx::query_as::<_, User>(
r#"SELECT id, username, display_name, email, role, auth_provider,
mfa_enabled, is_active, force_password_reset, last_login_at,
created_at, updated_at
FROM users ORDER BY username"#,
)
.fetch_all(&state.db)
.await
.map(Json)
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})
}
async fn create_user(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<CreateUserRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
// Validate password strength
if let Err(msg) = validate_password_strength(&req.password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
let hash = hash_password(&req.password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
let role = match req.role.to_lowercase().as_str() {
"admin" => "admin",
"reporter" => "reporter",
_ => "operator",
};
let id: Uuid = sqlx::query_scalar(
r#"INSERT INTO users (username, display_name, email, role, auth_provider, password_hash)
VALUES ($1, $2, $3, $4::user_role, 'local', $5)
RETURNING id"#,
)
.bind(&req.username)
.bind(req.display_name.as_deref().unwrap_or(&req.username))
.bind(&req.email)
.bind(role)
.bind(&hash)
.fetch_one(&state.db)
.await
.map_err(|e| {
let msg = if e.to_string().contains("unique") {
"Username or email already exists".to_string()
} else {
"Database error".to_string()
};
(
StatusCode::CONFLICT,
Json(json!({ "error": { "code": "conflict", "message": msg } })),
)
})?;
log_event(
&state.db,
AuditAction::UserCreated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({ "username": req.username }),
None,
None,
)
.await;
Ok(Json(json!({ "id": id, "message": "User created" })))
}
async fn get_current_user(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
fetch_user(&state.db, auth.user_id).await
}
async fn get_user(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
// Users can see themselves; admin can see anyone
if !auth.role.is_admin() && auth.user_id != id {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
fetch_user(&state.db, id).await
}
async fn fetch_user(
pool: &sqlx::PgPool,
id: Uuid,
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
let user: Option<User> = sqlx::query_as(
r#"SELECT id, username, display_name, email, role, auth_provider,
mfa_enabled, is_active, force_password_reset, last_login_at,
created_at, updated_at
FROM users WHERE id = $1"#,
)
.bind(id)
.fetch_optional(pool)
.await
.map_err(|e| {
tracing::error!(error = %e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?;
user.map(Json).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
)
})
}
async fn update_user(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<UpdateUserRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() && auth.user_id != id {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
));
}
// Only admins can change role or active status
if (req.role.is_some() || req.is_active.is_some() || req.force_password_reset.is_some())
&& !auth.role.is_admin()
{
return Err((
StatusCode::FORBIDDEN,
Json(
json!({ "error": { "code": "forbidden", "message": "Admin role required to change role, status, or force_password_reset" } }),
),
));
}
let role_str = req
.role
.as_deref()
.map(|r| match r.to_lowercase().as_str() {
"admin" => "admin",
"reporter" => "reporter",
_ => "operator",
});
let rows = sqlx::query(
r#"UPDATE users SET
display_name = COALESCE($1, display_name),
email = COALESCE($2, email),
role = COALESCE($3::user_role, role),
is_active = COALESCE($4, is_active),
force_password_reset = COALESCE($5, force_password_reset),
updated_at = NOW()
WHERE id = $6"#,
)
.bind(req.display_name.as_deref())
.bind(req.email.as_deref())
.bind(role_str)
.bind(req.is_active)
.bind(req.force_password_reset)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({}),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User updated" })))
}
async fn delete_user(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
if auth.user_id == id {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "bad_request", "message": "Cannot delete your own account" } }),
),
));
}
let rows = sqlx::query("DELETE FROM users WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
log_event(
&state.db,
AuditAction::UserDeleted,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({}),
None,
None,
)
.await;
Ok(Json(json!({ "message": "User deleted" })))
}
async fn revoke_user_sessions(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
let count = force_logout(&state.db, id).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
Ok(Json(
json!({ "message": "Sessions revoked", "count": count }),
))
}
// ============================================================
// PUT /api/v1/users/me/password — change own password
// ============================================================
async fn change_own_password(
State(state): State<AppState>,
auth: AuthUser,
Json(req): Json<ChangePasswordRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// Fetch current password hash
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
.bind(auth.user_id)
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to fetch password hash");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
)
})?
.flatten();
let hash_str = hash.unwrap_or_default();
let valid = verify_password(&req.current_password, &hash_str).unwrap_or(false);
if !valid {
return Err((
StatusCode::BAD_REQUEST,
Json(
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
),
));
}
// Validate new password strength
if let Err(msg) = validate_password_strength(&req.new_password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
let new_hash = hash_password(&req.new_password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
sqlx::query(
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, updated_at = NOW() WHERE id = $2",
)
.bind(&new_hash)
.bind(auth.user_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to update password");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
)
})?;
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&auth.user_id.to_string()),
json!({ "action": "password_change" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Password changed successfully" })))
}
// ============================================================
// PUT /api/v1/users/:id/password — admin reset password
// ============================================================
async fn admin_reset_password(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<AdminResetPasswordRequest>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
// Verify target user exists
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)")
.bind(id)
.fetch_one(&state.db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
if !exists {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
// Validate new password strength
if let Err(msg) = validate_password_strength(&req.new_password) {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
));
}
let new_hash = hash_password(&req.new_password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
)
})?;
sqlx::query(
"UPDATE users SET password_hash = $1, force_password_reset = $2, updated_at = NOW() WHERE id = $3",
)
.bind(&new_hash)
.bind(req.force_password_reset)
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to reset password");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to reset password" } })),
)
})?;
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({ "action": "admin_password_reset", "force_password_reset": req.force_password_reset }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "Password reset successfully" })))
}
// ============================================================
// DELETE /api/v1/users/:id/mfa — admin disable MFA
// ============================================================
async fn admin_disable_mfa(
State(state): State<AppState>,
auth: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
));
}
let rows = sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE, updated_at = NOW() WHERE id = $1")
.bind(id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to disable MFA");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
)
})?
.rows_affected();
if rows == 0 {
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
));
}
log_event(
&state.db,
AuditAction::UserUpdated,
Some(auth.user_id),
Some(&auth.username),
Some("user"),
Some(&id.to_string()),
json!({ "action": "admin_mfa_disabled" }),
None,
None,
)
.await;
Ok(Json(json!({ "message": "MFA disabled successfully" })))
}

205
crates/pm-web/src/routes/ws.rs Executable file
View File

@ -0,0 +1,205 @@
//! WebSocket relay routes — M7
//!
//! POST /api/v1/ws/ticket — create a single-use WS auth ticket (JWT-protected)
//! GET /api/v1/ws/jobs — browser WebSocket endpoint (ticket-authenticated)
use axum::{
extract::ws::{Message, WebSocket},
extract::{Query, State, WebSocketUpgrade},
http::StatusCode,
response::{Json, Response},
routing::{get, post},
Router,
};
use chrono::{Duration, Utc};
use pm_auth::rbac::AuthUser;
use serde::Deserialize;
use serde_json::{json, Value};
use sqlx::postgres::PgListener;
use ulid::Ulid;
use uuid::Uuid;
use crate::AppState;
// ── WsTicket ──────────────────────────────────────────────────────────────────
/// Single-use WebSocket authentication ticket stored in-memory.
#[derive(Debug, Clone)]
pub struct WsTicket {
pub user_id: Uuid,
pub role: String,
pub expires_at: chrono::DateTime<Utc>,
}
// ── Router ────────────────────────────────────────────────────────────────────
/// Router for ticket-issuance endpoint (JWT-protected, merged into protected_api).
pub fn ticket_router() -> Router<AppState> {
Router::new().route("/ws/ticket", post(create_ticket_handler))
}
/// Router for the WebSocket endpoint (ticket-authenticated, NO JWT middleware).
pub fn ws_router() -> Router<AppState> {
Router::new().route("/api/v1/ws/jobs", get(ws_handler))
}
// ── Error helper ─────────────────────────────────────────────────────────────
#[inline]
fn err(
status: StatusCode,
code: &'static str,
message: impl Into<String>,
) -> (StatusCode, Json<Value>) {
(
status,
Json(json!({ "error": { "code": code, "message": message.into() } })),
)
}
// ── POST /api/v1/ws/ticket ────────────────────────────────────────────────────
/// Issue a single-use WebSocket authentication ticket (60 s expiry).
pub async fn create_ticket_handler(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
let ticket_id = Ulid::new().to_string();
let expires_at = Utc::now() + Duration::seconds(60);
let ticket = WsTicket {
user_id: auth.user_id,
role: auth.role.as_str().to_string(),
expires_at,
};
state.ws_tickets.insert(ticket_id.clone(), ticket);
tracing::info!(
user_id = %auth.user_id,
username = %auth.username,
ticket = %ticket_id,
"WS ticket issued"
);
Ok(Json(json!({ "ticket": ticket_id })))
}
// ── GET /api/v1/ws/jobs ───────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
pub struct WsQuery {
pub ticket: String,
}
/// Browser WebSocket upgrade endpoint — authenticates via single-use ticket.
pub async fn ws_handler(
State(state): State<AppState>,
Query(q): Query<WsQuery>,
ws: WebSocketUpgrade,
) -> Result<Response, (StatusCode, Json<Value>)> {
// Validate and consume the ticket atomically.
let ticket = {
let entry = state.ws_tickets.get(&q.ticket);
match entry {
None => {
return Err(err(
StatusCode::UNAUTHORIZED,
"invalid_ticket",
"WebSocket ticket not found or already used",
));
},
Some(t) => {
if t.expires_at < Utc::now() {
drop(t);
state.ws_tickets.remove(&q.ticket);
return Err(err(
StatusCode::UNAUTHORIZED,
"ticket_expired",
"WebSocket ticket has expired",
));
}
t.clone()
},
}
};
// Single-use: remove immediately after validation.
state.ws_tickets.remove(&q.ticket);
tracing::info!(
user_id = %ticket.user_id,
role = %ticket.role,
"Browser WebSocket connection upgraded"
);
let db = state.db.clone();
Ok(ws.on_upgrade(move |socket| handle_browser_ws(socket, db, ticket)))
}
// ── WebSocket handler ─────────────────────────────────────────────────────────
/// Drive the browser WebSocket: LISTEN on `job_update` and forward payloads.
async fn handle_browser_ws(mut socket: WebSocket, db: sqlx::PgPool, ticket: WsTicket) {
// Acquire a dedicated PG listener connection.
let mut listener = match PgListener::connect_with(&db).await {
Ok(l) => l,
Err(e) => {
tracing::error!(error = %e, user_id = %ticket.user_id, "Failed to create PgListener");
let _ = socket
.send(Message::Text(
json!({ "error": "internal_error" }).to_string().into(),
))
.await;
return;
},
};
if let Err(e) = listener.listen("job_update").await {
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener LISTEN failed");
return;
}
tracing::info!(user_id = %ticket.user_id, "Browser WS: LISTEN job_update started");
loop {
tokio::select! {
// Forward PG notifications to the browser.
notify_result = listener.recv() => {
match notify_result {
Ok(notification) => {
let payload = notification.payload().to_string();
tracing::debug!(user_id = %ticket.user_id, payload = %payload, "Forwarding job_update");
if socket.send(Message::Text(payload.into())).await.is_err() {
tracing::info!(user_id = %ticket.user_id, "Browser WS send failed — client disconnected");
break;
}
}
Err(e) => {
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener recv error");
break;
}
}
}
// Handle incoming frames from the browser (ping/close).
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => {
tracing::info!(user_id = %ticket.user_id, "Browser WS closed by client");
break;
}
Some(Ok(Message::Ping(data))) if socket.send(Message::Pong(data.clone())).await.is_err() => {
break;
}
Some(Err(e)) => {
tracing::debug!(error = %e, user_id = %ticket.user_id, "Browser WS recv error");
break;
}
_ => {}
}
}
}
}
tracing::info!(user_id = %ticket.user_id, "Browser WS handler exiting");
}

View File

@ -0,0 +1,31 @@
[package]
name = "pm-worker"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[[bin]]
name = "pm-worker"
path = "src/main.rs"
[dependencies]
pm-core = { path = "../pm-core" }
pm-agent-client = { path = "../pm-agent-client" }
tokio = { workspace = true, features = ["full"] }
sqlx = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
futures = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { version = "0.26" }
rustls-pemfile = { version = "2" }
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
reqwest = { workspace = true }

View File

@ -0,0 +1,45 @@
//! Helper for loading mTLS certificate/key material from disk.
//!
//! Reads PEM files referenced in [`SecurityConfig`] and returns the raw bytes
//! needed by [`pm_agent_client::AgentClient`].
use pm_core::config::SecurityConfig;
/// Raw PEM bytes for mTLS client authentication and CA verification.
pub struct AgentCerts {
pub client_cert: Vec<u8>,
pub client_key: Vec<u8>,
pub ca_cert: Vec<u8>,
}
/// Load agent mTLS certificates from the paths specified in [`SecurityConfig`].
///
/// Returns an error if any file cannot be read. The caller should handle
/// the error gracefully (log and skip the poll cycle) rather than crashing.
pub fn load_agent_certs(security: &SecurityConfig) -> anyhow::Result<AgentCerts> {
let client_cert = std::fs::read(&security.agent_client_cert_path).map_err(|e| {
anyhow::anyhow!(
"Failed to read agent client cert '{}': {}",
security.agent_client_cert_path,
e
)
})?;
let client_key = std::fs::read(&security.agent_client_key_path).map_err(|e| {
anyhow::anyhow!(
"Failed to read agent client key '{}': {}",
security.agent_client_key_path,
e
)
})?;
let ca_cert = std::fs::read(&security.ca_cert_path).map_err(|e| {
anyhow::anyhow!("Failed to read CA cert '{}': {}", security.ca_cert_path, e)
})?;
Ok(AgentCerts {
client_cert,
client_key,
ca_cert,
})
}

View File

@ -0,0 +1,86 @@
//! Periodic audit log integrity verification.
//!
//! Runs every 24 hours, walks the audit_log rows ordered by id,
//! verifies each row_hash matches the recomputed hash, and logs the
//! result as an `AuditIntegrityVerified` event. If tampering is
//! detected, logs an error and creates an alert.
use std::sync::Arc;
use std::time::Duration;
use sqlx::PgPool;
use pm_core::audit::{log_event, verify_integrity, AuditAction};
use pm_core::config::AppConfig;
/// Run the audit integrity verifier every 24 hours.
pub async fn run_audit_verifier(pool: PgPool, _config: Arc<AppConfig>) {
tracing::info!("Audit integrity verifier started");
// Run immediately on startup
verify_once(&pool).await;
let mut interval = tokio::time::interval(Duration::from_secs(24 * 60 * 60));
loop {
interval.tick().await;
tracing::info!("Running scheduled audit integrity verification");
verify_once(&pool).await;
}
}
/// Run a single integrity verification pass.
async fn verify_once(pool: &PgPool) {
let result = verify_integrity(pool).await;
if result.intact {
tracing::info!(
rows_checked = result.rows_checked,
"Audit integrity verification passed"
);
} else {
tracing::error!(
rows_checked = result.rows_checked,
error_count = result.errors.len(),
"Audit integrity verification FAILED — tampering detected!"
);
for err in &result.errors {
tracing::error!(
row_id = err.row_id,
expected_hash = %err.expected_hash,
actual_hash = %err.actual_hash,
"Audit chain integrity error"
);
}
}
// Log the verification event
log_event(
pool,
AuditAction::AuditIntegrityVerified,
None,
None,
Some("audit_log"),
None,
serde_json::json!({
"intact": result.intact,
"rows_checked": result.rows_checked,
"error_count": result.errors.len(),
"errors": result.errors.iter().take(10).map(|e| serde_json::json!({
"row_id": e.row_id,
"expected_hash": e.expected_hash,
"actual_hash": e.actual_hash,
})).collect::<Vec<_>>(),
}),
None,
None,
)
.await;
// Update last verified timestamp
let _ = sqlx::query(
"UPDATE system_config SET value = NOW()::text, updated_at = NOW() WHERE key = 'audit_integrity_last_verified'",
)
.execute(pool)
.await;
}

331
crates/pm-worker/src/email.rs Executable file
View File

@ -0,0 +1,331 @@
//! Email notification module.
//!
//! Loads SMTP configuration from `system_config` and sends notification emails
//! for patch job events (completion, failure) and maintenance window reminders.
//! All emails are optional and disabled by default via `notification_email_enabled`.
use lettre::{
message::{header::ContentType, Mailbox},
transport::smtp::authentication::Credentials,
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
};
use sqlx::PgPool;
use pm_core::audit::{log_event, AuditAction};
/// SMTP configuration loaded from `system_config`.
struct SmtpSettings {
enabled: bool,
host: String,
port: u16,
username: String,
password: String,
from: String,
tls_mode: String,
}
/// Notification preferences loaded from `system_config`.
struct NotificationSettings {
email_enabled: bool,
email_from: String,
recipients: Vec<String>,
}
/// Load SMTP settings from the `system_config` table.
async fn load_smtp_settings(pool: &PgPool) -> SmtpSettings {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT key, value FROM system_config WHERE key IN (
'smtp_enabled', 'smtp_host', 'smtp_port', 'smtp_username',
'smtp_password', 'smtp_from', 'smtp_tls_mode'
)",
)
.fetch_all(pool)
.await
.unwrap_or_default();
let get = |key: &str| -> String {
rows.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.clone())
.unwrap_or_default()
};
SmtpSettings {
enabled: get("smtp_enabled") == "true",
host: get("smtp_host"),
port: get("smtp_port").parse().unwrap_or(587),
username: get("smtp_username"),
password: get("smtp_password"),
from: get("smtp_from"),
tls_mode: get("smtp_tls_mode"),
}
}
/// Load notification preferences from `system_config`.
async fn load_notification_settings(pool: &PgPool) -> NotificationSettings {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT key, value FROM system_config WHERE key IN (
'notification_email_enabled', 'notification_email_from', 'notification_email_recipients'
)",
)
.fetch_all(pool)
.await
.unwrap_or_default();
let get = |key: &str| -> String {
rows.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.clone())
.unwrap_or_default()
};
let recipients: Vec<String> =
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
NotificationSettings {
email_enabled: get("notification_email_enabled") == "true",
email_from: get("notification_email_from"),
recipients,
}
}
/// Build an async SMTP transport from settings.
fn build_transport(settings: &SmtpSettings) -> Result<AsyncSmtpTransport<Tokio1Executor>, String> {
match settings.tls_mode.as_str() {
"tls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(&settings.host)
.map_err(|e| format!("TLS relay error: {}", e))?;
builder = builder.port(settings.port);
if !settings.username.is_empty() {
builder = builder.credentials(Credentials::new(
settings.username.clone(),
settings.password.clone(),
));
}
Ok(builder.build())
},
"starttls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(&settings.host)
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
builder = builder.port(settings.port);
if !settings.username.is_empty() {
builder = builder.credentials(Credentials::new(
settings.username.clone(),
settings.password.clone(),
));
}
Ok(builder.build())
},
_ => {
// "none" — plaintext / no TLS
let mut builder =
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(&settings.host)
.port(settings.port);
if !settings.username.is_empty() {
builder = builder.credentials(Credentials::new(
settings.username.clone(),
settings.password.clone(),
));
}
Ok(builder.build())
},
}
}
/// Send an email notification. Returns true if the email was sent successfully.
async fn send_email(pool: &PgPool, subject: &str, body: &str) -> bool {
let smtp = match load_smtp_settings(pool).await {
s if !s.enabled => {
tracing::debug!("SMTP not enabled, skipping email notification");
return false;
},
s => s,
};
let notif = load_notification_settings(pool).await;
if !notif.email_enabled {
tracing::debug!("Email notifications disabled, skipping");
return false;
}
if notif.recipients.is_empty() {
tracing::debug!("No email recipients configured, skipping notification");
return false;
}
let from_addr = if notif.email_from.is_empty() {
smtp.from.clone()
} else {
notif.email_from
};
let from_mailbox: Mailbox = match from_addr.parse() {
Ok(m) => m,
Err(e) => {
tracing::error!(error = %e, "Invalid from address for email notification");
return false;
},
};
let mut builder = Message::builder()
.from(from_mailbox.clone())
.subject(subject)
.header(ContentType::TEXT_PLAIN);
// Add all recipients
for recipient in &notif.recipients {
let mailbox: Mailbox = match recipient.parse() {
Ok(m) => m,
Err(e) => {
tracing::error!(error = %e, recipient = %recipient, "Invalid recipient address");
continue;
},
};
builder = builder.to(mailbox);
}
let email = match builder.body(body.to_string()) {
Ok(e) => e,
Err(e) => {
tracing::error!(error = %e, "Failed to build email message");
return false;
},
};
let transport = match build_transport(&smtp) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "Failed to build SMTP transport");
return false;
},
};
match transport.send(email).await {
Ok(_) => {
tracing::info!(subject, "Email notification sent successfully");
true
},
Err(e) => {
tracing::error!(error = %e, subject, "Failed to send email notification");
false
},
}
}
/// Send a patch failure notification email for a specific host.
pub async fn send_patch_failure_email(
pool: &PgPool,
host_fqdn: &str,
job_id: &str,
error_message: &str,
) {
let subject = format!("[Patch Manager] Patch Failed on {}", host_fqdn);
let body = format!(
"Patch operation failed on host: {host_fqdn}\n\
Job ID: {job_id}\n\
Error: {error_message}\n\
\n\
Please review the job details in the Patch Manager dashboard."
);
let sent = send_email(pool, &subject, &body).await;
log_event(
pool,
AuditAction::EmailNotificationSent,
None,
None,
Some("patch_job"),
Some(job_id),
serde_json::json!({
"type": "patch_failure",
"host_fqdn": host_fqdn,
"sent": sent,
}),
None,
None,
)
.await;
}
/// Send a job completion notification email.
pub async fn send_job_completion_email(
pool: &PgPool,
job_id: &str,
host_count: i64,
succeeded_count: i64,
failed_count: i64,
) {
let subject = format!("[Patch Manager] Job {} Completed", job_id);
let body = format!(
"Patch job completed: {job_id}\n\
Total hosts: {host_count}\n\
Succeeded: {succeeded_count}\n\
Failed: {failed_count}\n\
\n\
Please review the job details in the Patch Manager dashboard."
);
let sent = send_email(pool, &subject, &body).await;
log_event(
pool,
AuditAction::EmailNotificationSent,
None,
None,
Some("patch_job"),
Some(job_id),
serde_json::json!({
"type": "job_completion",
"host_count": host_count,
"succeeded_count": succeeded_count,
"failed_count": failed_count,
"sent": sent,
}),
None,
None,
)
.await;
}
/// Send a maintenance window reminder email.
#[allow(dead_code)]
pub async fn send_maintenance_window_reminder_email(
pool: &PgPool,
host_fqdn: &str,
window_label: &str,
start_at: &str,
) {
let subject = format!(
"[Patch Manager] Upcoming Maintenance Window: {}",
window_label
);
let body = format!(
"Maintenance window reminder:\n\
Host: {host_fqdn}\n\
Window: {window_label}\n\
Starts at: {start_at}\n\
\n\
Patch operations will begin at the scheduled time."
);
let sent = send_email(pool, &subject, &body).await;
log_event(
pool,
AuditAction::MaintenanceWindowReminder,
None,
None,
Some("maintenance_window"),
None,
serde_json::json!({
"type": "maintenance_reminder",
"host_fqdn": host_fqdn,
"window_label": window_label,
"sent": sent,
}),
None,
None,
)
.await;
}

View File

@ -0,0 +1,480 @@
//! Periodic health check poller for configured service and HTTP checks.
//!
//! Polls every `health_check_poll_interval_secs`, querying each enabled health
//! check definition and storing results in `host_health_check_results`.
//! Results older than 4 days are pruned on each cycle.
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use pm_core::{config::AppConfig, crypto};
use sqlx::{FromRow, PgPool};
use tokio::{sync::Semaphore, time};
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
use pm_agent_client::{AgentClient, AgentClientError};
// ─────────────────────────────────────────────────────────────────────────────
// DB row types
// ─────────────────────────────────────────────────────────────────────────────
/// Row fetched for each enabled health check, joined with host connection info.
#[derive(FromRow)]
#[allow(dead_code)]
struct HealthCheckRow {
id: Uuid,
host_id: Uuid,
name: String,
check_type: String,
service_name: Option<String>,
url: Option<String>,
expected_body: Option<String>,
ignore_cert_errors: Option<bool>,
basic_auth_user: Option<String>,
basic_auth_pass_encrypted: Option<Vec<u8>>,
basic_auth_pass_nonce: Option<Vec<u8>>,
target_host_id: Option<Uuid>,
ip_address: String,
agent_port: i32,
}
// ─────────────────────────────────────────────────────────────────────────────
// Public entry point
// ─────────────────────────────────────────────────────────────────────────────
/// Run the health check poller loop indefinitely.
///
/// On each tick all enabled health checks are queried concurrently (up to
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
/// to `host_health_check_results` and stale rows are pruned.
pub async fn run_health_check_poller(pool: PgPool, config: Arc<AppConfig>) {
let interval_secs = config.worker.health_check_poll_interval_secs;
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
tracing::info!(interval_secs, "Health check poller started");
loop {
ticker.tick().await;
// Load certs on each cycle so cert rotation is picked up automatically.
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(
error = %e,
"Health check poller: failed to load agent certs — skipping cycle"
);
continue;
},
};
let client_cert = Arc::new(certs.client_cert);
let client_key = Arc::new(certs.client_key);
let ca_cert = Arc::new(certs.ca_cert);
// Load the crypto key for decrypting HTTP check passwords.
let crypto_key = match crypto::load_or_create_key(Path::new(crypto::KEY_PATH)) {
Ok(k) => Arc::new(k),
Err(e) => {
tracing::error!(
error = %e,
"Health check poller: failed to load crypto key — skipping cycle"
);
continue;
},
};
// Fetch all enabled health checks with host connection info.
let checks: Vec<HealthCheckRow> = match sqlx::query_as(
r#"
SELECT
hc.id,
hc.host_id,
hc.name,
hc.check_type,
hc.service_name,
hc.url,
hc.expected_body,
hc.ignore_cert_errors,
hc.basic_auth_user,
hc.basic_auth_pass_encrypted,
hc.basic_auth_pass_nonce,
hc.target_host_id,
host(COALESCE(th.ip_address, h.ip_address))::text AS ip_address,
COALESCE(th.agent_port, h.agent_port) AS agent_port
FROM host_health_checks hc
JOIN hosts h ON h.id = hc.host_id
LEFT JOIN hosts th ON th.id = hc.target_host_id
WHERE hc.enabled = TRUE
ORDER BY hc.id
"#,
)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(e) => {
tracing::error!(error = %e, "Health check poller: failed to fetch health checks");
continue;
},
};
if checks.is_empty() {
tracing::debug!("Health check poller: no enabled health checks, skipping cycle");
prune_old_results(&pool).await;
continue;
}
let total = checks.len();
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
let mut handles = Vec::with_capacity(total);
for check in checks {
let pool = pool.clone();
let sem = semaphore.clone();
let cert = client_cert.clone();
let key = client_key.clone();
let ca = ca_cert.clone();
let ckey = crypto_key.clone();
let handle = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
run_check(pool, check, &cert, &key, &ca, &ckey).await
});
handles.push(handle);
}
// Collect results and tally counts.
let mut healthy_count = 0usize;
let mut unhealthy_count = 0usize;
let mut error_count = 0usize;
for handle in handles {
match handle.await {
Ok(true) => healthy_count += 1,
Ok(false) => unhealthy_count += 1,
Err(e) => {
tracing::error!(error = %e, "Health check poller task panicked");
error_count += 1;
},
}
}
tracing::info!(
total,
healthy_count,
unhealthy_count,
error_count,
"Health check poll cycle complete"
);
// Prune results older than 4 days.
prune_old_results(&pool).await;
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Check dispatch
// ─────────────────────────────────────────────────────────────────────────────
/// Run a single health check and persist the result. Returns `true` if healthy.
async fn run_check(
pool: PgPool,
check: HealthCheckRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
crypto_key: &[u8; 32],
) -> bool {
let start = Instant::now();
let (healthy, detail) = match check.check_type.as_str() {
"service" => run_service_check(&check, client_cert, client_key, ca_cert).await,
"http" => run_http_check(&check, crypto_key).await,
other => {
tracing::warn!(
check_id = %check.id,
check_type = other,
"Unknown health check type — treating as unhealthy"
);
(false, format!("Unknown check type: {other}"))
},
};
let latency_ms = start.elapsed().as_millis() as i32;
// Persist the result.
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_health_check_results (check_id, healthy, detail, latency_ms)
VALUES ($1, $2, $3, $4)
"#,
)
.bind(check.id)
.bind(healthy)
.bind(&detail)
.bind(latency_ms)
.execute(&pool)
.await
{
tracing::error!(
check_id = %check.id,
error = %e,
"Health check poller: failed to insert result"
);
}
healthy
}
// ─────────────────────────────────────────────────────────────────────────────
// Service check (via mTLS AgentClient)
// ─────────────────────────────────────────────────────────────────────────────
/// Execute a service check by calling the agent's `/api/v1/system/services/{name}` endpoint.
async fn run_service_check(
check: &HealthCheckRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) -> (bool, String) {
let service_name = match &check.service_name {
Some(name) => name.clone(),
None => {
return (false, "Service check missing service_name".to_string());
},
};
let client = match AgentClient::new(
&check.ip_address,
check.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Ok(c) => c,
Err(e) => {
return (false, format!("Failed to build AgentClient: {e}"));
},
};
match client.service_status(&service_name).await {
Ok(data) => {
let detail = if data.healthy {
format!(
"Service '{}' is {}/{} (enabled: {})",
data.name, data.active_state, data.sub_state, data.enabled_state
)
} else {
format!(
"Service '{}' status: {}/{} (unhealthy, enabled: {})",
data.name, data.active_state, data.sub_state, data.enabled_state
)
};
(data.healthy, detail)
},
Err(AgentClientError::Timeout) => (
false,
format!("Agent timed out querying service '{service_name}'"),
),
Err(AgentClientError::Connect(_)) => (
false,
format!("Agent connection refused for service '{service_name}'"),
),
Err(AgentClientError::ApiError { code, message }) => {
// 404, 400, 500 etc. from the agent means the service is unhealthy.
(false, format!("Agent error [{code}]: {message}"))
},
Err(e) => (
false,
format!("Agent error querying service '{service_name}': {e}"),
),
}
}
// ─────────────────────────────────────────────────────────────────────────────
// HTTP check (via reqwest, no mTLS)
// ─────────────────────────────────────────────────────────────────────────────
/// Execute an HTTP check by making a GET request to the configured URL.
/// Supports optional basic auth (decrypted from DB) and substring body matching.
async fn run_http_check(check: &HealthCheckRow, crypto_key: &[u8; 32]) -> (bool, String) {
let url = match &check.url {
Some(u) => u.clone(),
None => {
return (false, "HTTP check missing URL".to_string());
},
};
// Build a reqwest client for this check.
// Use danger_accept_invalid_certs if ignore_cert_errors is set (default true).
let ignore_cert_errors = check.ignore_cert_errors.unwrap_or(true);
let client_builder = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.redirect(reqwest::redirect::Policy::limited(5));
let client = if ignore_cert_errors {
client_builder
.danger_accept_invalid_certs(true)
.build()
.unwrap_or_else(|_| reqwest::Client::new())
} else {
client_builder
.build()
.unwrap_or_else(|_| reqwest::Client::new())
};
// Build the request.
let mut request = client.get(&url);
// Add basic auth if configured.
if let Some(user) = &check.basic_auth_user {
// Decrypt the password if present.
let password = match (
&check.basic_auth_pass_encrypted,
&check.basic_auth_pass_nonce,
) {
(Some(enc), Some(nonce)) => match crypto::decrypt(enc, nonce, crypto_key) {
Ok(p) => p,
Err(e) => {
return (false, format!("Failed to decrypt basic auth password: {e}"));
},
},
_ => {
// No encrypted password stored — treat as missing credentials.
return (
false,
"HTTP check has basic_auth_user but no encrypted password".to_string(),
);
},
};
request = request.basic_auth(user.as_str(), Some(password.as_str()));
}
// Execute the request.
let response = match request.send().await {
Ok(r) => r,
Err(e) => {
if e.is_timeout() {
return (false, format!("HTTP check timed out: {url}"));
} else if e.is_connect() {
return (false, format!("HTTP check connection failed: {url}"));
} else {
return (false, format!("HTTP check request error: {e}"));
}
},
};
let status = response.status();
// Check HTTP status code.
if !status.is_success() {
return (
false,
format!("HTTP check returned status {} for {url}", status.as_u16()),
);
}
// Read the response body for substring matching.
let body = match response.text().await {
Ok(b) => b,
Err(e) => {
return (
false,
format!("HTTP check failed to read response body: {e}"),
);
},
};
// Check expected_body substring match.
if let Some(expected) = &check.expected_body {
if !body.contains(expected) {
return (
false,
format!("HTTP check body mismatch for {url}: expected substring not found"),
);
}
}
(
true,
format!("HTTP check OK for {url} (status {})", status.as_u16()),
)
}
// ─────────────────────────────────────────────────────────────────────────────
// Prune old results
// ─────────────────────────────────────────────────────────────────────────────
/// Delete health check results older than 4 days.
async fn prune_old_results(pool: &PgPool) {
match sqlx::query(
"DELETE FROM host_health_check_results WHERE checked_at < NOW() - INTERVAL '4 days'",
)
.execute(pool)
.await
{
Ok(result) => {
if result.rows_affected() > 0 {
tracing::info!(
rows_deleted = result.rows_affected(),
"Health check poller: pruned old results"
);
}
},
Err(e) => {
tracing::error!(error = %e, "Health check poller: failed to prune old results");
},
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Health check gate for job executor
// ─────────────────────────────────────────────────────────────────────────────
/// Check whether all enabled health checks for a host are healthy.
///
/// Returns `Ok(true)` if all checks pass (or no checks are configured),
/// `Ok(false)` if any check is unhealthy or has no result yet.
pub async fn check_host_health_checks(pool: &PgPool, host_id: Uuid) -> anyhow::Result<bool> {
// Check if there are any enabled health checks for this host.
let check_count: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM host_health_checks WHERE host_id = $1 AND enabled = TRUE",
)
.bind(host_id)
.fetch_one(pool)
.await?;
if check_count.0 == 0 {
// No health checks configured for this host — treat as healthy.
return Ok(true);
}
// Find any enabled check that has no healthy result or an unhealthy latest result.
let unhealthy_count: (i64,) = sqlx::query_as(
r#"
SELECT COUNT(*)
FROM host_health_checks hc
LEFT JOIN LATERAL (
SELECT healthy
FROM host_health_check_results r
WHERE r.check_id = hc.id
ORDER BY r.checked_at DESC
LIMIT 1
) latest ON true
WHERE hc.host_id = $1
AND hc.enabled = TRUE
AND (latest.healthy IS NULL OR latest.healthy = FALSE)
"#,
)
.bind(host_id)
.fetch_one(pool)
.await?;
Ok(unhealthy_count.0 == 0)
}

View File

@ -0,0 +1,249 @@
//! Periodic health poller for all registered hosts.
//!
//! Polls every host via the agent `/health` endpoint on each tick of
//! `health_poll_interval_secs`, with bounded concurrency controlled by a
//! [`tokio::sync::Semaphore`]. Also calls `/system/info` to refresh
//! `os_family`, `os_name`, `arch`, and `agent_version` in the hosts table.
use std::sync::Arc;
use pm_agent_client::{AgentClient, AgentClientError};
use pm_core::{config::AppConfig, models::HostHealthStatus};
use sqlx::{FromRow, PgPool};
use tokio::{sync::Semaphore, time};
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
/// Minimal host projection fetched for each poll cycle.
#[derive(Debug, FromRow)]
struct HostRow {
id: Uuid,
ip_address: String,
agent_port: i32,
}
/// Run the health poller loop indefinitely.
///
/// On each tick all registered hosts are queried concurrently (up to
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
/// to `host_health_data` and the `hosts` table is updated.
pub async fn run_health_poller(pool: PgPool, config: Arc<AppConfig>) {
let interval_secs = config.worker.health_poll_interval_secs;
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
tracing::info!(interval_secs, "Health poller started");
loop {
ticker.tick().await;
// Load certs on each cycle so cert rotation is picked up automatically.
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Health poller: failed to load agent certs — skipping cycle");
continue;
},
};
let client_cert = Arc::new(certs.client_cert);
let client_key = Arc::new(certs.client_key);
let ca_cert = Arc::new(certs.ca_cert);
// Fetch all hosts.
let hosts: Vec<HostRow> = match sqlx::query_as(
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(e) => {
tracing::error!(error = %e, "Health poller: failed to fetch hosts");
continue;
},
};
if hosts.is_empty() {
tracing::debug!("Health poller: no hosts registered, skipping cycle");
continue;
}
let total = hosts.len();
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
let mut handles = Vec::with_capacity(total);
for host in hosts {
let pool = pool.clone();
let sem = semaphore.clone();
let cert = client_cert.clone();
let key = client_key.clone();
let ca = ca_cert.clone();
let handle = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
poll_host_health(pool, host, &cert, &key, &ca).await
});
handles.push(handle);
}
// Collect results and tally counts.
let mut healthy = 0usize;
let mut degraded = 0usize;
let mut unreachable = 0usize;
for handle in handles {
match handle.await {
Ok(HostHealthStatus::Healthy) => healthy += 1,
Ok(HostHealthStatus::Degraded) => degraded += 1,
Ok(HostHealthStatus::Unreachable) => unreachable += 1,
Ok(_) => {},
Err(e) => tracing::error!(error = %e, "Health poller task panicked"),
}
}
tracing::info!(
total,
healthy,
degraded,
unreachable,
"Health poll cycle complete"
);
}
}
/// Poll a single host, persist the result, and return the determined status.
///
/// Also updates `agent_version` from the health response and
/// `os_family`/`os_name`/`arch` from the `/system/info` endpoint when available.
async fn poll_host_health(
pool: PgPool,
host: HostRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) -> HostHealthStatus {
// Determine status, payload, agent version, and optional system info.
let (status, payload, agent_version, sys_info) = match AgentClient::new(
&host.ip_address,
host.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Err(e) => {
tracing::warn!(
host_id = %host.id,
error = %e,
"Health poller: failed to build AgentClient"
);
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
None,
None,
)
},
Ok(client) => {
let (status, payload, version) = match client.health().await {
Ok(data) => {
let payload = serde_json::to_value(&data).unwrap_or_default();
(HostHealthStatus::Healthy, payload, Some(data.version))
},
Err(AgentClientError::Timeout) => {
tracing::warn!(host_id = %host.id, "Health poller: agent timed out");
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
None,
)
},
Err(AgentClientError::Connect(_)) => {
tracing::warn!(host_id = %host.id, "Health poller: agent connection refused");
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
None,
)
},
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Health poller: agent error");
(
HostHealthStatus::Degraded,
serde_json::Value::Object(Default::default()),
None,
)
},
};
// Try to fetch system info for OS/arch details (best-effort).
let sys_info = if status != HostHealthStatus::Unreachable {
match client.system_info().await {
Ok(info) => Some(info),
Err(e) => {
tracing::debug!(
host_id = %host.id,
error = %e,
"Health poller: failed to get system info (non-fatal)"
);
None
},
}
} else {
None
};
(status, payload, version, sys_info)
},
};
// Insert into host_health_data.
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_health_data (host_id, status, payload)
VALUES ($1, $2, $3)
"#,
)
.bind(host.id)
.bind(&status)
.bind(&payload)
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to insert health data");
}
// Build OS name from system info components (e.g. "Ubuntu 24.04").
let os_name_from_sysinfo = sys_info
.as_ref()
.map(|i| format!("{} {}", i.os, i.os_version));
// Update hosts table with health status, agent version, and OS details.
// COALESCE preserves existing values when new data is unavailable.
if let Err(e) = sqlx::query(
r#"
UPDATE hosts
SET health_status = $2, last_health_at = NOW(),
agent_version = COALESCE($3, agent_version),
os_family = COALESCE($4, os_family),
os_name = COALESCE($5, os_name),
arch = COALESCE($6, arch)
WHERE id = $1
"#,
)
.bind(host.id)
.bind(&status)
.bind(&agent_version)
.bind(sys_info.as_ref().map(|i| i.os.as_str()))
.bind(os_name_from_sysinfo)
.bind(sys_info.as_ref().map(|i| i.architecture.as_str()))
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to update host status");
}
status
}

File diff suppressed because it is too large Load Diff

208
crates/pm-worker/src/main.rs Executable file
View File

@ -0,0 +1,208 @@
//! pm-worker — Linux Patch Manager background worker.
//!
//! Handles scheduled polling, job execution, maintenance window scheduling,
//! retry logic, email notifications, audit integrity verification, and data pruning.
mod agent_loader;
mod audit_verifier;
mod email;
mod health_check_poller;
mod health_poller;
mod job_executor;
mod maintenance_scheduler;
mod patch_poller;
mod refresh_listener;
mod ws_relay;
use chrono::Utc;
use pm_core::{config::AppConfig, db, logging};
use sqlx::PgPool;
use std::{sync::Arc, time::Duration};
use tokio::time;
use audit_verifier::run_audit_verifier;
use health_check_poller::run_health_check_poller;
use health_poller::run_health_poller;
use job_executor::run_job_executor;
use maintenance_scheduler::run_maintenance_scheduler;
use patch_poller::run_patch_poller;
use refresh_listener::run_refresh_listener;
use ws_relay::run_ws_relay;
/// Minimum number of applied migrations the worker requires before
/// accepting work. Prevents the worker from running against a schema
/// that hasn't been migrated yet.
const REQUIRED_MIGRATION_COUNT: i64 = 16;
/// How long to wait between schema-version checks before giving up.
const SCHEMA_CHECK_TIMEOUT: Duration = Duration::from_secs(120);
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Install the default crypto provider for rustls (required since 0.23)
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
// Load configuration
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
eprintln!("Config file not found or invalid, using defaults");
AppConfig::default()
});
// Initialize logging
logging::init(&config.logging);
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"patch-manager-worker starting"
);
// Initialize database pool
let pool = db::init_pool(&config.database).await?;
// Wait for schema to be at the expected version (web process runs migrations)
wait_for_schema(&pool).await?;
let config = Arc::new(config);
// Spawn worker tasks
let heartbeat_handle = tokio::spawn(run_heartbeat(
pool.clone(),
config.worker.heartbeat_interval_secs,
));
// M4: agent health poller, patch data poller, on-demand refresh listener
let health_handle = tokio::spawn(run_health_poller(pool.clone(), config.clone()));
let patch_handle = tokio::spawn(run_patch_poller(pool.clone(), config.clone()));
let refresh_handle = tokio::spawn(run_refresh_listener(pool.clone(), config.clone()));
// M5: job execution engine
let job_exec_handle = tokio::spawn(run_job_executor(pool.clone(), config.clone()));
// M6: maintenance window scheduler
let maint_sched_handle = tokio::spawn(run_maintenance_scheduler(pool.clone(), config.clone()));
// M7: WS relay — streams agent job events → DB → pg_notify → browser WS
let ws_relay_handle = tokio::spawn(run_ws_relay(pool.clone(), config.clone()));
// M11: audit integrity verification (runs every 24 hours)
let audit_verifier_handle = tokio::spawn(run_audit_verifier(pool.clone(), config.clone()));
// Health check poller — runs configured service/HTTP health checks
let health_check_handle = tokio::spawn(run_health_check_poller(pool.clone(), config.clone()));
// Enrollment cleanup task (runs every hour)
let enrollment_cleanup_handle = tokio::spawn(run_enrollment_cleanup_task(pool.clone()));
tracing::info!("Worker tasks started");
// Wait for all tasks (they run indefinitely)
let _ = tokio::join!(
heartbeat_handle,
health_handle,
patch_handle,
refresh_handle,
job_exec_handle,
maint_sched_handle,
ws_relay_handle,
audit_verifier_handle,
health_check_handle,
enrollment_cleanup_handle,
);
Ok(())
}
/// Wait until the database schema has at least `REQUIRED_MIGRATION_COUNT`
/// successful migrations applied. Retries every 5 seconds up to
/// `SCHEMA_CHECK_TIMEOUT`.
async fn wait_for_schema(pool: &PgPool) -> anyhow::Result<()> {
let deadline = tokio::time::Instant::now() + SCHEMA_CHECK_TIMEOUT;
loop {
match db::check_schema_version(pool).await {
Ok(count) if count >= REQUIRED_MIGRATION_COUNT => {
tracing::info!(migration_count = count, "Schema version check passed");
return Ok(());
},
Ok(count) => {
tracing::warn!(
migration_count = count,
required = REQUIRED_MIGRATION_COUNT,
"Schema not ready, waiting..."
);
},
Err(e) => {
tracing::warn!(error = %e, "Schema version check failed, retrying...");
},
}
if tokio::time::Instant::now() >= deadline {
anyhow::bail!(
"Schema not ready after {}s — is the web process running migrations?",
SCHEMA_CHECK_TIMEOUT.as_secs()
);
}
time::sleep(Duration::from_secs(5)).await;
}
}
/// Writes a heartbeat row to `worker_heartbeat` every `interval_secs`.
/// The web process can query this to confirm the worker is alive.
async fn run_heartbeat(pool: PgPool, interval_secs: u64) {
let interval = Duration::from_secs(interval_secs);
let mut ticker = time::interval(interval);
loop {
ticker.tick().await;
let result = sqlx::query(
r#"
INSERT INTO worker_heartbeat (id, last_seen, worker_version)
VALUES (1, NOW(), $1)
ON CONFLICT (id) DO UPDATE
SET last_seen = EXCLUDED.last_seen,
worker_version = EXCLUDED.worker_version
"#,
)
.bind(env!("CARGO_PKG_VERSION"))
.execute(&pool)
.await;
match result {
Ok(_) => tracing::debug!("Worker heartbeat written"),
Err(e) => tracing::error!(error = %e, "Worker heartbeat failed"),
}
}
}
/// Periodically deletes expired enrollment requests.
async fn run_enrollment_cleanup_task(pool: PgPool) {
let mut interval = tokio::time::interval(Duration::from_secs(3600)); // Every hour
interval.tick().await; // Initial tick to run immediately if needed
loop {
interval.tick().await;
let now = Utc::now();
match sqlx::query("DELETE FROM enrollment_requests WHERE expires_at < $1")
.bind(now)
.execute(&pool)
.await
{
Ok(result) => {
if result.rows_affected() > 0 {
tracing::info!(
removed = result.rows_affected(),
"Purged expired enrollment requests"
);
}
},
Err(e) => tracing::error!(error = %e, "Failed to purge expired enrollment requests"),
}
}
}

View File

@ -0,0 +1,381 @@
//! Maintenance window scheduler.
//!
//! Polls every 60 seconds and performs two tasks:
//!
//! 1. **Auto-apply**: For each enabled maintenance window with `auto_apply = true`
//! that is currently open, if the host has pending patches and no existing
//! patch_apply job queued/running for that window, automatically creates one.
//!
//! 2. **Dispatch**: For each open window, dispatch any queued non-immediate
//! patch jobs associated with the window's host.
//!
//! A window is considered "open" when:
//! - `once` — `start_at <= NOW() < start_at + duration_minutes * '1 minute'`
//! - `daily` — current UTC time-of-day is within the window's daily slot
//! - `weekly` — same as daily, but only on the matching `recurrence_day` (0=Sun)
//! - `monthly` — same as daily, but only on the matching `recurrence_day` (1-31)
use std::sync::Arc;
use pm_core::config::AppConfig;
use sqlx::{FromRow, PgPool};
use tokio::time;
use uuid::Uuid;
use crate::job_executor::process_job;
// ─────────────────────────────────────────────────────────────────────────────
// Internal types
// ─────────────────────────────────────────────────────────────────────────────
#[derive(Debug, FromRow)]
struct OpenWindowHost {
host_id: Uuid,
}
#[derive(Debug, FromRow)]
struct QueuedJobId {
job_id: Uuid,
}
#[derive(Debug, FromRow)]
struct AutoApplyWindow {
window_id: Uuid,
host_id: Uuid,
}
#[derive(Debug, FromRow)]
#[allow(dead_code)]
struct PendingPatchHost {
host_id: Uuid,
patch_count: i32,
}
#[derive(Debug, FromRow)]
struct InsertedJobId {
job_id: Uuid,
}
// ─────────────────────────────────────────────────────────────────────────────
// Public entry point
// ─────────────────────────────────────────────────────────────────────────────
/// Run the maintenance scheduler indefinitely.
/// Spawned by `pm-worker/src/main.rs` alongside the job executor.
pub async fn run_maintenance_scheduler(pool: PgPool, config: Arc<AppConfig>) {
tracing::info!("Maintenance scheduler started");
// First tick fires immediately; consume it to align with job_executor.
let mut ticker = time::interval(std::time::Duration::from_secs(60));
ticker.tick().await;
loop {
ticker.tick().await;
tracing::debug!("Maintenance scheduler: checking open windows");
// Step 1: Auto-create patch_apply jobs for windows with auto_apply=true
auto_create_patch_jobs(pool.clone(), config.clone()).await;
// Step 2: Dispatch any queued non-immediate jobs for open windows
dispatch_open_window_jobs(pool.clone(), config.clone()).await;
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Step 1: Auto-create patch_apply jobs
// ─────────────────────────────────────────────────────────────────────────────
/// For each enabled maintenance window that is currently open AND has
/// `auto_apply = true`, check if the host has pending patches and no
/// existing patch_apply job for this window cycle. If so, create one.
async fn auto_create_patch_jobs(pool: PgPool, _config: Arc<AppConfig>) {
// Find all open windows with auto_apply=true
let auto_windows: Vec<AutoApplyWindow> = match sqlx::query_as(
r#"
SELECT mw.id AS window_id, mw.host_id
FROM maintenance_windows mw
WHERE mw.enabled = TRUE
AND mw.auto_apply = TRUE
AND (
( mw.recurrence = 'once'
AND mw.start_at <= NOW()
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
)
OR
( mw.recurrence = 'daily'
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
( mw.recurrence = 'weekly'
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
( mw.recurrence = 'monthly'
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
)
"#,
)
.fetch_all(&pool)
.await
{
Ok(w) => w,
Err(e) => {
tracing::error!(error = %e, "auto_create_patch_jobs: open-windows query failed");
return;
}
};
if auto_windows.is_empty() {
tracing::debug!("auto_create: no open auto-apply windows this cycle");
return;
}
tracing::info!(
auto_window_count = auto_windows.len(),
"auto_create: found open auto-apply windows"
);
for win in &auto_windows {
// Check if host has pending patches
let pending: Option<PendingPatchHost> = match sqlx::query_as(
r#"
SELECT host_id, patch_count
FROM host_patch_data
WHERE host_id = $1 AND patch_count > 0
"#,
)
.bind(win.host_id)
.fetch_optional(&pool)
.await
{
Ok(p) => p,
Err(e) => {
tracing::error!(
error = %e,
host_id = %win.host_id,
"auto_create: patch data query failed"
);
continue;
},
};
let Some(pending) = pending else {
tracing::debug!(
host_id = %win.host_id,
"auto_create: no pending patches, skipping"
);
continue;
};
// Check if there's already a queued/running patch_apply job for this host
// that was created during this window cycle (within the window's time range).
// We use a simpler check: any non-completed patch_apply job for this host
// that references this maintenance window, OR any non-immediate job without
// a window that was created since the window opened.
let existing_job: bool = match sqlx::query_scalar(
r#"
SELECT EXISTS(
SELECT 1 FROM patch_jobs pj
JOIN patch_job_hosts pjh ON pj.id = pjh.job_id
WHERE pjh.host_id = $1
AND pj.status IN ('queued', 'running', 'pending')
AND pj.kind = 'patch_apply'
AND (
pj.maintenance_window_id = $2
OR
(pj.immediate = FALSE AND pj.created_at >=
(SELECT start_at - INTERVAL '5 minutes' FROM maintenance_windows WHERE id = $2)
)
)
)
"#,
)
.bind(win.host_id)
.bind(win.window_id)
.fetch_one(&pool)
.await
{
Ok(b) => b,
Err(e) => {
tracing::error!(
error = %e,
host_id = %win.host_id,
"auto_create: existing job check failed"
);
continue;
}
};
if existing_job {
tracing::debug!(
host_id = %win.host_id,
window_id = %win.window_id,
"auto_create: existing job already queued/running, skipping"
);
continue;
}
// Create a new patch_apply job for this host, linked to the window.
let job: Option<InsertedJobId> = match sqlx::query_as(
r#"
WITH new_job AS (
INSERT INTO patch_jobs
(kind, status, maintenance_window_id, immediate, patch_selection, notes)
VALUES
('patch_apply', 'queued', $1, FALSE, '[]'::jsonb,
'Auto-created by maintenance window scheduler')
RETURNING id AS job_id
)
INSERT INTO patch_job_hosts (job_id, host_id, status)
SELECT new_job.job_id, $2, 'queued'
FROM new_job
RETURNING job_id
"#,
)
.bind(win.window_id)
.bind(win.host_id)
.fetch_optional(&pool)
.await
{
Ok(j) => j,
Err(e) => {
tracing::error!(
error = %e,
host_id = %win.host_id,
window_id = %win.window_id,
"auto_create: job insert failed"
);
continue;
},
};
if let Some(job) = job {
tracing::info!(
job_id = %job.job_id,
host_id = %win.host_id,
window_id = %win.window_id,
patch_count = pending.patch_count,
"auto_create: created patch_apply job for host in maintenance window"
);
}
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Step 2: Dispatch queued non-immediate jobs
// ─────────────────────────────────────────────────────────────────────────────
/// Find all hosts with a currently-open maintenance window, then for each,
/// find their queued non-immediate job entries and dispatch them.
async fn dispatch_open_window_jobs(pool: PgPool, config: Arc<AppConfig>) {
// ── 1. Find all host_ids with an open window right now ─────────────────
let open_hosts: Vec<OpenWindowHost> = match sqlx::query_as(
r#"
SELECT DISTINCT mw.host_id
FROM maintenance_windows mw
WHERE mw.enabled = TRUE
AND (
-- One-time: absolute window
( mw.recurrence = 'once'
AND mw.start_at <= NOW()
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
)
OR
-- Daily: time-of-day slot, any day
( mw.recurrence = 'daily'
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
-- Weekly: matching day-of-week + time-of-day slot
( mw.recurrence = 'weekly'
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
-- Monthly: matching day-of-month + time-of-day slot
( mw.recurrence = 'monthly'
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
)
"#,
)
.fetch_all(&pool)
.await
{
Ok(hosts) => hosts,
Err(e) => {
tracing::error!(error = %e, "dispatch_open_window_jobs: open-hosts query failed");
return;
}
};
if open_hosts.is_empty() {
tracing::debug!("Maintenance scheduler: no open windows this cycle");
return;
}
tracing::info!(
open_host_count = open_hosts.len(),
"Maintenance scheduler: found hosts with open windows"
);
// ── 2. For each open host, find distinct queued non-immediate job IDs ──
for host in open_hosts {
let job_ids: Vec<QueuedJobId> = match sqlx::query_as(
r#"
SELECT DISTINCT pjh.job_id
FROM patch_job_hosts pjh
JOIN patch_jobs j ON j.id = pjh.job_id
WHERE pjh.host_id = $1
AND pjh.status = 'queued'
AND j.immediate = FALSE
AND j.status != 'cancelled'
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
"#,
)
.bind(host.host_id)
.fetch_all(&pool)
.await
{
Ok(ids) => ids,
Err(e) => {
tracing::error!(
error = %e,
host_id = %host.host_id,
"dispatch_open_window_jobs: queued jobs query failed"
);
continue;
},
};
for job in job_ids {
tracing::info!(
job_id = %job.job_id,
host_id = %host.host_id,
"Maintenance scheduler: dispatching non-immediate job (window open)"
);
let (p, c) = (pool.clone(), config.clone());
let job_id = job.job_id;
tokio::spawn(async move {
process_job(p, c, job_id).await;
});
}
}
}

View File

@ -0,0 +1,202 @@
//! Periodic patch-data poller for all registered hosts.
//!
//! Polls every host via the agent `/patches` and `/packages` endpoints on
//! each tick of `patch_poll_interval_secs`, with bounded concurrency
//! controlled by a [`tokio::sync::Semaphore`].
use std::sync::Arc;
use pm_agent_client::AgentClient;
use pm_core::config::AppConfig;
use sqlx::{FromRow, PgPool};
use tokio::{sync::Semaphore, time};
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
/// Minimal host projection fetched for each poll cycle.
#[derive(Debug, FromRow)]
struct HostRow {
id: Uuid,
ip_address: String,
agent_port: i32,
}
/// Run the patch poller loop indefinitely.
///
/// On each tick all registered hosts are queried concurrently (up to
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
/// to `host_patch_data` and `hosts.last_patch_at` is updated.
pub async fn run_patch_poller(pool: PgPool, config: Arc<AppConfig>) {
let interval_secs = config.worker.patch_poll_interval_secs;
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
tracing::info!(interval_secs, "Patch poller started");
loop {
ticker.tick().await;
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Patch poller: failed to load agent certs — skipping cycle");
continue;
},
};
let client_cert = Arc::new(certs.client_cert);
let client_key = Arc::new(certs.client_key);
let ca_cert = Arc::new(certs.ca_cert);
let hosts: Vec<HostRow> = match sqlx::query_as(
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(e) => {
tracing::error!(error = %e, "Patch poller: failed to fetch hosts");
continue;
},
};
if hosts.is_empty() {
tracing::debug!("Patch poller: no hosts registered, skipping cycle");
continue;
}
let total = hosts.len();
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
let mut handles = Vec::with_capacity(total);
for host in hosts {
let pool = pool.clone();
let sem = semaphore.clone();
let cert = client_cert.clone();
let key = client_key.clone();
let ca = ca_cert.clone();
let handle = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
poll_host_patches(pool, host, &cert, &key, &ca).await
});
handles.push(handle);
}
let mut succeeded = 0usize;
let mut failed = 0usize;
for handle in handles {
match handle.await {
Ok(true) => succeeded += 1,
Ok(false) => failed += 1,
Err(e) => {
tracing::error!(error = %e, "Patch poller task panicked");
failed += 1;
},
}
}
tracing::info!(total, succeeded, failed, "Patch poll cycle complete");
}
}
/// Poll a single host for patch and package data, persist the result.
/// Returns `true` on success, `false` on any error.
async fn poll_host_patches(
pool: PgPool,
host: HostRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) -> bool {
let client = match AgentClient::new(
&host.ip_address,
host.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Ok(c) => c,
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: failed to build AgentClient");
return false;
},
};
// Fetch patches and packages concurrently.
let (patches_result, packages_result) =
tokio::join!(client.patches(), client.packages_upgradable());
let patches_data = match patches_result {
Ok(d) => d,
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: patches() failed");
return false;
},
};
let packages_data = match packages_result {
Ok(d) => d,
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: packages_upgradable() failed");
return false;
},
};
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
let installed_packages = serde_json::to_value(&packages_data.packages).unwrap_or_default();
let patch_count = patches_data.total as i32;
let cve_count = patches_data
.patches
.iter()
.filter(|p| !p.cve_ids.is_empty())
.count() as i32;
// Upsert into host_patch_data (one row per host, latest poll wins).
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_patch_data
(host_id, available_patches, installed_packages, patch_count, cve_count)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (host_id) DO UPDATE SET
available_patches = EXCLUDED.available_patches,
installed_packages = EXCLUDED.installed_packages,
patch_count = EXCLUDED.patch_count,
cve_count = EXCLUDED.cve_count,
polled_at = NOW()
"#,
)
.bind(host.id)
.bind(&available_patches)
.bind(&installed_packages)
.bind(patch_count)
.bind(cve_count)
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to insert patch data");
return false;
}
// Update hosts.last_patch_at.
if let Err(e) = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
.bind(host.id)
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to update last_patch_at");
}
tracing::debug!(
host_id = %host.id,
patch_count,
cve_count,
"Patch data collected"
);
true
}

View File

@ -0,0 +1,269 @@
//! On-demand refresh listener.
//!
//! Listens on the PostgreSQL `refresh_requested` NOTIFY channel. When a
//! notification arrives the payload is expected to be a host UUID string.
//! The listener immediately polls that host for health and patch data and
//! persists the results — bypassing the normal poll intervals.
use std::sync::Arc;
use pm_agent_client::{AgentClient, AgentClientError};
use pm_core::{config::AppConfig, models::HostHealthStatus};
use sqlx::{FromRow, PgPool};
use tokio::time;
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
/// Minimal host row used for on-demand refresh.
#[derive(Debug, FromRow)]
struct HostRow {
id: Uuid,
ip_address: String,
agent_port: i32,
}
/// Run the LISTEN/NOTIFY refresh listener indefinitely.
///
/// Automatically reconnects if the underlying PostgreSQL connection drops.
pub async fn run_refresh_listener(pool: PgPool, config: Arc<AppConfig>) {
tracing::info!("Refresh listener started — listening on 'refresh_requested'");
loop {
if let Err(e) = listen_loop(&pool, &config).await {
tracing::error!(
error = %e,
"Refresh listener disconnected, reconnecting in 5s"
);
time::sleep(std::time::Duration::from_secs(5)).await;
}
}
}
/// Inner loop — returns `Err` only on a fatal listener error so the outer
/// loop can reconnect.
async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> {
let mut listener = sqlx::postgres::PgListener::connect(&config.database.url).await?;
listener.listen("refresh_requested").await?;
tracing::debug!("Refresh listener connected and listening");
loop {
let notification = listener.recv().await?;
let payload = notification.payload().to_string();
tracing::info!(payload, "Refresh notification received");
let host_id = match payload.parse::<Uuid>() {
Ok(id) => id,
Err(e) => {
tracing::warn!(
payload,
error = %e,
"Refresh listener: invalid UUID in notification payload"
);
continue;
},
};
// Fetch the host from the database.
let host: Option<HostRow> = sqlx::query_as(
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts WHERE id = $1",
)
.bind(host_id)
.fetch_optional(pool)
.await
.unwrap_or(None);
let host = match host {
Some(h) => h,
None => {
tracing::warn!(%host_id, "Refresh listener: host not found");
continue;
},
};
// Load certs for this refresh.
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(
%host_id,
error = %e,
"Refresh listener: failed to load agent certs"
);
continue;
},
};
// Spawn the actual work so the listener loop is not blocked.
let pool_clone = pool.clone();
let cert = certs.client_cert;
let key = certs.client_key;
let ca = certs.ca_cert;
tokio::spawn(async move {
refresh_host(pool_clone, host, &cert, &key, &ca).await;
});
}
}
/// Perform a full health + patch refresh for one host and persist results.
async fn refresh_host(
pool: PgPool,
host: HostRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) {
let client = match AgentClient::new(
&host.ip_address,
host.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Ok(c) => c,
Err(e) => {
tracing::warn!(
host_id = %host.id,
error = %e,
"Refresh: failed to build AgentClient"
);
persist_health_unreachable(&pool, host.id).await;
return;
},
};
// ── Health ────────────────────────────────────────────────────────────
let (health_status, health_payload) = match client.health().await {
Ok(data) => {
let payload = serde_json::to_value(&data).unwrap_or_default();
(HostHealthStatus::Healthy, payload)
},
Err(AgentClientError::Timeout) | Err(AgentClientError::Connect(_)) => {
tracing::warn!(host_id = %host.id, "Refresh: agent unreachable");
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
)
},
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Refresh: health error");
(
HostHealthStatus::Degraded,
serde_json::Value::Object(Default::default()),
)
},
};
persist_health(&pool, host.id, &health_status, &health_payload).await;
// ── Patch data ────────────────────────────────────────────────────────
let (patches_result, packages_result) =
tokio::join!(client.patches(), client.packages_upgradable());
match (patches_result, packages_result) {
(Ok(patches_data), Ok(packages_data)) => {
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
let installed_packages =
serde_json::to_value(&packages_data.packages).unwrap_or_default();
let patch_count = patches_data.total as i32;
let cve_count = patches_data
.patches
.iter()
.filter(|p| !p.cve_ids.is_empty())
.count() as i32;
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_patch_data
(host_id, available_patches, installed_packages, patch_count, cve_count)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (host_id) DO UPDATE SET
available_patches = EXCLUDED.available_patches,
installed_packages = EXCLUDED.installed_packages,
patch_count = EXCLUDED.patch_count,
cve_count = EXCLUDED.cve_count,
polled_at = NOW()
"#,
)
.bind(host.id)
.bind(&available_patches)
.bind(&installed_packages)
.bind(patch_count)
.bind(cve_count)
.execute(&pool)
.await
{
tracing::error!(
host_id = %host.id,
error = %e,
"Refresh: failed to insert patch data"
);
} else {
let _ = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
.bind(host.id)
.execute(&pool)
.await;
tracing::info!(
host_id = %host.id,
patch_count,
cve_count,
"On-demand refresh complete"
);
}
},
(Err(e), _) | (_, Err(e)) => {
tracing::warn!(
host_id = %host.id,
error = %e,
"Refresh: failed to collect patch data"
);
},
}
}
async fn persist_health_unreachable(pool: &PgPool, host_id: Uuid) {
let status = HostHealthStatus::Unreachable;
let payload = serde_json::Value::Object(Default::default());
persist_health(pool, host_id, &status, &payload).await;
}
async fn persist_health(
pool: &PgPool,
host_id: Uuid,
status: &HostHealthStatus,
payload: &serde_json::Value,
) {
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_health_data (host_id, status, payload)
VALUES ($1, $2, $3)
"#,
)
.bind(host_id)
.bind(status)
.bind(payload)
.execute(pool)
.await
{
tracing::error!(
%host_id,
error = %e,
"Refresh: failed to insert health data"
);
}
if let Err(e) =
sqlx::query("UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1")
.bind(host_id)
.bind(status)
.execute(pool)
.await
{
tracing::error!(%host_id, error = %e, "Refresh: failed to update host health_status");
}
}

718
crates/pm-worker/src/ws_relay.rs Executable file
View File

@ -0,0 +1,718 @@
//! WS relay — M7
//!
//! For every running `patch_job_hosts` row that has an `agent_job_id`, open a
//! WebSocket to the corresponding agent, stream job-status events, update the
//! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS
//! handler can forward the event to connected clients.
use std::{collections::HashSet, error::Error, sync::Arc, time::Duration};
use anyhow::Context;
use futures::StreamExt;
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
ClientConfig as TlsClientConfig, RootCertStore,
};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::protocol::Message, Connector};
use uuid::Uuid;
use pm_agent_client::client::AgentClient;
use pm_agent_client::client::DEFAULT_AGENT_PORT;
use pm_core::config::AppConfig;
// ── Types ─────────────────────────────────────────────────────────────────────
#[derive(Debug, sqlx::FromRow)]
struct RunningHostJob {
job_id: Uuid,
host_id: Uuid,
agent_job_id: String,
host_address: String,
}
/// JSON event streamed by the agent over its WS endpoint.
#[derive(Debug, Deserialize)]
struct AgentWsEvent {
#[allow(dead_code)]
job_id: String,
status: String,
output: Option<String>,
error: Option<String>,
#[allow(dead_code)]
progress_percent: Option<u8>,
}
/// Payload broadcast via `pg_notify('job_update', …)`.
#[derive(Debug, Serialize)]
struct NotifyPayload {
event_type: String, // "host" or "job"
job_id: String,
host_id: String,
status: String,
output: Option<String>,
error_message: Option<String>,
agent_job_id: String,
// Job-level fields (only present when event_type === "job")
succeeded_count: Option<i64>,
failed_count: Option<i64>,
host_count: Option<i64>,
}
// ── Cert PEM bytes for building AgentClient ────────────────────────────────────
/// Raw PEM bytes read from the security config cert paths.
/// Used to build an [`AgentClient`] for HTTP polling fallback.
#[derive(Clone)]
struct CertPems {
client_cert: Vec<u8>,
client_key: Vec<u8>,
ca_cert: Vec<u8>,
}
/// Read the three PEM files referenced by the security config.
/// Mirrors the file reads in [`build_tls_config`] but returns raw bytes instead of
/// parsing them into rustls types.
async fn read_cert_pems(config: &AppConfig) -> anyhow::Result<CertPems> {
let sec = &config.security;
let client_cert = tokio::fs::read(&sec.agent_client_cert_path)
.await
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
let client_key = tokio::fs::read(&sec.agent_client_key_path)
.await
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
let ca_cert = tokio::fs::read(&sec.ca_cert_path)
.await
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
Ok(CertPems {
client_cert,
client_key,
ca_cert,
})
}
// ── Entry point ───────────────────────────────────────────────────────────────
/// Long-running task: polls the DB for running host-jobs and spawns a per-pair
/// relay task for each one that isn't already being tracked.
pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
tracing::info!("WS relay task started");
let active: Arc<Mutex<HashSet<(Uuid, Uuid)>>> = Arc::new(Mutex::new(HashSet::new()));
let mut interval = tokio::time::interval(Duration::from_secs(10));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
let rows = match query_running_jobs(&pool).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "ws_relay: DB poll failed");
continue;
},
};
for row in rows {
let key = (row.job_id, row.host_id);
// Skip pairs that already have an active relay.
if active.lock().await.contains(&key) {
continue;
}
// Build the rustls ClientConfig once per connection.
let tls_config = match build_tls_config(&config).await {
Ok(c) => Arc::new(c),
Err(e) => {
tracing::error!(error = %e, "ws_relay: TLS config error");
continue;
},
};
// Read raw cert PEM bytes for HTTP polling fallback.
let cert_pems = match read_cert_pems(&config).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "ws_relay: cert PEM read error");
continue;
},
};
let poll_interval = config.worker.ws_relay_poll_interval_secs;
active.lock().await.insert(key);
let pool_c = pool.clone();
let active_c = active.clone();
tokio::spawn(async move {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
agent_job_id = %row.agent_job_id,
host = %row.host_address,
"WS relay: starting relay"
);
match relay_one_job(&pool_c, &row, tls_config, &cert_pems, poll_interval).await {
Ok(()) => tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: completed"
),
Err(e) => tracing::error!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: ended with error"
),
}
active_c.lock().await.remove(&key);
});
}
}
}
// ── DB helpers ────────────────────────────────────────────────────────────────
async fn query_running_jobs(pool: &PgPool) -> anyhow::Result<Vec<RunningHostJob>> {
sqlx::query_as::<_, RunningHostJob>(
r#"
SELECT
pjh.job_id,
pjh.host_id,
pjh.agent_job_id,
COALESCE(h.fqdn, host(h.ip_address)::text) AS host_address
FROM patch_job_hosts pjh
JOIN hosts h ON h.id = pjh.host_id
WHERE pjh.status = 'running'::job_status
AND pjh.agent_job_id IS NOT NULL
"#,
)
.fetch_all(pool)
.await
.context("query_running_jobs")
}
// ── TLS ───────────────────────────────────────────────────────────────────────
async fn build_tls_config(config: &AppConfig) -> anyhow::Result<TlsClientConfig> {
let sec = &config.security;
let cert_pem = tokio::fs::read(&sec.agent_client_cert_path)
.await
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
let key_pem = tokio::fs::read(&sec.agent_client_key_path)
.await
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
let ca_pem = tokio::fs::read(&sec.ca_cert_path)
.await
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
// Parse client certificate chain.
let client_certs: Vec<CertificateDer<'static>> = {
let mut cur = std::io::Cursor::new(&cert_pem);
rustls_pemfile::certs(&mut cur)
.collect::<Result<Vec<_>, _>>()
.context("parse client cert PEM")?
};
// Parse client private key.
let client_key: PrivateKeyDer<'static> = {
let mut cur = std::io::Cursor::new(&key_pem);
rustls_pemfile::private_key(&mut cur)
.context("parse client key PEM")?
.context("no private key in PEM")?
};
// Build root store from CA cert.
let mut root_store = RootCertStore::empty();
{
let mut cur = std::io::Cursor::new(&ca_pem);
for cert_result in rustls_pemfile::certs(&mut cur) {
root_store
.add(cert_result.context("read CA cert entry")?)
.context("add CA cert to root store")?;
}
}
let mut config = TlsClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(client_certs, client_key)
.context("build TlsClientConfig")?;
// WebSocket requires HTTP/1.1 — without ALPN the server may negotiate
// h2 (HTTP/2), which breaks the WebSocket upgrade handshake
// ("Key mismatch in Sec-WebSocket-Accept header").
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(config)
}
// ── Per-job relay ─────────────────────────────────────────────────────────────
async fn relay_one_job(
pool: &PgPool,
row: &RunningHostJob,
tls_config: Arc<TlsClientConfig>,
cert_pems: &CertPems,
poll_interval_secs: u64,
) -> anyhow::Result<()> {
let url = format!(
"wss://{}:{}/api/v1/ws/jobs",
row.host_address, DEFAULT_AGENT_PORT,
);
let (ws_stream, _) = match connect_async_tls_with_config(
url.as_str(),
None,
false,
Some(Connector::Rustls(tls_config)),
)
.await
{
Ok(ws) => ws,
Err(e) => {
// Log the full error chain for TLS debugging
let mut source = e.source();
let mut depth = 0;
while let Some(err) = source {
tracing::warn!(
job_id = %row.job_id,
host_id = %row.host_id,
depth,
error = %err,
"WS relay: TLS connection error detail"
);
source = err.source();
depth += 1;
}
// Fall back to HTTP polling instead of returning an error.
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
host = %row.host_address,
"WS relay: WebSocket connection failed, falling back to HTTP polling"
);
return relay_one_job_poll(pool, row, cert_pems, poll_interval_secs).await;
},
};
let (_sink, mut stream) = ws_stream.split();
while let Some(frame) = stream.next().await {
let frame = match frame {
Ok(f) => f,
Err(e) => {
tracing::warn!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: stream error"
);
break;
},
};
let text = match frame {
Message::Text(t) => t.to_string(),
Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(),
Message::Close(_) => {
tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly");
break;
},
_ => continue,
};
if text.is_empty() {
continue;
}
let event: AgentWsEvent = match serde_json::from_str(&text) {
Ok(e) => e,
Err(e) => {
tracing::warn!(
error = %e, raw = %text,
"WS relay: unparseable agent frame"
);
continue;
},
};
process_event(pool, row, &event).await;
if matches!(event.status.as_str(), "succeeded" | "failed" | "cancelled") {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %event.status,
"WS relay: terminal state — stopping"
);
break;
}
}
Ok(())
}
// ── Per-job HTTP polling fallback ─────────────────────────────────────────────
/// Fall back to HTTP polling when the WebSocket connection fails.
///
/// Builds an [`AgentClient`] from the same cert paths used for TLS, polls
/// `GET /api/v1/jobs/{id}` every `poll_interval_secs`, and calls
/// [`process_event`] whenever the status changes from the previous poll.
/// Stops when the job reaches a terminal state (succeeded/failed/cancelled).
async fn relay_one_job_poll(
pool: &PgPool,
row: &RunningHostJob,
cert_pems: &CertPems,
poll_interval_secs: u64,
) -> anyhow::Result<()> {
// Build the HTTP client using the same mTLS certs.
let agent_client = AgentClient::new(
&row.host_address,
DEFAULT_AGENT_PORT,
&cert_pems.client_cert,
&cert_pems.client_key,
&cert_pems.ca_cert,
)
.context("build AgentClient for polling fallback")?;
let mut last_status: Option<String> = None;
let poll_interval = Duration::from_secs(poll_interval_secs);
loop {
tokio::time::sleep(poll_interval).await;
let job_status = match agent_client.job_status(&row.agent_job_id).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
error = %e,
"WS relay poll: job_status request failed, will retry"
);
continue;
},
};
// Map agent status to the WS event format.
// The agent uses "completed" but pm-worker expects "succeeded".
// The agent uses "queued" but pm-worker expects "running".
let mapped_status = match job_status.status.as_str() {
"queued" => "running",
"running" => "running",
"succeeded" => "succeeded",
"completed" => "succeeded",
"failed" => "failed",
"cancelled" => "cancelled",
other => {
tracing::warn!(
status = %other,
job_id = %row.job_id,
"WS relay poll: unknown agent status, treating as running"
);
"running"
},
};
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
agent_status = %job_status.status,
mapped_status = %mapped_status,
progress = ?job_status.progress_percent,
"WS relay poll: fetched job status"
);
// Only process when the status has changed since the last poll.
if last_status.as_deref() == Some(mapped_status) {
continue;
}
last_status = Some(mapped_status.to_string());
let event = AgentWsEvent {
job_id: job_status.job_id.clone(),
status: mapped_status.to_string(),
output: job_status.output.clone(),
error: job_status.error.clone(),
progress_percent: job_status.progress_percent,
};
process_event(pool, row, &event).await;
// Stop polling on terminal states.
if matches!(mapped_status, "succeeded" | "failed" | "cancelled") {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %mapped_status,
"WS relay poll: terminal state — stopping"
);
break;
}
}
Ok(())
}
// ── Event processing ──────────────────────────────────────────────────────────
async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) {
// Map agent status string to DB job_status enum value.
let db_status = match event.status.as_str() {
"running" => "running",
"succeeded" => "succeeded",
"failed" => "failed",
"cancelled" => "cancelled",
other => {
tracing::warn!(status = %other, "WS relay: unknown agent status");
return;
},
};
let output = event.output.as_deref().unwrap_or("");
let error_msg = event.error.as_deref();
// Determine timestamps based on terminal state.
let is_terminal = matches!(db_status, "succeeded" | "failed" | "cancelled");
// Update the DB row.
let update_result = if is_terminal {
sqlx::query(
r#"
UPDATE patch_job_hosts
SET status = $1::job_status,
output = CASE WHEN $2 != '' THEN $2 ELSE output END,
error_message = $3,
completed_at = NOW()
WHERE job_id = $4
AND host_id = $5
"#,
)
.bind(db_status)
.bind(output)
.bind(error_msg)
.bind(row.job_id)
.bind(row.host_id)
.execute(pool)
.await
} else {
sqlx::query(
r#"
UPDATE patch_job_hosts
SET status = $1::job_status,
output = CASE WHEN $2 != '' THEN $2 ELSE output END
WHERE job_id = $3
AND host_id = $4
"#,
)
.bind(db_status)
.bind(output)
.bind(row.job_id)
.bind(row.host_id)
.execute(pool)
.await
};
if let Err(e) = update_result {
tracing::error!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: DB update failed"
);
return;
}
// Also update the parent patch_jobs status when the host-level job reaches
// a terminal state: running → if all hosts terminal then update parent.
if is_terminal {
update_parent_job_status(pool, row.job_id).await;
}
// Fire pg_notify so browser WS handlers forward the host-level event.
let payload = NotifyPayload {
event_type: "host".to_string(),
job_id: row.job_id.to_string(),
host_id: row.host_id.to_string(),
status: db_status.to_string(),
output: event.output.clone(),
error_message: event.error.clone(),
agent_job_id: row.agent_job_id.clone(),
succeeded_count: None,
failed_count: None,
host_count: None,
};
let payload_json = match serde_json::to_string(&payload) {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "WS relay: failed to serialize notify payload");
return;
},
};
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
.bind(&payload_json)
.execute(pool)
.await
{
tracing::error!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: pg_notify failed"
);
} else {
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %db_status,
"WS relay: pg_notify sent"
);
}
}
// ── Parent job status rollup ──────────────────────────────────────────────────
/// After a host-level job reaches a terminal state, check whether ALL hosts for
/// that job are now terminal and update the parent `patch_jobs` row accordingly.
///
/// If the parent job transitions to a terminal status, also fires a `job_update`
/// pg_notify with `event_type: "job"` so the frontend can update the job row.
async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) {
// Count hosts that are still in a non-terminal state.
let pending: i64 = match sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM patch_job_hosts
WHERE job_id = $1
AND status NOT IN (
'succeeded'::job_status,
'failed'::job_status,
'cancelled'::job_status
)
"#,
)
.bind(job_id)
.fetch_one(pool)
.await
{
Ok(n) => n,
Err(e) => {
tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed");
return;
},
};
if pending > 0 {
return; // still hosts running — parent stays running
}
// All hosts terminal — determine final parent status and counts.
#[derive(sqlx::FromRow)]
struct RollupCounts {
total: i64,
succeeded: i64,
failed: i64,
}
let counts: RollupCounts = match sqlx::query_as(
r#"
SELECT
COUNT(*) AS total,
COUNT(*) FILTER (WHERE status = 'succeeded') AS succeeded,
COUNT(*) FILTER (WHERE status = 'failed') AS failed
FROM patch_job_hosts
WHERE job_id = $1
"#,
)
.bind(job_id)
.fetch_one(pool)
.await
{
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, %job_id, "update_parent_job_status: rollup query failed");
return;
},
};
let final_status = if counts.failed > 0 {
"failed"
} else {
"succeeded"
};
if let Err(e) = sqlx::query(
"UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2",
)
.bind(final_status)
.bind(job_id)
.execute(pool)
.await
{
tracing::error!(
error = %e,
%job_id,
status = %final_status,
"update_parent_job_status: UPDATE failed"
);
return;
}
tracing::info!(
%job_id,
status = %final_status,
"Parent job status updated"
);
// Fire job-level pg_notify so the frontend can update the job row.
let payload = NotifyPayload {
event_type: "job".to_string(),
job_id: job_id.to_string(),
host_id: String::new(), // no specific host for job-level events
status: final_status.to_string(),
output: None,
error_message: None,
agent_job_id: String::new(),
succeeded_count: Some(counts.succeeded),
failed_count: Some(counts.failed),
host_count: Some(counts.total),
};
let payload_json = match serde_json::to_string(&payload) {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, %job_id, "update_parent_job_status: failed to serialize job-level notify payload");
return;
},
};
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
.bind(&payload_json)
.execute(pool)
.await
{
tracing::error!(
error = %e,
%job_id,
"update_parent_job_status: job-level pg_notify failed"
);
} else {
tracing::info!(
%job_id,
status = %final_status,
"Job-level pg_notify sent"
);
}
}