feat: add bump-version.sh script for version management
Automates version bumps across all version source files: - Cargo.toml (PRIMARY - workspace.package.version) - debian/changelog (prepend new entry) - debian/control (update Version field) - scripts/build-package.sh (update VERSION variable) - frontend/package.json (update version field) - Stale references check after bump Usage: ./scripts/bump-version.sh <new_version> <old_version>
This commit is contained in:
19
crates/pm-agent-client/Cargo.toml
Normal file
19
crates/pm-agent-client/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "pm-agent-client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
274
crates/pm-agent-client/src/client.rs
Executable file
274
crates/pm-agent-client/src/client.rs
Executable file
@ -0,0 +1,274 @@
|
||||
//! mTLS HTTP client for communicating with Linux Patch API agents.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use pm_agent_client::client::AgentClient;
|
||||
//!
|
||||
//! # async fn example() -> Result<(), pm_agent_client::error::AgentClientError> {
|
||||
//! let client = AgentClient::new(
|
||||
//! "192.168.1.10",
|
||||
//! 12443,
|
||||
//! include_bytes!("../certs/client.crt"),
|
||||
//! include_bytes!("../certs/client.key"),
|
||||
//! include_bytes!("../certs/ca.crt"),
|
||||
//! )?;
|
||||
//!
|
||||
//! let health = client.health().await?;
|
||||
//! println!("Agent status: {}", health.status);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use reqwest::{tls::Version, Certificate, ClientBuilder, Identity};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use crate::{
|
||||
error::AgentClientError,
|
||||
types::{
|
||||
AgentEnvelope, AgentJobStatus, ApplyPatchesRequest, ApplyPatchesResponse, HealthData,
|
||||
PackagesData, PatchesData, RollbackResponse, ServiceStatusData, SystemInfoData,
|
||||
},
|
||||
};
|
||||
|
||||
/// Default TCP port that the Linux Patch API agent listens on.
|
||||
pub const DEFAULT_AGENT_PORT: u16 = 12443;
|
||||
|
||||
/// Request timeout applied to every agent API call.
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
// ============================================================
|
||||
// AgentClient
|
||||
// ============================================================
|
||||
|
||||
/// Async HTTP client that speaks mTLS to a single Linux Patch API agent.
|
||||
///
|
||||
/// Construct once via [`AgentClient::new`] and reuse across calls;
|
||||
/// the underlying [`reqwest::Client`] maintains a connection pool.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentClient {
|
||||
/// Underlying HTTP client (configured for mTLS + TLS 1.3).
|
||||
inner: reqwest::Client,
|
||||
/// Base URL of the agent, e.g. `https://10.0.0.5:12443/api/v1`.
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl AgentClient {
|
||||
/// Create a new [`AgentClient`] configured for mTLS.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `host_ip` – IP address (or hostname) of the agent.
|
||||
/// * `port` – TCP port the agent listens on (default [`DEFAULT_AGENT_PORT`]).
|
||||
/// * `client_cert_pem` – PEM-encoded client certificate presented during the TLS handshake.
|
||||
/// * `client_key_pem` – PEM-encoded private key matching `client_cert_pem`.
|
||||
/// * `ca_cert_pem` – PEM-encoded CA certificate used to verify the agent's server cert.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AgentClientError::Tls`] when certificate parsing fails, or
|
||||
/// [`AgentClientError::Request`] when `reqwest` client construction fails.
|
||||
pub fn new(
|
||||
host_ip: &str,
|
||||
port: u16,
|
||||
client_cert_pem: &[u8],
|
||||
client_key_pem: &[u8],
|
||||
ca_cert_pem: &[u8],
|
||||
) -> Result<Self, AgentClientError> {
|
||||
// Build client identity: reqwest expects cert + key concatenated as PEM.
|
||||
let mut identity_pem = Vec::with_capacity(client_cert_pem.len() + client_key_pem.len());
|
||||
identity_pem.extend_from_slice(client_cert_pem);
|
||||
identity_pem.extend_from_slice(client_key_pem);
|
||||
|
||||
let identity = Identity::from_pem(&identity_pem)
|
||||
.map_err(|e| AgentClientError::Tls(format!("invalid client identity PEM: {e}")))?;
|
||||
|
||||
// Parse the CA certificate used to verify the agent's server certificate.
|
||||
let ca_cert = Certificate::from_pem(ca_cert_pem)
|
||||
.map_err(|e| AgentClientError::Tls(format!("invalid CA certificate PEM: {e}")))?;
|
||||
|
||||
// Build the reqwest client:
|
||||
// - force rustls TLS backend
|
||||
// - disable built-in OS/system trust roots (only trust our internal CA)
|
||||
// - enforce TLS 1.3 minimum
|
||||
// - attach client identity (mTLS)
|
||||
// - add our CA as a trusted root
|
||||
// - apply a global request timeout
|
||||
let inner = ClientBuilder::new()
|
||||
.use_rustls_tls()
|
||||
.tls_built_in_root_certs(false)
|
||||
.min_tls_version(Version::TLS_1_3)
|
||||
.identity(identity)
|
||||
.add_root_certificate(ca_cert)
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()
|
||||
.map_err(AgentClientError::Request)?;
|
||||
|
||||
let clean_ip = host_ip.split('/').next().unwrap_or(host_ip);
|
||||
let base_url = format!("https://{}:{}/api/v1", clean_ip, port);
|
||||
|
||||
Ok(Self { inner, base_url })
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Public API methods
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// `GET /api/v1/health` — check agent liveness and retrieve uptime.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn health(&self) -> Result<HealthData, AgentClientError> {
|
||||
self.get("health", &[]).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/system/info` — retrieve host system information.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn system_info(&self) -> Result<SystemInfoData, AgentClientError> {
|
||||
self.get("system/info", &[]).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/packages?status=upgradable` — list packages with available upgrades.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn packages_upgradable(&self) -> Result<PackagesData, AgentClientError> {
|
||||
self.get("packages", &[("status", "upgradable")]).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/patches` — list available patches with severity and CVE data.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn patches(&self) -> Result<PatchesData, AgentClientError> {
|
||||
self.get("patches", &[]).await
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Private helpers
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// Execute a GET request against `{base_url}/{path}` with optional query
|
||||
/// parameters, deserialize the [`AgentEnvelope`], and extract the `data`
|
||||
/// field — or propagate an [`AgentClientError::ApiError`].
|
||||
async fn get<T>(&self, path: &str, query: &[(&str, &str)]) -> Result<T, AgentClientError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let url = format!("{}/{}", self.base_url, path);
|
||||
debug!(url = %url, ?query, "Sending GET request to agent");
|
||||
|
||||
let mut request = self.inner.get(&url);
|
||||
if !query.is_empty() {
|
||||
request = request.query(query);
|
||||
}
|
||||
|
||||
let response = request.send().await?;
|
||||
let status = response.status();
|
||||
debug!(url = %url, status = %status, "Received response from agent");
|
||||
|
||||
// Capture body text so we can attempt to deserialise the error envelope
|
||||
// even for non-2xx responses.
|
||||
let body = response.text().await?;
|
||||
|
||||
// Attempt to parse the standard agent envelope regardless of HTTP status.
|
||||
// The agent may embed a structured error body on 4xx/5xx responses.
|
||||
let envelope: AgentEnvelope<T> = serde_json::from_str(&body)?;
|
||||
|
||||
if !status.is_success() || !envelope.success {
|
||||
// Prefer the structured error from the envelope when present.
|
||||
if let Some(err) = envelope.error {
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: err.code,
|
||||
message: err.message,
|
||||
});
|
||||
}
|
||||
// Fallback: use the HTTP status as the error indicator.
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: status.as_str().to_string(),
|
||||
message: format!("Agent returned HTTP {} for {}", status.as_u16(), url),
|
||||
});
|
||||
}
|
||||
|
||||
// On success the `data` field must be present.
|
||||
envelope.data.ok_or_else(|| AgentClientError::ApiError {
|
||||
code: "MISSING_DATA".to_string(),
|
||||
message: "Agent response success=true but data field is absent".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Patch apply / job management methods
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// `POST /api/v1/patches/apply` — trigger patch application on the agent.
|
||||
#[instrument(skip(self, req), fields(base_url = %self.base_url))]
|
||||
pub async fn apply_patches(
|
||||
&self,
|
||||
req: &ApplyPatchesRequest,
|
||||
) -> Result<ApplyPatchesResponse, AgentClientError> {
|
||||
self.post("patches/apply", req).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/jobs/{id}` — poll an async agent job for status.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))]
|
||||
pub async fn job_status(&self, job_id: &str) -> Result<AgentJobStatus, AgentClientError> {
|
||||
self.get(&format!("jobs/{}", job_id), &[]).await
|
||||
}
|
||||
|
||||
/// `POST /api/v1/jobs/{id}/rollback` — trigger rollback on the agent.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))]
|
||||
pub async fn rollback_job(&self, job_id: &str) -> Result<RollbackResponse, AgentClientError> {
|
||||
let empty: serde_json::Value = serde_json::json!({});
|
||||
self.post(&format!("jobs/{}/rollback", job_id), &empty)
|
||||
.await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/system/services/{name}` — check status of a specific service on the agent.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url, service_name = %service_name))]
|
||||
pub async fn service_status(
|
||||
&self,
|
||||
service_name: &str,
|
||||
) -> Result<ServiceStatusData, AgentClientError> {
|
||||
self.get(&format!("system/services/{}", service_name), &[])
|
||||
.await
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Private POST helper
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// Execute a POST request against `{base_url}/{path}`, serialize `body` as
|
||||
/// JSON, deserialize the [`AgentEnvelope`], and extract the `data` field —
|
||||
/// or propagate an [`AgentClientError::ApiError`].
|
||||
async fn post<Req, Resp>(&self, path: &str, body: &Req) -> Result<Resp, AgentClientError>
|
||||
where
|
||||
Req: Serialize,
|
||||
Resp: DeserializeOwned,
|
||||
{
|
||||
let url = format!("{}/{}", self.base_url, path);
|
||||
debug!(url = %url, "Sending POST request to agent");
|
||||
|
||||
let response = self.inner.post(&url).json(body).send().await?;
|
||||
let status = response.status();
|
||||
debug!(url = %url, status = %status, "Received POST response from agent");
|
||||
|
||||
let body_text = response.text().await?;
|
||||
let envelope: AgentEnvelope<Resp> = serde_json::from_str(&body_text)?;
|
||||
|
||||
if !status.is_success() || !envelope.success {
|
||||
if let Some(err) = envelope.error {
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: err.code,
|
||||
message: err.message,
|
||||
});
|
||||
}
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: status.as_str().to_string(),
|
||||
message: format!("Agent returned HTTP {} for {}", status.as_u16(), url),
|
||||
});
|
||||
}
|
||||
|
||||
envelope.data.ok_or_else(|| AgentClientError::ApiError {
|
||||
code: "MISSING_DATA".to_string(),
|
||||
message: "Agent response success=true but data field is absent".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
49
crates/pm-agent-client/src/error.rs
Executable file
49
crates/pm-agent-client/src/error.rs
Executable file
@ -0,0 +1,49 @@
|
||||
//! Error types for the pm-agent-client crate.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Top-level error type returned by [`crate::client::AgentClient`] methods.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AgentClientError {
|
||||
/// TLS configuration or handshake failure.
|
||||
#[error("TLS error: {0}")]
|
||||
Tls(String),
|
||||
|
||||
/// Unable to establish a TCP/TLS connection to the agent.
|
||||
#[error("Connection error: {0}")]
|
||||
Connect(#[source] reqwest::Error),
|
||||
|
||||
/// An HTTP request or response transport error (not a timeout).
|
||||
#[error("Request error: {0}")]
|
||||
Request(#[source] reqwest::Error),
|
||||
|
||||
/// The request did not complete within the configured timeout.
|
||||
#[error("Request timed out")]
|
||||
Timeout,
|
||||
|
||||
/// The agent returned a non-2xx HTTP status or `success: false` in the
|
||||
/// response envelope.
|
||||
#[error("Agent API error [{code}]: {message}")]
|
||||
ApiError {
|
||||
/// Machine-readable error code supplied by the agent (e.g. `"NOT_FOUND"`).
|
||||
code: String,
|
||||
/// Human-readable description returned by the agent.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// JSON deserialization of the agent response failed.
|
||||
#[error("Failed to deserialise agent response: {0}")]
|
||||
Deserialize(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for AgentClientError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
if err.is_timeout() {
|
||||
AgentClientError::Timeout
|
||||
} else if err.is_connect() {
|
||||
AgentClientError::Connect(err)
|
||||
} else {
|
||||
AgentClientError::Request(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
43
crates/pm-agent-client/src/lib.rs
Executable file
43
crates/pm-agent-client/src/lib.rs
Executable file
@ -0,0 +1,43 @@
|
||||
//! `pm-agent-client` — mTLS HTTP client for Linux Patch API agent communication.
|
||||
//!
|
||||
//! This crate provides [`client::AgentClient`], an async HTTP client that
|
||||
//! establishes mutual-TLS connections (TLS 1.3) to `linux_patch_api` agents
|
||||
//! running on managed hosts.
|
||||
//!
|
||||
//! # Quick start
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use pm_agent_client::AgentClient;
|
||||
//!
|
||||
//! # async fn run() -> Result<(), pm_agent_client::AgentClientError> {
|
||||
//! let client = AgentClient::new(
|
||||
//! "10.0.1.5",
|
||||
//! 12443,
|
||||
//! include_bytes!("../certs/client.crt"),
|
||||
//! include_bytes!("../certs/client.key"),
|
||||
//! include_bytes!("../certs/ca.crt"),
|
||||
//! )?;
|
||||
//!
|
||||
//! let health = client.health().await?;
|
||||
//! println!("Agent {}: {}", health.status, health.version);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
pub mod types;
|
||||
|
||||
// ── Convenience re-exports ──────────────────────────────────────────────────
|
||||
|
||||
/// Primary client — re-exported from [`client::AgentClient`].
|
||||
pub use client::{AgentClient, DEFAULT_AGENT_PORT};
|
||||
|
||||
/// Error type — re-exported from [`error::AgentClientError`].
|
||||
pub use error::AgentClientError;
|
||||
|
||||
/// Response envelope and all data types.
|
||||
pub use types::{
|
||||
AgentEnvelope, AgentErrorBody, HealthData, Package, PackagesData, Patch, PatchesData,
|
||||
RollbackResponse, ServiceStatusData, SystemInfoData,
|
||||
};
|
||||
230
crates/pm-agent-client/src/types.rs
Executable file
230
crates/pm-agent-client/src/types.rs
Executable file
@ -0,0 +1,230 @@
|
||||
//! Response and request types for the Linux Patch API agent endpoints.
|
||||
//!
|
||||
//! All agent responses are wrapped in [`AgentEnvelope<T>`].
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
// ============================================================
|
||||
// Envelope & error
|
||||
// ============================================================
|
||||
|
||||
/// Generic response wrapper returned by every agent endpoint.
|
||||
///
|
||||
/// ```json
|
||||
/// { "success": true, "request_id": "…", "timestamp": "…", "data": {…}, "error": null }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentEnvelope<T> {
|
||||
/// `true` when the request succeeded; `false` on error.
|
||||
pub success: bool,
|
||||
/// Server-assigned request identifier (UUID v4).
|
||||
pub request_id: Uuid,
|
||||
/// Server timestamp for the response (ISO-8601 / RFC-3339).
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Response payload — present when `success` is `true`.
|
||||
pub data: Option<T>,
|
||||
/// Error detail — present when `success` is `false`.
|
||||
pub error: Option<AgentErrorBody>,
|
||||
}
|
||||
|
||||
/// Structured error returned inside [`AgentEnvelope::error`].
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentErrorBody {
|
||||
/// Machine-readable error code (e.g. `"INTERNAL_ERROR"`).
|
||||
pub code: String,
|
||||
/// Human-readable description of what went wrong.
|
||||
pub message: String,
|
||||
/// Optional free-form extra detail from the agent.
|
||||
#[serde(default)]
|
||||
pub details: Option<serde_json::Value>,
|
||||
/// Whether the caller may safely retry the request.
|
||||
#[serde(default)]
|
||||
pub retryable: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/health
|
||||
// ============================================================
|
||||
|
||||
/// Payload returned by `GET /api/v1/health`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct HealthData {
|
||||
/// Agent status string, e.g. `"ok"` or `"degraded"`.
|
||||
pub status: String,
|
||||
/// Seconds elapsed since the agent process started.
|
||||
pub uptime_seconds: u64,
|
||||
/// Agent software version string.
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/system/info
|
||||
// ============================================================
|
||||
|
||||
/// Payload returned by `GET /api/v1/system/info`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct SystemInfoData {
|
||||
/// Hostname of the managed system.
|
||||
pub hostname: String,
|
||||
/// OS family / distribution name (e.g. `"Ubuntu"`).
|
||||
pub os: String,
|
||||
/// OS version string.
|
||||
pub os_version: String,
|
||||
/// Kernel version string.
|
||||
pub kernel: String,
|
||||
/// CPU architecture (e.g. `"x86_64"`).
|
||||
pub architecture: String,
|
||||
/// When the agent last checked for updates (`null` if never).
|
||||
pub last_update_check: Option<DateTime<Utc>>,
|
||||
/// When updates were last applied (`null` if never).
|
||||
pub last_update_apply: Option<DateTime<Utc>>,
|
||||
/// Whether the system has a pending reboot.
|
||||
pub pending_reboot: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/packages?status=upgradable
|
||||
// ============================================================
|
||||
|
||||
/// A single package entry.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Package {
|
||||
/// Package name.
|
||||
pub name: String,
|
||||
/// Installed version.
|
||||
pub version: String,
|
||||
/// Package status string (e.g. `"installed"`, `"upgradable"`).
|
||||
pub status: String,
|
||||
/// Whether a newer version is available.
|
||||
pub upgradable: bool,
|
||||
/// Latest available version (`null` if not upgradable).
|
||||
pub latest_version: Option<String>,
|
||||
/// Short package description.
|
||||
pub description: String,
|
||||
/// CVE identifiers associated with this package.
|
||||
#[serde(default)]
|
||||
pub cve_ids: Vec<String>,
|
||||
}
|
||||
|
||||
/// Payload returned by `GET /api/v1/packages`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct PackagesData {
|
||||
/// List of packages matching the query filters.
|
||||
pub packages: Vec<Package>,
|
||||
/// Total count of matching packages.
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/patches
|
||||
// ============================================================
|
||||
|
||||
/// A single available patch.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Patch {
|
||||
/// Package / patch name.
|
||||
pub name: String,
|
||||
/// Currently installed version.
|
||||
pub current_version: String,
|
||||
/// Version available after applying this patch.
|
||||
pub available_version: String,
|
||||
/// Severity level (e.g. `"critical"`, `"high"`, `"medium"`, `"low"`).
|
||||
pub severity: String,
|
||||
/// Human-readable description of the patch.
|
||||
pub description: String,
|
||||
/// CVE identifiers addressed by this patch.
|
||||
#[serde(default)]
|
||||
pub cve_ids: Vec<String>,
|
||||
/// Whether applying this patch requires a system reboot.
|
||||
pub requires_reboot: bool,
|
||||
}
|
||||
|
||||
/// Payload returned by `GET /api/v1/patches`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct PatchesData {
|
||||
/// List of available patches.
|
||||
pub patches: Vec<Patch>,
|
||||
/// Total patch count.
|
||||
pub total: u64,
|
||||
/// Number of patches classified as security updates.
|
||||
pub security_updates: u64,
|
||||
/// Whether any patch in the list requires a reboot.
|
||||
pub requires_reboot: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/patches/apply
|
||||
// ============================================================
|
||||
|
||||
/// Request body for `POST /api/v1/patches/apply`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApplyPatchesRequest {
|
||||
/// Package names to apply. Empty = apply all available patches.
|
||||
pub packages: Vec<String>,
|
||||
/// If true, allow automatic reboot after patching if required.
|
||||
pub allow_reboot: bool,
|
||||
}
|
||||
|
||||
/// Response from `POST /api/v1/patches/apply`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ApplyPatchesResponse {
|
||||
/// Agent-assigned async job ID for status polling.
|
||||
pub job_id: String,
|
||||
/// Initial status: typically `"running"` or `"queued"`.
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/jobs/{id}
|
||||
// ============================================================
|
||||
|
||||
/// Status of an async agent job returned by `GET /api/v1/jobs/{id}`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentJobStatus {
|
||||
pub job_id: String,
|
||||
/// Current status: `"queued"`, `"running"`, `"succeeded"`, `"completed"`, `"failed"`, or `"cancelled"`.
|
||||
pub status: String,
|
||||
pub progress_percent: Option<u8>,
|
||||
pub output: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/system/services/{name}
|
||||
// ============================================================
|
||||
|
||||
/// Payload returned by `GET /api/v1/system/services/{name}`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ServiceStatusData {
|
||||
/// Service name.
|
||||
pub name: String,
|
||||
/// Human-readable service name.
|
||||
pub display_name: String,
|
||||
/// Active state (e.g. `"active"`, `"inactive"`, `"failed"`).
|
||||
pub active_state: String,
|
||||
/// Sub state (e.g. `"running"`, `"dead"`, `"exited"`).
|
||||
pub sub_state: String,
|
||||
/// Load state (e.g. `"loaded"`, `"not-found"`).
|
||||
pub load_state: String,
|
||||
/// Enabled state (e.g. `"enabled"`, `"disabled"`).
|
||||
pub enabled_state: String,
|
||||
/// Main PID of the service process.
|
||||
pub main_pid: Option<u32>,
|
||||
/// Whether the service is considered healthy.
|
||||
pub healthy: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/jobs/{id}/rollback
|
||||
// ============================================================
|
||||
|
||||
/// Response from `POST /api/v1/jobs/{id}/rollback`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RollbackResponse {
|
||||
pub job_id: String,
|
||||
pub status: String,
|
||||
}
|
||||
29
crates/pm-auth/Cargo.toml
Normal file
29
crates/pm-auth/Cargo.toml
Normal file
@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "pm-auth"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
argon2 = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
totp-rs = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
parking_lot = "0.12"
|
||||
sha2 = { workspace = true }
|
||||
152
crates/pm-auth/src/jwt.rs
Executable file
152
crates/pm-auth/src/jwt.rs
Executable file
@ -0,0 +1,152 @@
|
||||
//! JWT issuance and validation using EdDSA / Ed25519.
|
||||
//!
|
||||
//! - Access tokens: 15-minute TTL, signed with Ed25519 private key
|
||||
//! - Key rotation: 90-day cycle with 24-hour overlap window
|
||||
//! - The web process holds the signing key; worker holds only the public key
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// JWT algorithm — EdDSA with Ed25519 curve.
|
||||
const JWT_ALGORITHM: Algorithm = Algorithm::EdDSA;
|
||||
|
||||
/// Default access token TTL in seconds.
|
||||
pub const DEFAULT_ACCESS_TTL_SECS: i64 = 900; // 15 minutes
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum JwtError {
|
||||
#[error("Failed to encode JWT: {0}")]
|
||||
Encode(String),
|
||||
#[error("Failed to decode JWT: {0}")]
|
||||
Decode(String),
|
||||
#[error("Token is expired")]
|
||||
Expired,
|
||||
#[error("Token has invalid claims")]
|
||||
InvalidClaims,
|
||||
#[error("Failed to load signing key: {0}")]
|
||||
KeyLoad(String),
|
||||
}
|
||||
|
||||
/// Standard JWT claims for access tokens.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessClaims {
|
||||
/// Subject: user ID (UUID)
|
||||
pub sub: String,
|
||||
/// Issued at (Unix timestamp)
|
||||
pub iat: i64,
|
||||
/// Expiry (Unix timestamp)
|
||||
pub exp: i64,
|
||||
/// JWT ID (unique per token)
|
||||
pub jti: String,
|
||||
/// User role: "admin" or "operator"
|
||||
pub role: String,
|
||||
/// Username (for display / logging)
|
||||
pub username: String,
|
||||
}
|
||||
|
||||
impl AccessClaims {
|
||||
/// Create new claims for the given user.
|
||||
pub fn new(user_id: Uuid, username: &str, role: &str, ttl_secs: i64) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
sub: user_id.to_string(),
|
||||
iat: now.timestamp(),
|
||||
exp: (now + Duration::seconds(ttl_secs)).timestamp(),
|
||||
jti: Uuid::new_v4().to_string(),
|
||||
role: role.to_string(),
|
||||
username: username.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the token is expired (redundant with validation but useful for explicit checks).
|
||||
pub fn is_expired(&self) -> bool {
|
||||
Utc::now().timestamp() > self.exp
|
||||
}
|
||||
|
||||
/// Return the user UUID parsed from the `sub` field.
|
||||
pub fn user_id(&self) -> Result<Uuid, JwtError> {
|
||||
Uuid::parse_str(&self.sub).map_err(|_| JwtError::InvalidClaims)
|
||||
}
|
||||
}
|
||||
|
||||
/// Issue an access token signed with the Ed25519 private key PEM.
|
||||
pub fn issue_access_token(
|
||||
user_id: Uuid,
|
||||
username: &str,
|
||||
role: &str,
|
||||
ttl_secs: i64,
|
||||
signing_key_pem: &str,
|
||||
) -> Result<String, JwtError> {
|
||||
let claims = AccessClaims::new(user_id, username, role, ttl_secs);
|
||||
|
||||
let key = EncodingKey::from_ed_pem(signing_key_pem.as_bytes())
|
||||
.map_err(|e| JwtError::KeyLoad(e.to_string()))?;
|
||||
|
||||
let header = Header::new(JWT_ALGORITHM);
|
||||
|
||||
encode(&header, &claims, &key).map_err(|e| JwtError::Encode(e.to_string()))
|
||||
}
|
||||
|
||||
/// Validate and decode an access token using the Ed25519 public key PEM.
|
||||
pub fn validate_access_token(token: &str, verify_key_pem: &str) -> Result<AccessClaims, JwtError> {
|
||||
let key = DecodingKey::from_ed_pem(verify_key_pem.as_bytes())
|
||||
.map_err(|e| JwtError::KeyLoad(e.to_string()))?;
|
||||
|
||||
let mut validation = Validation::new(JWT_ALGORITHM);
|
||||
validation.validate_exp = true;
|
||||
validation.leeway = 5; // 5-second clock skew tolerance
|
||||
|
||||
decode::<AccessClaims>(token, &key, &validation)
|
||||
.map(|data| data.claims)
|
||||
.map_err(|e| {
|
||||
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
|
||||
JwtError::Expired
|
||||
} else {
|
||||
JwtError::Decode(e.to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Load the Ed25519 signing key from a PEM file path.
|
||||
pub fn load_signing_key(path: &str) -> Result<String, JwtError> {
|
||||
std::fs::read_to_string(path).map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}")))
|
||||
}
|
||||
|
||||
/// Load the Ed25519 verification (public) key from a PEM file path.
|
||||
pub fn load_verify_key(path: &str) -> Result<String, JwtError> {
|
||||
std::fs::read_to_string(path).map_err(|e| JwtError::KeyLoad(format!("Cannot read {path}: {e}")))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Test keys generated with:
|
||||
// openssl genpkey -algorithm ed25519 -out signing.pem
|
||||
// openssl pkey -in signing.pem -pubout -out verify.pem
|
||||
const TEST_SIGNING_KEY: &str = "-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIHNzPc3LkpODUVFr8GjVPm4M2yiKrXsZ/1uJQ/tQMjNb
|
||||
-----END PRIVATE KEY-----
|
||||
";
|
||||
|
||||
const TEST_VERIFY_KEY: &str = "-----BEGIN PUBLIC KEY-----
|
||||
MCowBQYDK2VwAyEA8nRzpCYzZ1xFKNJDGt9wuXdq7kKS/ck9PfLJu/r3VEw=
|
||||
-----END PUBLIC KEY-----
|
||||
";
|
||||
|
||||
// Note: real tests require valid key pairs; these are placeholders.
|
||||
// Integration tests in the test suite use generated keys.
|
||||
#[test]
|
||||
fn claims_construction() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let claims = AccessClaims::new(user_id, "admin", "admin", 900);
|
||||
assert_eq!(claims.sub, user_id.to_string());
|
||||
assert_eq!(claims.role, "admin");
|
||||
assert!(!claims.is_expired());
|
||||
assert_eq!(claims.user_id().unwrap(), user_id);
|
||||
}
|
||||
}
|
||||
25
crates/pm-auth/src/lib.rs
Executable file
25
crates/pm-auth/src/lib.rs
Executable file
@ -0,0 +1,25 @@
|
||||
//! pm-auth — Authentication and authorization.
|
||||
//!
|
||||
//! Modules:
|
||||
//! - `password` — Argon2id password hashing (m=65536, t=3, p=1)
|
||||
//! - `jwt` — EdDSA/Ed25519 JWT issuance and validation (15-min TTL)
|
||||
//! - `refresh` — Opaque 256-bit refresh tokens (1-hour sliding window)
|
||||
//! - `mfa_totp` — TOTP setup and verification (Google Authenticator compatible)
|
||||
//! - `mfa_webauthn` — WebAuthn stub (full implementation pending)
|
||||
//! - `rbac` — Axum middleware for JWT authentication and role enforcement
|
||||
//! - `session` — Login flow orchestration (password → MFA → tokens)
|
||||
|
||||
pub mod jwt;
|
||||
pub mod mfa_totp;
|
||||
pub mod mfa_webauthn;
|
||||
pub mod password;
|
||||
pub mod rbac;
|
||||
pub mod refresh;
|
||||
pub mod session;
|
||||
|
||||
// Commonly re-exported types
|
||||
pub use jwt::{AccessClaims, JwtError};
|
||||
pub use password::validate_password_strength;
|
||||
pub use password::{hash_password, verify_password, PasswordError};
|
||||
pub use rbac::{AuthConfig, AuthUser, UserRole};
|
||||
pub use session::{LoginRequest, LoginResponse, SessionError, SessionUser};
|
||||
103
crates/pm-auth/src/mfa_totp.rs
Executable file
103
crates/pm-auth/src/mfa_totp.rs
Executable file
@ -0,0 +1,103 @@
|
||||
//! TOTP (Time-based One-Time Password) MFA implementation.
|
||||
//!
|
||||
//! Uses TOTP-rs with HMAC-SHA1, 6-digit codes, 30-second window.
|
||||
//! Compatible with Google Authenticator, Authy, and standard TOTP apps.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use totp_rs::{Algorithm, Secret, TOTP};
|
||||
|
||||
/// TOTP issuer label shown in authenticator apps.
|
||||
const ISSUER: &str = "Linux Patch Manager";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TotpError {
|
||||
#[error("Failed to create TOTP: {0}")]
|
||||
Creation(String),
|
||||
#[error("Invalid TOTP secret")]
|
||||
InvalidSecret,
|
||||
#[error("TOTP code verification failed")]
|
||||
VerificationFailed,
|
||||
}
|
||||
|
||||
/// TOTP setup response returned to the user during MFA enrollment.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TotpSetup {
|
||||
/// Base32-encoded secret for manual entry in authenticator apps.
|
||||
pub secret_base32: String,
|
||||
/// OTP Auth URI for QR code generation (otpauth://totp/...).
|
||||
pub otp_uri: String,
|
||||
}
|
||||
|
||||
/// Generate a new TOTP secret and return setup information.
|
||||
///
|
||||
/// The caller should store `secret_base32` in the database after
|
||||
/// the user verifies the first code.
|
||||
pub fn generate_setup(username: &str) -> Result<TotpSetup, TotpError> {
|
||||
let secret = Secret::generate_secret();
|
||||
let secret_base32 = secret.to_encoded().to_string();
|
||||
|
||||
let totp = build_totp(username, &secret_base32)?;
|
||||
let otp_uri = totp.get_url();
|
||||
|
||||
Ok(TotpSetup {
|
||||
secret_base32,
|
||||
otp_uri,
|
||||
})
|
||||
}
|
||||
|
||||
/// Verify a TOTP code against the stored secret.
|
||||
///
|
||||
/// Accepts codes within a ±1 step window (±30 seconds) to handle clock skew.
|
||||
pub fn verify_code(username: &str, secret_base32: &str, code: &str) -> Result<bool, TotpError> {
|
||||
let totp = build_totp(username, secret_base32)?;
|
||||
let valid = totp
|
||||
.check_current(code)
|
||||
.map_err(|_| TotpError::VerificationFailed)?;
|
||||
Ok(valid)
|
||||
}
|
||||
|
||||
/// Build a TOTP instance from a base32 secret.
|
||||
fn build_totp(username: &str, secret_base32: &str) -> Result<TOTP, TotpError> {
|
||||
let secret = Secret::Encoded(secret_base32.to_string());
|
||||
let secret_bytes = secret.to_bytes().map_err(|_| TotpError::InvalidSecret)?;
|
||||
|
||||
// With the `otpauth` feature, TOTP::new signature is:
|
||||
// new(issuer, account_name, algorithm, digits, skew, step, secret)
|
||||
TOTP::new(
|
||||
Algorithm::SHA1,
|
||||
6, // digits
|
||||
1, // skew
|
||||
30, // step (seconds)
|
||||
secret_bytes,
|
||||
Some(ISSUER.to_string()),
|
||||
username.to_string(),
|
||||
)
|
||||
.map_err(|e| TotpError::Creation(e.to_string()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn generate_setup_produces_valid_uri() {
|
||||
let setup = generate_setup("testuser").unwrap();
|
||||
assert!(!setup.secret_base32.is_empty());
|
||||
assert!(setup.otp_uri.starts_with("otpauth://totp/"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_with_current_code() {
|
||||
let setup = generate_setup("testuser").unwrap();
|
||||
let totp = build_totp("testuser", &setup.secret_base32).unwrap();
|
||||
let code = totp.generate_current().unwrap();
|
||||
assert!(verify_code("testuser", &setup.secret_base32, &code).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_code_fails() {
|
||||
let setup = generate_setup("testuser").unwrap();
|
||||
assert!(!verify_code("testuser", &setup.secret_base32, "000000").unwrap());
|
||||
}
|
||||
}
|
||||
51
crates/pm-auth/src/mfa_webauthn.rs
Executable file
51
crates/pm-auth/src/mfa_webauthn.rs
Executable file
@ -0,0 +1,51 @@
|
||||
//! WebAuthn (FIDO2) MFA stub.
|
||||
//!
|
||||
//! Full implementation planned for M2 extension or M3.
|
||||
//! WebAuthn requires stateful registration/authentication ceremonies
|
||||
//! and a compatible client library (webauthn-rs).
|
||||
//!
|
||||
//! For M2, TOTP is the primary MFA method.
|
||||
//! WebAuthn credentials are stored in the `users.webauthn_credential` JSONB
|
||||
//! column and will be processed here when implemented.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebAuthnError {
|
||||
#[error("WebAuthn not yet implemented")]
|
||||
NotImplemented,
|
||||
}
|
||||
|
||||
/// Placeholder for WebAuthn registration options.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegistrationOptions {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Begin WebAuthn registration ceremony (stub).
|
||||
pub fn begin_registration(_username: &str) -> Result<RegistrationOptions, WebAuthnError> {
|
||||
Err(WebAuthnError::NotImplemented)
|
||||
}
|
||||
|
||||
/// Complete WebAuthn registration ceremony (stub).
|
||||
pub fn complete_registration(
|
||||
_username: &str,
|
||||
_response: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, WebAuthnError> {
|
||||
Err(WebAuthnError::NotImplemented)
|
||||
}
|
||||
|
||||
/// Begin WebAuthn authentication ceremony (stub).
|
||||
pub fn begin_authentication(_username: &str) -> Result<serde_json::Value, WebAuthnError> {
|
||||
Err(WebAuthnError::NotImplemented)
|
||||
}
|
||||
|
||||
/// Verify WebAuthn authentication response (stub).
|
||||
pub fn verify_authentication(
|
||||
_username: &str,
|
||||
_credential: &serde_json::Value,
|
||||
_response: &serde_json::Value,
|
||||
) -> Result<bool, WebAuthnError> {
|
||||
Err(WebAuthnError::NotImplemented)
|
||||
}
|
||||
125
crates/pm-auth/src/password.rs
Executable file
125
crates/pm-auth/src/password.rs
Executable file
@ -0,0 +1,125 @@
|
||||
//! Password hashing and verification using Argon2id.
|
||||
//!
|
||||
//! Parameters (calibrated per OWASP recommendations):
|
||||
//! - Algorithm: Argon2id
|
||||
//! - Memory cost: 65536 KiB (64 MiB)
|
||||
//! - Time cost: 3 iterations
|
||||
//! - Parallelism: 1
|
||||
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2, Params, Version,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Argon2id parameters per spec.
|
||||
const M_COST: u32 = 65536; // 64 MiB
|
||||
const T_COST: u32 = 3; // 3 iterations
|
||||
const P_COST: u32 = 1; // 1 thread
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PasswordError {
|
||||
#[error("Failed to hash password: {0}")]
|
||||
HashError(String),
|
||||
#[error("Failed to verify password: {0}")]
|
||||
VerifyError(String),
|
||||
#[error("Invalid password hash format")]
|
||||
InvalidHash,
|
||||
}
|
||||
|
||||
/// Build an Argon2id instance with calibrated parameters.
|
||||
fn argon2() -> Result<Argon2<'static>, PasswordError> {
|
||||
let params = Params::new(M_COST, T_COST, P_COST, None)
|
||||
.map_err(|e| PasswordError::HashError(e.to_string()))?;
|
||||
Ok(Argon2::new(
|
||||
argon2::Algorithm::Argon2id,
|
||||
Version::V0x13,
|
||||
params,
|
||||
))
|
||||
}
|
||||
|
||||
/// Hash a plaintext password using Argon2id with a random salt.
|
||||
///
|
||||
/// Returns the PHC string format hash suitable for storage.
|
||||
pub fn hash_password(password: &str) -> Result<String, PasswordError> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = argon2()?;
|
||||
|
||||
let hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| PasswordError::HashError(e.to_string()))?;
|
||||
|
||||
Ok(hash.to_string())
|
||||
}
|
||||
|
||||
/// Verify a plaintext password against a stored Argon2id PHC hash.
|
||||
///
|
||||
/// Returns `Ok(true)` if the password matches, `Ok(false)` if not.
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool, PasswordError> {
|
||||
let parsed_hash = PasswordHash::new(hash).map_err(|_| PasswordError::InvalidHash)?;
|
||||
|
||||
let argon2 = argon2()?;
|
||||
|
||||
match argon2.verify_password(password.as_bytes(), &parsed_hash) {
|
||||
Ok(()) => Ok(true),
|
||||
Err(argon2::password_hash::Error::Password) => Ok(false),
|
||||
Err(e) => Err(PasswordError::VerifyError(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate password strength against minimum requirements.
|
||||
///
|
||||
/// Requirements:
|
||||
/// - Minimum 8 characters
|
||||
/// - At least one uppercase letter
|
||||
/// - At least one lowercase letter
|
||||
/// - At least one digit
|
||||
/// - At least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?)
|
||||
pub fn validate_password_strength(password: &str) -> Result<(), String> {
|
||||
if password.len() < 8 {
|
||||
return Err("Password must be at least 8 characters".to_string());
|
||||
}
|
||||
if !password.chars().any(|c| c.is_ascii_uppercase()) {
|
||||
return Err("Password must contain at least one uppercase letter".to_string());
|
||||
}
|
||||
if !password.chars().any(|c| c.is_ascii_lowercase()) {
|
||||
return Err("Password must contain at least one lowercase letter".to_string());
|
||||
}
|
||||
if !password.chars().any(|c| c.is_ascii_digit()) {
|
||||
return Err("Password must contain at least one digit".to_string());
|
||||
}
|
||||
let special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?";
|
||||
if !password.chars().any(|c| special_chars.contains(c)) {
|
||||
return Err(
|
||||
"Password must contain at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?)"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn hash_and_verify_roundtrip() {
|
||||
let password = "super-secret-password-123!";
|
||||
let hash = hash_password(password).unwrap();
|
||||
assert!(hash.starts_with("$argon2id$"));
|
||||
assert!(verify_password(password, &hash).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_password_fails() {
|
||||
let hash = hash_password("correct-horse").unwrap();
|
||||
assert!(!verify_password("wrong-password", &hash).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_salts_produce_different_hashes() {
|
||||
let hash1 = hash_password("same-password").unwrap();
|
||||
let hash2 = hash_password("same-password").unwrap();
|
||||
assert_ne!(hash1, hash2); // different salts
|
||||
}
|
||||
}
|
||||
232
crates/pm-auth/src/rbac.rs
Executable file
232
crates/pm-auth/src/rbac.rs
Executable file
@ -0,0 +1,232 @@
|
||||
//! Role-Based Access Control (RBAC) middleware for Axum.
|
||||
//!
|
||||
//! Provides:
|
||||
//! - JWT extraction and validation from `Authorization: Bearer <token>` header
|
||||
//! - Role enforcement (`admin`, `operator`)
|
||||
//! - Group-scoped access (enforced at the handler level using `AuthUser` extension)
|
||||
//! - IP whitelist enforcement
|
||||
|
||||
use axum::{
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Json, Response},
|
||||
};
|
||||
use ipnet::IpNet;
|
||||
use parking_lot::RwLock;
|
||||
use serde_json::json;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::jwt::{validate_access_token, AccessClaims, JwtError};
|
||||
|
||||
/// User identity extracted from a validated JWT, inserted as a request extension.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthUser {
|
||||
pub user_id: Uuid,
|
||||
pub username: String,
|
||||
pub role: UserRole,
|
||||
pub claims: AccessClaims,
|
||||
}
|
||||
|
||||
/// Application roles.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum UserRole {
|
||||
Admin,
|
||||
Operator,
|
||||
Reporter,
|
||||
}
|
||||
|
||||
impl UserRole {
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"admin" => Some(Self::Admin),
|
||||
"operator" => Some(Self::Operator),
|
||||
"reporter" => Some(Self::Reporter),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Admin => "admin",
|
||||
Self::Operator => "operator",
|
||||
Self::Reporter => "reporter",
|
||||
}
|
||||
}
|
||||
|
||||
/// Admin can do everything; operator has limited scope.
|
||||
pub fn is_admin(&self) -> bool {
|
||||
matches!(self, Self::Admin)
|
||||
}
|
||||
|
||||
/// Admin and Operator can write; Reporter is read-only.
|
||||
pub fn can_write(&self) -> bool {
|
||||
matches!(self, Self::Admin | Self::Operator)
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared auth configuration injected via Axum state.
|
||||
#[derive(Clone)]
|
||||
pub struct AuthConfig {
|
||||
/// Ed25519 public key PEM for JWT verification.
|
||||
pub verify_key_pem: String,
|
||||
/// IP whitelist (empty = allow all). RwLock for runtime updates.
|
||||
pub ip_whitelist: Arc<RwLock<Vec<IpNet>>>,
|
||||
}
|
||||
|
||||
impl AuthConfig {
|
||||
pub fn new(verify_key_pem: String, ip_whitelist_cidrs: &[String]) -> Self {
|
||||
let ip_whitelist = ip_whitelist_cidrs
|
||||
.iter()
|
||||
.filter_map(|cidr| IpNet::from_str(cidr).ok())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
verify_key_pem,
|
||||
ip_whitelist: Arc::new(RwLock::new(ip_whitelist)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an IP address is allowed by the whitelist.
|
||||
/// If the whitelist is empty, all IPs are allowed.
|
||||
pub fn is_ip_allowed(&self, ip: &IpAddr) -> bool {
|
||||
let whitelist = self.ip_whitelist.read();
|
||||
if whitelist.is_empty() {
|
||||
return true;
|
||||
}
|
||||
whitelist.iter().any(|net| net.contains(ip))
|
||||
}
|
||||
|
||||
/// Update the IP whitelist at runtime without restart.
|
||||
pub fn update_ip_whitelist(&self, entries: Vec<String>) {
|
||||
let nets: Vec<IpNet> = entries
|
||||
.iter()
|
||||
.filter_map(|cidr| IpNet::from_str(cidr).ok())
|
||||
.collect();
|
||||
let count = nets.len();
|
||||
*self.ip_whitelist.write() = nets;
|
||||
tracing::info!(count, "IP whitelist updated at runtime");
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract `Authorization: Bearer <token>` from request headers.
|
||||
fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
|
||||
headers
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.strip_prefix("Bearer "))
|
||||
}
|
||||
|
||||
/// Extract the remote IP from `X-Forwarded-For`.
|
||||
fn extract_remote_ip(headers: &HeaderMap) -> Option<IpAddr> {
|
||||
headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.split(',').next())
|
||||
.and_then(|s| s.trim().parse().ok())
|
||||
}
|
||||
|
||||
/// Unauthorized JSON response helper.
|
||||
fn unauthorized(message: &str) -> Response {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({ "error": { "code": "unauthorized", "message": message } })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Forbidden JSON response helper.
|
||||
fn forbidden(message: &str) -> Response {
|
||||
(
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": message } })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Middleware: authenticate any valid JWT (admin or operator).
|
||||
///
|
||||
/// Inserts `AuthUser` into request extensions on success.
|
||||
/// Rejects with 401 if token is missing/invalid, 403 if IP is blocked.
|
||||
pub async fn require_auth(auth_config: Arc<AuthConfig>, mut req: Request, next: Next) -> Response {
|
||||
// IP whitelist check
|
||||
if let Some(ip) = extract_remote_ip(req.headers()) {
|
||||
if !auth_config.is_ip_allowed(&ip) {
|
||||
tracing::warn!(ip = %ip, "Request blocked by IP whitelist");
|
||||
return forbidden("Access denied");
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and validate JWT
|
||||
let token = match extract_bearer_token(req.headers()) {
|
||||
Some(t) => t,
|
||||
None => return unauthorized("Missing authorization token"),
|
||||
};
|
||||
|
||||
let claims = match validate_access_token(token, &auth_config.verify_key_pem) {
|
||||
Ok(c) => c,
|
||||
Err(JwtError::Expired) => return unauthorized("Token expired"),
|
||||
Err(e) => {
|
||||
tracing::debug!(error = %e, "JWT validation failed");
|
||||
return unauthorized("Invalid token");
|
||||
},
|
||||
};
|
||||
|
||||
let role = match UserRole::from_str(&claims.role) {
|
||||
Some(r) => r,
|
||||
None => return unauthorized("Invalid role in token"),
|
||||
};
|
||||
|
||||
let user_id = match claims.user_id() {
|
||||
Ok(id) => id,
|
||||
Err(_) => return unauthorized("Invalid user ID in token"),
|
||||
};
|
||||
|
||||
let auth_user = AuthUser {
|
||||
user_id,
|
||||
username: claims.username.clone(),
|
||||
role,
|
||||
claims,
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(auth_user);
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// Middleware: require the `admin` role.
|
||||
/// Must be chained AFTER `require_auth` (which inserts `AuthUser`).
|
||||
pub async fn require_admin(req: Request, next: Next) -> Response {
|
||||
let auth_user = match req.extensions().get::<AuthUser>().cloned() {
|
||||
Some(u) => u,
|
||||
None => return unauthorized("Authentication required"),
|
||||
};
|
||||
|
||||
if !auth_user.role.is_admin() {
|
||||
return forbidden("Admin role required");
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// Axum extractor: pulls `AuthUser` from request extensions.
|
||||
impl<S> axum::extract::FromRequestParts<S> for AuthUser
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
_state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
parts
|
||||
.extensions
|
||||
.get::<AuthUser>()
|
||||
.cloned()
|
||||
.ok_or_else(|| unauthorized("Authentication required"))
|
||||
}
|
||||
}
|
||||
163
crates/pm-auth/src/refresh.rs
Executable file
163
crates/pm-auth/src/refresh.rs
Executable file
@ -0,0 +1,163 @@
|
||||
//! Opaque refresh token management.
|
||||
//!
|
||||
//! - 256-bit cryptographically random opaque tokens
|
||||
//! - Stored as SHA-256 hash in the database (never the raw token)
|
||||
//! - 1-hour sliding inactivity timeout, updated on each use
|
||||
//! - Rotated on use (old token revoked, new one issued)
|
||||
//! - Revocable by admin force-logout
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use rand::RngCore;
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Length of the raw refresh token in bytes (256 bits).
|
||||
const TOKEN_BYTES: usize = 32;
|
||||
|
||||
/// Sliding inactivity window: 1 hour.
|
||||
const INACTIVITY_TIMEOUT_HOURS: i64 = 1;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RefreshError {
|
||||
#[error("Refresh token not found or revoked")]
|
||||
Invalid,
|
||||
#[error("Refresh token expired")]
|
||||
Expired,
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
}
|
||||
|
||||
/// Raw (plaintext) refresh token — returned to the client, never stored.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RawRefreshToken(pub String);
|
||||
|
||||
impl RawRefreshToken {
|
||||
/// Hex-encode a raw 256-bit random token.
|
||||
pub fn generate() -> Self {
|
||||
let mut bytes = [0u8; TOKEN_BYTES];
|
||||
rand::thread_rng().fill_bytes(&mut bytes);
|
||||
Self(hex::encode(bytes))
|
||||
}
|
||||
|
||||
/// Return the SHA-256 hash of this token for database storage.
|
||||
pub fn hash(&self) -> String {
|
||||
let digest = Sha256::digest(self.0.as_bytes());
|
||||
hex::encode(digest)
|
||||
}
|
||||
}
|
||||
|
||||
/// Database row representation of a stored refresh token.
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
pub struct StoredRefreshToken {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub expires_at: chrono::DateTime<Utc>,
|
||||
pub revoked: bool,
|
||||
}
|
||||
|
||||
/// Issue a new refresh token for the given user and store it in the database.
|
||||
///
|
||||
/// Returns the raw (plaintext) token to be sent to the client.
|
||||
pub async fn issue(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
user_agent: Option<&str>,
|
||||
ip_address: Option<&str>,
|
||||
) -> Result<RawRefreshToken, RefreshError> {
|
||||
let token = RawRefreshToken::generate();
|
||||
let hash = token.hash();
|
||||
let expires_at = Utc::now() + Duration::hours(INACTIVITY_TIMEOUT_HOURS);
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO refresh_tokens (user_id, token_hash, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, $4, $5::inet)
|
||||
"#,
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&hash)
|
||||
.bind(expires_at)
|
||||
.bind(user_agent)
|
||||
.bind(ip_address)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
tracing::debug!(user_id = %user_id, "Refresh token issued");
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Validate a refresh token, then rotate it (revoke old, issue new).
|
||||
///
|
||||
/// Returns `(new_raw_token, user_id)` if valid.
|
||||
pub async fn rotate(
|
||||
pool: &PgPool,
|
||||
raw_token: &str,
|
||||
user_agent: Option<&str>,
|
||||
ip_address: Option<&str>,
|
||||
) -> Result<(RawRefreshToken, Uuid), RefreshError> {
|
||||
let hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
||||
let now = Utc::now();
|
||||
|
||||
// Look up token
|
||||
let row: Option<StoredRefreshToken> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, user_id, expires_at, revoked
|
||||
FROM refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
"#,
|
||||
)
|
||||
.bind(&hash)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let stored = row.ok_or(RefreshError::Invalid)?;
|
||||
|
||||
if stored.revoked {
|
||||
return Err(RefreshError::Invalid);
|
||||
}
|
||||
|
||||
if stored.expires_at < now {
|
||||
return Err(RefreshError::Expired);
|
||||
}
|
||||
|
||||
// Revoke old token
|
||||
sqlx::query("UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE id = $1")
|
||||
.bind(stored.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Issue new token
|
||||
let new_token = issue(pool, stored.user_id, user_agent, ip_address).await?;
|
||||
|
||||
tracing::debug!(user_id = %stored.user_id, "Refresh token rotated");
|
||||
Ok((new_token, stored.user_id))
|
||||
}
|
||||
|
||||
/// Revoke all refresh tokens for a user (force logout).
|
||||
pub async fn revoke_all_for_user(pool: &PgPool, user_id: Uuid) -> Result<u64, RefreshError> {
|
||||
let result = sqlx::query(
|
||||
"UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE user_id = $1 AND revoked = FALSE",
|
||||
)
|
||||
.bind(user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
tracing::info!(user_id = %user_id, rows = result.rows_affected(), "All refresh tokens revoked");
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Revoke a single refresh token by its raw value.
|
||||
pub async fn revoke(pool: &PgPool, raw_token: &str) -> Result<(), RefreshError> {
|
||||
let hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE refresh_tokens SET revoked = TRUE, revoked_at = NOW() WHERE token_hash = $1",
|
||||
)
|
||||
.bind(&hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
308
crates/pm-auth/src/session.rs
Executable file
308
crates/pm-auth/src/session.rs
Executable file
@ -0,0 +1,308 @@
|
||||
//! Session management: login flow, logout, token issuance.
|
||||
//!
|
||||
//! Login flow: password → MFA → access token + refresh token
|
||||
//! Logout: revoke refresh token
|
||||
//! Force logout: revoke all tokens for a user
|
||||
|
||||
use chrono::Utc;
|
||||
use pm_core::models::{AuthProvider, UserRole};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
jwt::{self, JwtError},
|
||||
mfa_totp,
|
||||
password::{self, PasswordError},
|
||||
refresh::{self, RefreshError},
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SessionError {
|
||||
#[error("Invalid credentials")]
|
||||
InvalidCredentials,
|
||||
#[error("Account is disabled")]
|
||||
AccountDisabled,
|
||||
#[error("Password reset required")]
|
||||
PasswordResetRequired,
|
||||
#[error("MFA required")]
|
||||
MfaRequired,
|
||||
#[error("Invalid MFA code")]
|
||||
InvalidMfaCode,
|
||||
#[error("Account locked due to too many failed attempts")]
|
||||
AccountLocked,
|
||||
#[error("JWT error: {0}")]
|
||||
Jwt(#[from] JwtError),
|
||||
#[error("Refresh token error: {0}")]
|
||||
Refresh(#[from] RefreshError),
|
||||
#[error("Password error: {0}")]
|
||||
Password(#[from] PasswordError),
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
}
|
||||
|
||||
/// Successful login response returned to the client.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginResponse {
|
||||
/// Short-lived JWT access token (15 minutes).
|
||||
pub access_token: String,
|
||||
/// Opaque refresh token (1-hour sliding window).
|
||||
pub refresh_token: String,
|
||||
/// Token type (always "Bearer").
|
||||
pub token_type: String,
|
||||
/// Access token TTL in seconds.
|
||||
pub expires_in: i64,
|
||||
/// User information.
|
||||
pub user: SessionUser,
|
||||
}
|
||||
|
||||
/// User summary embedded in login response.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SessionUser {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub display_name: String,
|
||||
pub role: String,
|
||||
pub mfa_enabled: bool,
|
||||
}
|
||||
|
||||
/// Database user row fetched during login.
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct DbUser {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
display_name: String,
|
||||
role: UserRole,
|
||||
auth_provider: AuthProvider,
|
||||
password_hash: Option<String>,
|
||||
totp_secret: Option<String>,
|
||||
mfa_enabled: bool,
|
||||
is_active: bool,
|
||||
force_password_reset: bool,
|
||||
failed_login_attempts: i32,
|
||||
locked_until: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Login request payload.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
/// TOTP code (required if MFA is enabled).
|
||||
pub totp_code: Option<String>,
|
||||
}
|
||||
|
||||
/// Perform the full login flow for local accounts.
|
||||
///
|
||||
/// Steps:
|
||||
/// 1. Look up user by username
|
||||
/// 2. Verify password (Argon2id)
|
||||
/// 3. Check account active state
|
||||
/// 4. Verify MFA if enabled
|
||||
/// 5. Issue access token + refresh token
|
||||
/// 6. Update last_login_at
|
||||
pub async fn login(
|
||||
pool: &PgPool,
|
||||
req: &LoginRequest,
|
||||
signing_key_pem: &str,
|
||||
access_ttl_secs: i64,
|
||||
user_agent: Option<&str>,
|
||||
ip_address: Option<&str>,
|
||||
) -> Result<LoginResponse, SessionError> {
|
||||
// 1. Fetch user by username
|
||||
let user: Option<DbUser> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, username, display_name, role, auth_provider,
|
||||
password_hash, totp_secret, mfa_enabled, is_active, force_password_reset,
|
||||
failed_login_attempts, locked_until
|
||||
FROM users
|
||||
WHERE username = $1 AND auth_provider = 'local'
|
||||
"#,
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
// Use constant-time comparison approach: always run Argon2 even on miss
|
||||
let user = match user {
|
||||
Some(u) => u,
|
||||
None => {
|
||||
// Prevent timing-based username enumeration
|
||||
let _ = password::hash_password("dummy-timing-fill");
|
||||
return Err(SessionError::InvalidCredentials);
|
||||
},
|
||||
};
|
||||
|
||||
// 2a. Check if account is locked due to too many failed attempts
|
||||
if let Some(locked_until) = user.locked_until {
|
||||
if locked_until > Utc::now() {
|
||||
tracing::warn!(username = %req.username, "Login blocked: account locked until {}", locked_until);
|
||||
return Err(SessionError::AccountLocked);
|
||||
}
|
||||
// Lockout period has expired — reset counters
|
||||
sqlx::query(
|
||||
"UPDATE users SET failed_login_attempts = 0, locked_until = NULL WHERE id = $1",
|
||||
)
|
||||
.bind(user.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Verify password
|
||||
let hash = user.password_hash.as_deref().unwrap_or("");
|
||||
let valid = password::verify_password(&req.password, hash).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
// Increment failed login attempts
|
||||
let new_attempts = user.failed_login_attempts + 1;
|
||||
if new_attempts >= 5 {
|
||||
let lock_until = Utc::now() + chrono::Duration::minutes(30);
|
||||
sqlx::query(
|
||||
"UPDATE users SET failed_login_attempts = $1, locked_until = $2 WHERE id = $3",
|
||||
)
|
||||
.bind(new_attempts)
|
||||
.bind(lock_until)
|
||||
.bind(user.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
tracing::warn!(username = %req.username, "Account locked after {} failed attempts", new_attempts);
|
||||
} else {
|
||||
sqlx::query("UPDATE users SET failed_login_attempts = $1 WHERE id = $2")
|
||||
.bind(new_attempts)
|
||||
.bind(user.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
tracing::warn!(username = %req.username, "Login failed: invalid password (attempt {})", new_attempts);
|
||||
return Err(SessionError::InvalidCredentials);
|
||||
}
|
||||
|
||||
// 3. Check account state
|
||||
if !user.is_active {
|
||||
tracing::warn!(username = %req.username, "Login failed: account disabled");
|
||||
return Err(SessionError::AccountDisabled);
|
||||
}
|
||||
|
||||
// 3b. Check if password reset is required
|
||||
if user.force_password_reset {
|
||||
tracing::warn!(username = %req.username, "Login blocked: password reset required");
|
||||
return Err(SessionError::PasswordResetRequired);
|
||||
}
|
||||
|
||||
// 4. MFA check
|
||||
if user.mfa_enabled {
|
||||
let code = req.totp_code.as_deref().ok_or(SessionError::MfaRequired)?;
|
||||
let secret = user.totp_secret.as_deref().unwrap_or("");
|
||||
|
||||
let mfa_ok = mfa_totp::verify_code(&user.username, secret, code).unwrap_or(false);
|
||||
|
||||
if !mfa_ok {
|
||||
tracing::warn!(username = %req.username, "Login failed: invalid MFA code");
|
||||
return Err(SessionError::InvalidMfaCode);
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Issue tokens
|
||||
let access_token = jwt::issue_access_token(
|
||||
user.id,
|
||||
&user.username,
|
||||
&user.role.to_string(),
|
||||
access_ttl_secs,
|
||||
signing_key_pem,
|
||||
)?;
|
||||
|
||||
let raw_refresh = refresh::issue(pool, user.id, user_agent, ip_address).await?;
|
||||
|
||||
// 6. Update last_login_at
|
||||
sqlx::query("UPDATE users SET last_login_at = $1, failed_login_attempts = 0, locked_until = NULL WHERE id = $2")
|
||||
.bind(Utc::now())
|
||||
.bind(user.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
tracing::info!(user_id = %user.id, username = %user.username, "Login successful");
|
||||
|
||||
Ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token: raw_refresh.0,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: access_ttl_secs,
|
||||
user: SessionUser {
|
||||
id: user.id.to_string(),
|
||||
username: user.username,
|
||||
display_name: user.display_name,
|
||||
role: user.role.to_string(),
|
||||
mfa_enabled: user.mfa_enabled,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Refresh an access token using a valid refresh token.
|
||||
///
|
||||
/// The old refresh token is revoked and a new one issued (rotation).
|
||||
pub async fn refresh_session(
|
||||
pool: &PgPool,
|
||||
raw_refresh_token: &str,
|
||||
signing_key_pem: &str,
|
||||
access_ttl_secs: i64,
|
||||
user_agent: Option<&str>,
|
||||
ip_address: Option<&str>,
|
||||
) -> Result<LoginResponse, SessionError> {
|
||||
let (new_refresh, user_id) =
|
||||
refresh::rotate(pool, raw_refresh_token, user_agent, ip_address).await?;
|
||||
|
||||
// Fetch user for token claims
|
||||
let user: DbUser = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, username, display_name, role, auth_provider,
|
||||
password_hash, totp_secret, mfa_enabled, is_active, force_password_reset,
|
||||
failed_login_attempts, locked_until
|
||||
FROM users WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if !user.is_active {
|
||||
// Revoke all tokens and deny
|
||||
let _ = refresh::revoke_all_for_user(pool, user_id).await;
|
||||
return Err(SessionError::AccountDisabled);
|
||||
}
|
||||
|
||||
let access_token = jwt::issue_access_token(
|
||||
user.id,
|
||||
&user.username,
|
||||
&user.role.to_string(),
|
||||
access_ttl_secs,
|
||||
signing_key_pem,
|
||||
)?;
|
||||
|
||||
Ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token: new_refresh.0,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: access_ttl_secs,
|
||||
user: SessionUser {
|
||||
id: user.id.to_string(),
|
||||
username: user.username,
|
||||
display_name: user.display_name,
|
||||
role: user.role.to_string(),
|
||||
mfa_enabled: user.mfa_enabled,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Logout: revoke the current refresh token.
|
||||
pub async fn logout(pool: &PgPool, raw_refresh_token: &str) -> Result<(), SessionError> {
|
||||
refresh::revoke(pool, raw_refresh_token).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Force-logout: revoke all refresh tokens for a user.
|
||||
pub async fn force_logout(pool: &PgPool, user_id: Uuid) -> Result<u64, SessionError> {
|
||||
let count = refresh::revoke_all_for_user(pool, user_id).await?;
|
||||
Ok(count)
|
||||
}
|
||||
25
crates/pm-ca/Cargo.toml
Normal file
25
crates/pm-ca/Cargo.toml
Normal file
@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "pm-ca"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
pem = { workspace = true }
|
||||
time = { workspace = true }
|
||||
527
crates/pm-ca/src/ca.rs
Executable file
527
crates/pm-ca/src/ca.rs
Executable file
@ -0,0 +1,527 @@
|
||||
//! Internal Certificate Authority for Linux Patch Manager.
|
||||
//!
|
||||
//! Issues and renews mTLS client certificates and agent server certificates
|
||||
//! for agent communication. Uses rcgen (ECDSA P-256) for all certificate
|
||||
//! generation. CA key and certificate are stored on disk under `base_dir`
|
||||
//! (default: /etc/patch-manager/ca/). Certificate metadata is persisted in
|
||||
//! the `certificates` PostgreSQL table.
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||
use rand::RngCore;
|
||||
use rcgen::{
|
||||
BasicConstraints, Certificate, CertificateParams, DistinguishedName, DnType,
|
||||
ExtendedKeyUsagePurpose, Ia5String, IsCa, KeyPair, KeyUsagePurpose, SanType, SerialNumber,
|
||||
PKCS_ECDSA_P256_SHA256,
|
||||
};
|
||||
use sqlx::{PgPool, Row};
|
||||
use time::{Duration as TimeDuration, OffsetDateTime};
|
||||
use uuid::Uuid;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Returned by [`CertAuthority::issue_client_cert`] and [`CertAuthority::renew_cert`].
|
||||
///
|
||||
/// The private keys are intentionally **not** stored in the database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IssuedCert {
|
||||
/// PEM-encoded client certificate (mTLS).
|
||||
pub cert_pem: String,
|
||||
/// PEM-encoded client private key (PKCS#8).
|
||||
pub key_pem: String,
|
||||
/// Hex-encoded 16-byte random serial number (client cert).
|
||||
pub serial_number: String,
|
||||
/// Certificate expiry timestamp (UTC).
|
||||
pub expires_at: DateTime<Utc>,
|
||||
/// PEM-encoded agent server certificate (for TLS listener).
|
||||
pub server_cert_pem: String,
|
||||
/// PEM-encoded agent server private key (PKCS#8).
|
||||
pub server_key_pem: String,
|
||||
/// Hex-encoded serial number of the server certificate.
|
||||
pub server_serial_number: String,
|
||||
/// PEM-encoded CA root certificate.
|
||||
pub ca_root_pem: String,
|
||||
}
|
||||
// ---------------------------------------------------------------------------
|
||||
// CertAuthority
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Thread-safe, cloneable handle to the internal certificate authority.
|
||||
///
|
||||
/// CA certificate and key are held in memory as PEM strings; rcgen objects
|
||||
/// are reconstructed on demand so this struct is unconditionally `Send + Sync`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertAuthority {
|
||||
#[allow(dead_code)]
|
||||
base_dir: PathBuf,
|
||||
/// PEM-encoded CA certificate (public cert only).
|
||||
ca_cert_pem: String,
|
||||
/// PEM-encoded CA private key (PKCS#8).
|
||||
ca_key_pem: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Private helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Generate a 16-byte cryptographically-random serial number.
|
||||
/// Returns `(rcgen::SerialNumber, hex_encoded_string)`.
|
||||
fn make_serial() -> (SerialNumber, String) {
|
||||
let mut bytes = [0u8; 16];
|
||||
rand::rngs::OsRng.fill_bytes(&mut bytes);
|
||||
let hex_serial = hex::encode(bytes);
|
||||
let serial = SerialNumber::from_slice(&bytes);
|
||||
(serial, hex_serial)
|
||||
}
|
||||
|
||||
/// `OffsetDateTime::now_utc()` offset forward by `days` (for rcgen params).
|
||||
fn odt_offset_days(days: i64) -> OffsetDateTime {
|
||||
OffsetDateTime::now_utc() + TimeDuration::days(days)
|
||||
}
|
||||
|
||||
/// `chrono::Utc::now()` offset forward by `days` (for DB / return values).
|
||||
fn chrono_offset_days(days: i64) -> DateTime<Utc> {
|
||||
Utc::now() + ChronoDuration::days(days)
|
||||
}
|
||||
|
||||
/// Build a `CertificateParams` with common fields pre-filled.
|
||||
/// Caller still needs to set `is_ca`, `key_usages`, `extended_key_usages`, and `subject_alt_names`.
|
||||
fn base_params(cn: &str, validity_days: i64) -> Result<(CertificateParams, String, DateTime<Utc>)> {
|
||||
let (serial, serial_hex) = make_serial();
|
||||
let expires_at = chrono_offset_days(validity_days);
|
||||
|
||||
let mut params = CertificateParams::default();
|
||||
params.not_before = OffsetDateTime::now_utc();
|
||||
params.not_after = odt_offset_days(validity_days);
|
||||
params.serial_number = Some(serial);
|
||||
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::CommonName, cn);
|
||||
params.distinguished_name = dn;
|
||||
|
||||
Ok((params, serial_hex, expires_at))
|
||||
}
|
||||
|
||||
/// Write `contents` to `path` and set Unix permissions to `0600`.
|
||||
async fn write_protected(path: &Path, contents: &str) -> Result<()> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
tokio::fs::write(path, contents).await?;
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
tokio::fs::set_permissions(path, perms).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// impl CertAuthority
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
impl CertAuthority {
|
||||
// -----------------------------------------------------------------------
|
||||
// Construction
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Load an existing CA from disk, or generate a brand-new one if absent.
|
||||
///
|
||||
/// Files managed:
|
||||
/// * `{base_dir}/ca.key` — PKCS#8 PEM private key (mode `0600`)
|
||||
/// * `{base_dir}/ca.crt` — PEM certificate (mode `0600`)
|
||||
///
|
||||
/// On first generation the CA row is inserted into `certificates`
|
||||
/// with `host_id = NULL` (marks it as the root CA record).
|
||||
pub async fn init(base_dir: &Path, db: &PgPool) -> Result<Self> {
|
||||
let key_path = base_dir.join("ca.key");
|
||||
let crt_path = base_dir.join("ca.crt");
|
||||
|
||||
// ── Load existing CA ──────────────────────────────────────────────
|
||||
if key_path.exists() && crt_path.exists() {
|
||||
tracing::info!(path = %base_dir.display(), "Loading existing root CA from disk");
|
||||
|
||||
let ca_key_pem = tokio::fs::read_to_string(&key_path)
|
||||
.await
|
||||
.context("read ca.key")?;
|
||||
let ca_cert_pem = tokio::fs::read_to_string(&crt_path)
|
||||
.await
|
||||
.context("read ca.crt")?;
|
||||
|
||||
// Validate that both PEMs parse without error.
|
||||
KeyPair::from_pem(&ca_key_pem).context("parse CA private-key PEM")?;
|
||||
CertificateParams::from_ca_cert_pem(&ca_cert_pem)
|
||||
.context("parse CA certificate PEM")?;
|
||||
|
||||
tracing::info!("Root CA loaded successfully");
|
||||
return Ok(Self {
|
||||
base_dir: base_dir.to_owned(),
|
||||
ca_cert_pem,
|
||||
ca_key_pem,
|
||||
});
|
||||
}
|
||||
|
||||
// ── Generate new CA ───────────────────────────────────────────────
|
||||
tracing::info!(
|
||||
path = %base_dir.display(),
|
||||
"Generating new root CA (ECDSA P-256, 10-year validity)"
|
||||
);
|
||||
tokio::fs::create_dir_all(base_dir)
|
||||
.await
|
||||
.context("create CA directory")?;
|
||||
|
||||
let ca_key =
|
||||
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate CA key pair")?;
|
||||
|
||||
let (serial, serial_hex) = make_serial();
|
||||
let expires_at = chrono_offset_days(365 * 10);
|
||||
|
||||
let mut params = CertificateParams::default();
|
||||
params.not_before = OffsetDateTime::now_utc();
|
||||
params.not_after = odt_offset_days(365 * 10);
|
||||
params.serial_number = Some(serial);
|
||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
|
||||
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::CommonName, "Patch Manager Root CA");
|
||||
dn.push(DnType::OrganizationName, "Patch Manager");
|
||||
params.distinguished_name = dn;
|
||||
|
||||
let ca_cert_obj = params
|
||||
.self_signed(&ca_key)
|
||||
.context("self-sign CA certificate")?;
|
||||
let ca_cert_pem = ca_cert_obj.pem();
|
||||
let ca_key_pem = ca_key.serialize_pem();
|
||||
|
||||
write_protected(&key_path, &ca_key_pem)
|
||||
.await
|
||||
.context("write ca.key")?;
|
||||
write_protected(&crt_path, &ca_cert_pem)
|
||||
.await
|
||||
.context("write ca.crt")?;
|
||||
|
||||
tracing::info!(
|
||||
serial = %serial_hex,
|
||||
expires_at = %expires_at,
|
||||
"Root CA generated and written to disk"
|
||||
);
|
||||
|
||||
// Persist CA cert metadata (host_id = NULL marks the root CA row).
|
||||
sqlx::query(
|
||||
"INSERT INTO certificates \
|
||||
(host_id, serial_number, common_name, status, expires_at, cert_pem) \
|
||||
VALUES (NULL, $1, 'Patch Manager Root CA', 'active'::cert_status, $2, $3)",
|
||||
)
|
||||
.bind(&serial_hex)
|
||||
.bind(expires_at)
|
||||
.bind(&ca_cert_pem)
|
||||
.execute(db)
|
||||
.await
|
||||
.context("insert root CA cert into database")?;
|
||||
|
||||
tracing::info!("Root CA certificate recorded in database");
|
||||
|
||||
Ok(Self {
|
||||
base_dir: base_dir.to_owned(),
|
||||
ca_cert_pem,
|
||||
ca_key_pem,
|
||||
})
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Public accessors
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Return the PEM-encoded root CA certificate (public cert only).
|
||||
pub fn root_cert_pem(&self) -> &str {
|
||||
&self.ca_cert_pem
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Certificate issuance
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Issue a one-year mTLS client certificate for a managed host.
|
||||
///
|
||||
/// * Subject: `CN=<hostname>`
|
||||
/// * Key usage: Digital Signature
|
||||
/// * Extended key usage: Client Authentication
|
||||
///
|
||||
/// Also issues a server certificate for the agent's TLS listener
|
||||
/// (see [`issue_server_cert`]).
|
||||
///
|
||||
/// The certificate PEMs are stored in `certificates`.
|
||||
/// The private keys are returned to the caller **only** and never persisted.
|
||||
pub async fn issue_client_cert(
|
||||
&self,
|
||||
host_id: Uuid,
|
||||
hostname: &str,
|
||||
ip_address: &str,
|
||||
db: &PgPool,
|
||||
) -> Result<IssuedCert> {
|
||||
tracing::info!(host_id = %host_id, hostname, ip_address, "Issuing mTLS client certificate");
|
||||
|
||||
let key =
|
||||
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate client key pair")?;
|
||||
|
||||
let (mut params, serial_hex, expires_at) = base_params(hostname, 365)?;
|
||||
params.is_ca = IsCa::ExplicitNoCa;
|
||||
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth];
|
||||
|
||||
let (ca_key, ca_cert) = self.ca_objects()?;
|
||||
let cert = params
|
||||
.signed_by(&key, &ca_cert, &ca_key)
|
||||
.context("sign client cert with CA")?;
|
||||
|
||||
let cert_pem = cert.pem();
|
||||
let key_pem = key.serialize_pem();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO certificates \
|
||||
(host_id, serial_number, common_name, status, expires_at, cert_pem) \
|
||||
VALUES ($1, $2, $3, 'active'::cert_status, $4, $5)",
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(&serial_hex)
|
||||
.bind(hostname)
|
||||
.bind(expires_at)
|
||||
.bind(&cert_pem)
|
||||
.execute(db)
|
||||
.await
|
||||
.context("insert client cert into database")?;
|
||||
|
||||
tracing::info!(
|
||||
host_id = %host_id,
|
||||
hostname,
|
||||
serial = %serial_hex,
|
||||
expires_at = %expires_at,
|
||||
"Client certificate issued successfully"
|
||||
);
|
||||
|
||||
// Also issue a server certificate for the agent's TLS listener.
|
||||
let (server_cert_pem, server_key_pem, server_serial_number, _server_expires_at) = self
|
||||
.issue_server_cert(host_id, hostname, ip_address, db)
|
||||
.await?;
|
||||
|
||||
Ok(IssuedCert {
|
||||
cert_pem,
|
||||
key_pem,
|
||||
serial_number: serial_hex,
|
||||
expires_at,
|
||||
server_cert_pem,
|
||||
server_key_pem,
|
||||
server_serial_number,
|
||||
ca_root_pem: self.root_cert_pem().to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Issue a one-year server certificate for a managed host's agent TLS listener.
|
||||
///
|
||||
/// * Subject: `CN=<hostname>-server`
|
||||
/// * Key usage: Digital Signature
|
||||
/// * Extended key usage: Server Authentication
|
||||
/// * SANs: DNS `<hostname>` + IP `<ip_address>` (if valid)
|
||||
///
|
||||
/// The certificate PEM is stored in `certificates` with common_name
|
||||
/// `{hostname}-server` to distinguish from client certs.
|
||||
/// The private key is returned to the caller **only** and never persisted.
|
||||
///
|
||||
/// Returns `(cert_pem, key_pem, serial_number, expires_at)`.
|
||||
pub async fn issue_server_cert(
|
||||
&self,
|
||||
host_id: Uuid,
|
||||
hostname: &str,
|
||||
ip_address: &str,
|
||||
db: &PgPool,
|
||||
) -> Result<(String, String, String, DateTime<Utc>)> {
|
||||
tracing::info!(host_id = %host_id, hostname, ip_address, "Issuing agent server certificate");
|
||||
|
||||
let key =
|
||||
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate server key pair")?;
|
||||
|
||||
let server_cn = format!("{hostname}-server");
|
||||
let (mut params, serial_hex, expires_at) = base_params(&server_cn, 365)?;
|
||||
params.is_ca = IsCa::ExplicitNoCa;
|
||||
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
|
||||
// Build SANs: DNS hostname + optional IP address
|
||||
let mut sans = vec![SanType::DnsName(
|
||||
Ia5String::try_from(hostname.to_owned()).context("hostname is not valid IA5")?,
|
||||
)];
|
||||
// Strip CIDR netmask (e.g. "192.168.3.36/32") before parsing
|
||||
let ip_str = ip_address.split('/').next().unwrap_or(ip_address);
|
||||
if let Ok(ip) = ip_str.parse::<IpAddr>() {
|
||||
sans.push(SanType::IpAddress(ip));
|
||||
} else {
|
||||
tracing::warn!(
|
||||
ip_address,
|
||||
"Could not parse IP address for server cert SANs, skipping IP SAN"
|
||||
);
|
||||
}
|
||||
params.subject_alt_names = sans;
|
||||
|
||||
let (ca_key, ca_cert) = self.ca_objects()?;
|
||||
let cert = params
|
||||
.signed_by(&key, &ca_cert, &ca_key)
|
||||
.context("sign server cert with CA")?;
|
||||
|
||||
let cert_pem = cert.pem();
|
||||
let key_pem = key.serialize_pem();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO certificates \
|
||||
(host_id, serial_number, common_name, status, expires_at, cert_pem) \
|
||||
VALUES ($1, $2, $3, 'active'::cert_status, $4, $5)",
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(&serial_hex)
|
||||
.bind(&server_cn)
|
||||
.bind(expires_at)
|
||||
.bind(&cert_pem)
|
||||
.execute(db)
|
||||
.await
|
||||
.context("insert server cert into database")?;
|
||||
|
||||
tracing::info!(
|
||||
host_id = %host_id,
|
||||
hostname,
|
||||
serial = %serial_hex,
|
||||
expires_at = %expires_at,
|
||||
"Server certificate issued successfully"
|
||||
);
|
||||
|
||||
Ok((cert_pem, key_pem, serial_hex, expires_at))
|
||||
}
|
||||
|
||||
/// Revoke a certificate by database ID.
|
||||
///
|
||||
/// Sets `status = 'revoked'` and `revoked_at = NOW()` in the `certificates` table.
|
||||
/// Does **not** reissue a replacement; use [`renew_cert`] for that.
|
||||
pub async fn revoke_cert(&self, cert_id: Uuid, db: &PgPool) -> Result<()> {
|
||||
tracing::info!(cert_id = %cert_id, "Revoking certificate");
|
||||
|
||||
let rows = sqlx::query(
|
||||
"UPDATE certificates \
|
||||
SET status = 'revoked'::cert_status, revoked_at = NOW() \
|
||||
WHERE id = $1",
|
||||
)
|
||||
.bind(cert_id)
|
||||
.execute(db)
|
||||
.await
|
||||
.context("revoke certificate in database")?;
|
||||
|
||||
if rows.rows_affected() == 0 {
|
||||
anyhow::bail!("certificate not found: {}", cert_id);
|
||||
}
|
||||
|
||||
tracing::info!(cert_id = %cert_id, "Certificate revoked");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Renew a certificate: revoke the existing cert and issue a new one with
|
||||
/// the same `host_id` and `common_name`.
|
||||
///
|
||||
/// Also issues a new server certificate and populates all `IssuedCert` fields.
|
||||
/// The host's IP address is looked up from the database for server cert SANs.
|
||||
pub async fn renew_cert(&self, cert_id: Uuid, db: &PgPool) -> Result<IssuedCert> {
|
||||
tracing::info!(cert_id = %cert_id, "Renewing certificate");
|
||||
|
||||
// Fetch the existing cert's host_id and common_name.
|
||||
let row = sqlx::query("SELECT host_id, common_name FROM certificates WHERE id = $1")
|
||||
.bind(cert_id)
|
||||
.fetch_one(db)
|
||||
.await
|
||||
.context("fetch certificate for renewal")?;
|
||||
|
||||
let host_id: Uuid = row
|
||||
.try_get("host_id")
|
||||
.context("certificate has no host_id (cannot renew root CA)")?;
|
||||
let common_name: String = row.try_get("common_name").context("fetch common_name")?;
|
||||
|
||||
// Look up the host's IP address for the server cert SANs.
|
||||
let ip_address: String =
|
||||
sqlx::query_scalar("SELECT ip_address::text FROM hosts WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.fetch_one(db)
|
||||
.await
|
||||
.context("fetch host IP address for renewal")?;
|
||||
|
||||
// Revoke the old cert first.
|
||||
self.revoke_cert(cert_id, db).await?;
|
||||
|
||||
// Issue a fresh cert with the same CN.
|
||||
let issued = self
|
||||
.issue_client_cert(host_id, &common_name, &ip_address, db)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
old_cert_id = %cert_id,
|
||||
new_serial = %issued.serial_number,
|
||||
"Certificate renewed"
|
||||
);
|
||||
Ok(issued)
|
||||
}
|
||||
|
||||
/// Generate a self-signed TLS certificate for the web UI using the CA.
|
||||
///
|
||||
/// * Subject: `CN=<hostname>`
|
||||
/// * Key usage: Digital Signature
|
||||
/// * Extended key usage: Server Authentication
|
||||
/// * SAN: DNS `<hostname>`
|
||||
///
|
||||
/// Returns `(cert_pem, key_pem)`. This certificate is **not** stored in the
|
||||
/// database; it is intended for runtime use only.
|
||||
pub async fn issue_web_tls_cert(&self, hostname: &str) -> Result<(String, String)> {
|
||||
tracing::info!(hostname, "Issuing web TLS certificate");
|
||||
|
||||
let key =
|
||||
KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).context("generate web TLS key pair")?;
|
||||
|
||||
let (mut params, serial_hex, expires_at) = base_params(hostname, 365)?;
|
||||
params.is_ca = IsCa::ExplicitNoCa;
|
||||
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
params.subject_alt_names = vec![SanType::DnsName(
|
||||
Ia5String::try_from(hostname.to_owned()).context("hostname is not valid IA5")?,
|
||||
)];
|
||||
|
||||
let (ca_key, ca_cert) = self.ca_objects()?;
|
||||
let cert = params
|
||||
.signed_by(&key, &ca_cert, &ca_key)
|
||||
.context("sign web TLS cert with CA")?;
|
||||
|
||||
let cert_pem = cert.pem();
|
||||
let key_pem = key.serialize_pem();
|
||||
|
||||
tracing::info!(
|
||||
hostname,
|
||||
serial = %serial_hex,
|
||||
expires_at = %expires_at,
|
||||
"Web TLS certificate issued"
|
||||
);
|
||||
|
||||
Ok((cert_pem, key_pem))
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Private helpers
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Reconstruct rcgen `(KeyPair, Certificate)` from the in-memory PEM strings.
|
||||
///
|
||||
/// The returned `Certificate` is used solely as an issuer reference when
|
||||
/// signing leaf certificates; it is never distributed directly.
|
||||
fn ca_objects(&self) -> Result<(KeyPair, Certificate)> {
|
||||
let key =
|
||||
KeyPair::from_pem(&self.ca_key_pem).context("reconstruct CA key pair from PEM")?;
|
||||
let params = CertificateParams::from_ca_cert_pem(&self.ca_cert_pem)
|
||||
.context("reconstruct CA params from PEM")?;
|
||||
let cert = params
|
||||
.self_signed(&key)
|
||||
.context("reconstruct CA certificate for signing")?;
|
||||
Ok((key, cert))
|
||||
}
|
||||
}
|
||||
7
crates/pm-ca/src/lib.rs
Executable file
7
crates/pm-ca/src/lib.rs
Executable file
@ -0,0 +1,7 @@
|
||||
//! pm-ca — Internal Certificate Authority.
|
||||
//!
|
||||
//! Issues and renews mTLS client certificates for agent communication.
|
||||
//! Uses rcgen + rustls. CA key stored at /etc/patch-manager/ca/.
|
||||
|
||||
pub mod ca;
|
||||
pub use ca::{CertAuthority, IssuedCert};
|
||||
26
crates/pm-core/Cargo.toml
Normal file
26
crates/pm-core/Cargo.toml
Normal file
@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "pm-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
ulid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
config = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
330
crates/pm-core/src/audit.rs
Executable file
330
crates/pm-core/src/audit.rs
Executable file
@ -0,0 +1,330 @@
|
||||
//! Audit log helper functions.
|
||||
//!
|
||||
//! Writes tamper-evident, hash-chained audit events to the `audit_log` table.
|
||||
//! The hash chain: each row's `row_hash` = SHA-256(
|
||||
//! prev_hash || action || actor_user_id || actor_username ||
|
||||
//! target_type || target_id || details_json || ip_address ||
|
||||
//! request_id || created_at
|
||||
//! ).
|
||||
//!
|
||||
//! The `prev_hash` column stores the previous row's `row_hash` for chain
|
||||
//! verification. The first row has `prev_hash = ''`.
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::PgPool;
|
||||
use std::net::IpAddr;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Audit event categories (must match the `audit_action` PostgreSQL ENUM).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AuditAction {
|
||||
UserLogin,
|
||||
UserLogout,
|
||||
UserLoginFailed,
|
||||
UserCreated,
|
||||
UserDeleted,
|
||||
UserUpdated,
|
||||
HostRegistered,
|
||||
HostRemoved,
|
||||
GroupCreated,
|
||||
GroupDeleted,
|
||||
GroupMembershipChanged,
|
||||
PatchJobCreated,
|
||||
PatchJobCancelled,
|
||||
PatchJobRollback,
|
||||
MaintenanceWindowCreated,
|
||||
MaintenanceWindowUpdated,
|
||||
MaintenanceWindowDeleted,
|
||||
CertificateIssued,
|
||||
CertificateRenewed,
|
||||
CertificateRevoked,
|
||||
CertificateDownloaded,
|
||||
ConfigChanged,
|
||||
DiscoveryScanStarted,
|
||||
// M11 additions
|
||||
AuditIntegrityVerified,
|
||||
EmailNotificationSent,
|
||||
PatchJobCompleted,
|
||||
PatchJobFailed,
|
||||
MaintenanceWindowReminder,
|
||||
HealthCheckCreated,
|
||||
HealthCheckUpdated,
|
||||
HealthCheckDeleted,
|
||||
CertificateReissued,
|
||||
}
|
||||
|
||||
impl AuditAction {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::UserLogin => "user_login",
|
||||
Self::UserLogout => "user_logout",
|
||||
Self::UserLoginFailed => "user_login_failed",
|
||||
Self::UserCreated => "user_created",
|
||||
Self::UserDeleted => "user_deleted",
|
||||
Self::UserUpdated => "user_updated",
|
||||
Self::HostRegistered => "host_registered",
|
||||
Self::HostRemoved => "host_removed",
|
||||
Self::GroupCreated => "group_created",
|
||||
Self::GroupDeleted => "group_deleted",
|
||||
Self::GroupMembershipChanged => "group_membership_changed",
|
||||
Self::PatchJobCreated => "patch_job_created",
|
||||
Self::PatchJobCancelled => "patch_job_cancelled",
|
||||
Self::PatchJobRollback => "patch_job_rollback",
|
||||
Self::MaintenanceWindowCreated => "maintenance_window_created",
|
||||
Self::MaintenanceWindowUpdated => "maintenance_window_updated",
|
||||
Self::MaintenanceWindowDeleted => "maintenance_window_deleted",
|
||||
Self::CertificateIssued => "certificate_issued",
|
||||
Self::CertificateRenewed => "certificate_renewed",
|
||||
Self::CertificateRevoked => "certificate_revoked",
|
||||
Self::CertificateDownloaded => "certificate_downloaded",
|
||||
Self::ConfigChanged => "config_changed",
|
||||
Self::DiscoveryScanStarted => "discovery_scan_started",
|
||||
Self::AuditIntegrityVerified => "audit_integrity_verified",
|
||||
Self::EmailNotificationSent => "email_notification_sent",
|
||||
Self::PatchJobCompleted => "patch_job_completed",
|
||||
Self::PatchJobFailed => "patch_job_failed",
|
||||
Self::MaintenanceWindowReminder => "maintenance_window_reminder",
|
||||
Self::HealthCheckCreated => "health_check_created",
|
||||
Self::HealthCheckUpdated => "health_check_updated",
|
||||
Self::HealthCheckDeleted => "health_check_deleted",
|
||||
Self::CertificateReissued => "certificate_reissued",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write an audit event to the database.
|
||||
///
|
||||
/// Computes a hash chain entry using the previous row's hash.
|
||||
/// Non-fatal: logs errors but does not propagate them to avoid
|
||||
/// disrupting the primary operation.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn log_event(
|
||||
pool: &PgPool,
|
||||
action: AuditAction,
|
||||
actor_user_id: Option<Uuid>,
|
||||
actor_username: Option<&str>,
|
||||
target_type: Option<&str>,
|
||||
target_id: Option<&str>,
|
||||
details: serde_json::Value,
|
||||
ip_address: Option<IpAddr>,
|
||||
request_id: Option<&str>,
|
||||
) {
|
||||
let result = write_audit_row(
|
||||
pool,
|
||||
action,
|
||||
actor_user_id,
|
||||
actor_username,
|
||||
target_type,
|
||||
target_id,
|
||||
details,
|
||||
ip_address,
|
||||
request_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::error!(error = %e, action = action.as_str(), "Failed to write audit log");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn write_audit_row(
|
||||
pool: &PgPool,
|
||||
action: AuditAction,
|
||||
actor_user_id: Option<Uuid>,
|
||||
actor_username: Option<&str>,
|
||||
target_type: Option<&str>,
|
||||
target_id: Option<&str>,
|
||||
details: serde_json::Value,
|
||||
ip_address: Option<IpAddr>,
|
||||
request_id: Option<&str>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
// Fetch previous hash for chain
|
||||
let prev_hash: Option<String> =
|
||||
sqlx::query_scalar("SELECT row_hash FROM audit_log ORDER BY id DESC LIMIT 1")
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let prev = prev_hash.unwrap_or_default();
|
||||
let now = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true);
|
||||
let action_str = action.as_str();
|
||||
let uid_str = actor_user_id.map(|u| u.to_string()).unwrap_or_default();
|
||||
let uname = actor_username.unwrap_or("");
|
||||
let ttype = target_type.unwrap_or("");
|
||||
let tid = target_id.unwrap_or("");
|
||||
let details_str = serde_json::to_string(&details).unwrap_or_default();
|
||||
let ip_str = ip_address.map(|ip| ip.to_string()).unwrap_or_default();
|
||||
let rid = request_id.unwrap_or("");
|
||||
|
||||
// Hash: SHA-256(prev_hash + action + actor_user_id + actor_username +
|
||||
// target_type + target_id + details_json + ip_address +
|
||||
// request_id + created_at)
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(prev.as_bytes());
|
||||
hasher.update(action_str.as_bytes());
|
||||
hasher.update(uid_str.as_bytes());
|
||||
hasher.update(uname.as_bytes());
|
||||
hasher.update(ttype.as_bytes());
|
||||
hasher.update(tid.as_bytes());
|
||||
hasher.update(details_str.as_bytes());
|
||||
hasher.update(ip_str.as_bytes());
|
||||
hasher.update(rid.as_bytes());
|
||||
hasher.update(now.as_bytes());
|
||||
let row_hash = hex::encode(hasher.finalize());
|
||||
|
||||
let ip_for_db = ip_address.map(|ip| ip.to_string());
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO audit_log
|
||||
(action, actor_user_id, actor_username, target_type, target_id,
|
||||
details, ip_address, request_id, created_at, row_hash, prev_hash)
|
||||
VALUES
|
||||
($1::audit_action, $2, $3, $4, $5, $6, $7::inet, $8, $9::timestamptz, $10, $11)
|
||||
"#,
|
||||
)
|
||||
.bind(action_str)
|
||||
.bind(actor_user_id)
|
||||
.bind(actor_username)
|
||||
.bind(target_type)
|
||||
.bind(target_id)
|
||||
.bind(details)
|
||||
.bind(ip_for_db)
|
||||
.bind(request_id)
|
||||
.bind(&now)
|
||||
.bind(&row_hash)
|
||||
.bind(&prev)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Result of an audit integrity verification pass.
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct IntegrityResult {
|
||||
/// Whether the chain is intact (no tampering detected).
|
||||
pub intact: bool,
|
||||
/// Total number of rows checked.
|
||||
pub rows_checked: i64,
|
||||
/// List of errors found (row id, expected hash, actual hash).
|
||||
pub errors: Vec<IntegrityError>,
|
||||
}
|
||||
|
||||
/// A single integrity error detected in the audit chain.
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct IntegrityError {
|
||||
pub row_id: i64,
|
||||
pub expected_hash: String,
|
||||
pub actual_hash: String,
|
||||
}
|
||||
|
||||
/// Row read from audit_log for integrity verification.
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct AuditRow {
|
||||
id: i64,
|
||||
action: String,
|
||||
actor_user_id: Option<uuid::Uuid>,
|
||||
actor_username: Option<String>,
|
||||
target_type: Option<String>,
|
||||
target_id: Option<String>,
|
||||
details: Option<serde_json::Value>,
|
||||
ip_address: Option<String>,
|
||||
request_id: Option<String>,
|
||||
created_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
row_hash: String,
|
||||
prev_hash: String,
|
||||
}
|
||||
|
||||
/// Walk the audit_log rows ordered by id and verify each row_hash matches
|
||||
/// the recomputed hash. Returns an [`IntegrityResult`] describing any
|
||||
/// tampering detected.
|
||||
pub async fn verify_integrity(pool: &PgPool) -> IntegrityResult {
|
||||
let rows: Vec<AuditRow> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, action::text AS action, actor_user_id, actor_username,
|
||||
target_type, target_id, details,
|
||||
host(ip_address) AS ip_address,
|
||||
request_id, created_at, row_hash, prev_hash
|
||||
FROM audit_log
|
||||
ORDER BY id ASC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "verify_integrity: failed to fetch audit rows");
|
||||
return IntegrityResult {
|
||||
intact: false,
|
||||
rows_checked: 0,
|
||||
errors: vec![],
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
let mut errors = Vec::new();
|
||||
let mut expected_prev_hash = String::new();
|
||||
|
||||
for row in &rows {
|
||||
// Verify prev_hash linkage
|
||||
if row.prev_hash != expected_prev_hash {
|
||||
errors.push(IntegrityError {
|
||||
row_id: row.id,
|
||||
expected_hash: expected_prev_hash.clone(),
|
||||
actual_hash: row.prev_hash.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Recompute the row hash from all fields
|
||||
let uid_str = row.actor_user_id.map(|u| u.to_string()).unwrap_or_default();
|
||||
let uname = row.actor_username.as_deref().unwrap_or("");
|
||||
let ttype = row.target_type.as_deref().unwrap_or("");
|
||||
let tid = row.target_id.as_deref().unwrap_or("");
|
||||
let details_str = row
|
||||
.details
|
||||
.as_ref()
|
||||
.and_then(|v| serde_json::to_string(v).ok())
|
||||
.unwrap_or_default();
|
||||
let ip_str = row.ip_address.as_deref().unwrap_or("");
|
||||
let rid = row.request_id.as_deref().unwrap_or("");
|
||||
let created_str = row
|
||||
.created_at
|
||||
.map(|c| c.to_rfc3339_opts(chrono::SecondsFormat::Micros, true))
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(row.prev_hash.as_bytes());
|
||||
hasher.update(row.action.as_bytes());
|
||||
hasher.update(uid_str.as_bytes());
|
||||
hasher.update(uname.as_bytes());
|
||||
hasher.update(ttype.as_bytes());
|
||||
hasher.update(tid.as_bytes());
|
||||
hasher.update(details_str.as_bytes());
|
||||
hasher.update(ip_str.as_bytes());
|
||||
hasher.update(rid.as_bytes());
|
||||
hasher.update(created_str.as_bytes());
|
||||
let computed_hash = hex::encode(hasher.finalize());
|
||||
|
||||
if row.row_hash != computed_hash {
|
||||
errors.push(IntegrityError {
|
||||
row_id: row.id,
|
||||
expected_hash: computed_hash,
|
||||
actual_hash: row.row_hash.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Next row should have this row's hash as prev_hash
|
||||
expected_prev_hash = row.row_hash.clone();
|
||||
}
|
||||
|
||||
let intact = errors.is_empty();
|
||||
let rows_checked = rows.len() as i64;
|
||||
|
||||
IntegrityResult {
|
||||
intact,
|
||||
rows_checked,
|
||||
errors,
|
||||
}
|
||||
}
|
||||
214
crates/pm-core/src/config.rs
Normal file
214
crates/pm-core/src/config.rs
Normal file
@ -0,0 +1,214 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Rate limiting configuration per route group.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RateLimitConfig {
|
||||
/// Enrollment endpoint: requests per minute per IP (default: 5)
|
||||
#[serde(default = "default_enrollment_rpm")]
|
||||
pub enrollment_rpm: u32,
|
||||
/// Enrollment burst allowance (default: 3)
|
||||
#[serde(default = "default_enrollment_burst")]
|
||||
pub enrollment_burst: u32,
|
||||
/// Public auth endpoints: requests per minute per IP (default: 20)
|
||||
#[serde(default = "default_auth_rpm")]
|
||||
pub auth_rpm: u32,
|
||||
/// Auth burst allowance (default: 10)
|
||||
#[serde(default = "default_auth_burst")]
|
||||
pub auth_burst: u32,
|
||||
/// Authenticated API: requests per minute per IP (default: 120)
|
||||
#[serde(default = "default_api_rpm")]
|
||||
pub api_rpm: u32,
|
||||
/// API burst allowance (default: 30)
|
||||
#[serde(default = "default_api_burst")]
|
||||
pub api_burst: u32,
|
||||
}
|
||||
|
||||
fn default_enrollment_rpm() -> u32 {
|
||||
5
|
||||
}
|
||||
fn default_enrollment_burst() -> u32 {
|
||||
3
|
||||
}
|
||||
fn default_auth_rpm() -> u32 {
|
||||
20
|
||||
}
|
||||
fn default_auth_burst() -> u32 {
|
||||
10
|
||||
}
|
||||
fn default_api_rpm() -> u32 {
|
||||
120
|
||||
}
|
||||
fn default_api_burst() -> u32 {
|
||||
30
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enrollment_rpm: default_enrollment_rpm(),
|
||||
enrollment_burst: default_enrollment_burst(),
|
||||
auth_rpm: default_auth_rpm(),
|
||||
auth_burst: default_auth_burst(),
|
||||
api_rpm: default_api_rpm(),
|
||||
api_burst: default_api_burst(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Top-level application configuration.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub worker: WorkerConfig,
|
||||
pub logging: LoggingConfig,
|
||||
pub security: SecurityConfig,
|
||||
#[serde(default)]
|
||||
pub rate_limit: RateLimitConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ServerConfig {
|
||||
/// Bind address for the web server
|
||||
pub host: String,
|
||||
/// HTTPS port
|
||||
pub port: u16,
|
||||
/// Path to static frontend assets
|
||||
pub static_dir: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct DatabaseConfig {
|
||||
/// Full PostgreSQL connection URL
|
||||
pub url: String,
|
||||
/// Maximum pool connections
|
||||
pub max_connections: u32,
|
||||
/// Minimum pool connections
|
||||
pub min_connections: u32,
|
||||
/// Connection acquire timeout in seconds
|
||||
pub acquire_timeout_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct WorkerConfig {
|
||||
/// Health poll interval in seconds (default: 300 = 5 min)
|
||||
pub health_poll_interval_secs: u64,
|
||||
/// Patch data poll interval in seconds (default: 1800 = 30 min)
|
||||
pub patch_poll_interval_secs: u64,
|
||||
/// Health check poll interval in seconds (default: 300 = 5 min)
|
||||
#[serde(default = "default_health_check_poll_interval")]
|
||||
pub health_check_poll_interval_secs: u64,
|
||||
/// Maximum concurrent agent calls
|
||||
pub max_concurrent_agent_calls: usize,
|
||||
/// Worker heartbeat interval in seconds
|
||||
pub heartbeat_interval_secs: u64,
|
||||
/// WS relay HTTP polling fallback interval in seconds (default: 10)
|
||||
pub ws_relay_poll_interval_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct LoggingConfig {
|
||||
/// Log level filter: trace, debug, info, warn, error
|
||||
pub level: String,
|
||||
/// Output format: json or pretty
|
||||
pub format: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct SecurityConfig {
|
||||
/// IP whitelist (CIDR or individual IPs); empty = allow all (not recommended)
|
||||
pub ip_whitelist: Vec<String>,
|
||||
/// JWT signing key path (Ed25519 PEM)
|
||||
pub jwt_signing_key_path: String,
|
||||
/// JWT verification key path (Ed25519 public PEM)
|
||||
pub jwt_verify_key_path: String,
|
||||
/// JWT access token TTL in seconds (default: 900 = 15 min)
|
||||
pub jwt_access_ttl_secs: u64,
|
||||
/// Agent mTLS client cert path
|
||||
pub agent_client_cert_path: String,
|
||||
/// Agent mTLS client key path
|
||||
pub agent_client_key_path: String,
|
||||
/// Internal CA cert path
|
||||
pub ca_cert_path: String,
|
||||
/// Internal CA key path
|
||||
pub ca_key_path: String,
|
||||
/// Web UI TLS cert path
|
||||
pub web_tls_cert_path: String,
|
||||
/// Web UI TLS key path
|
||||
pub web_tls_key_path: String,
|
||||
/// Frontend URL to redirect to after SSO callback (default: http://localhost:5173/auth/sso/callback)
|
||||
#[serde(default = "default_sso_callback_url")]
|
||||
pub sso_callback_url: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// Load configuration from a TOML file and environment variable overrides.
|
||||
///
|
||||
/// Environment variables follow the pattern: `PATCH_MANAGER__SECTION__KEY`
|
||||
/// e.g. `PATCH_MANAGER__DATABASE__URL=postgres://...`
|
||||
pub fn load(config_path: &str) -> Result<Self, ConfigError> {
|
||||
let cfg = Config::builder()
|
||||
.add_source(File::with_name(config_path).required(false))
|
||||
.add_source(
|
||||
Environment::with_prefix("PATCH_MANAGER")
|
||||
.separator("__")
|
||||
.try_parsing(true),
|
||||
)
|
||||
.build()?;
|
||||
|
||||
cfg.try_deserialize()
|
||||
}
|
||||
}
|
||||
|
||||
fn default_health_check_poll_interval() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_sso_callback_url() -> String {
|
||||
"http://localhost:5173/auth/sso/callback".to_string()
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server: ServerConfig {
|
||||
host: "0.0.0.0".to_string(),
|
||||
port: 443,
|
||||
static_dir: "/usr/share/patch-manager/frontend".to_string(),
|
||||
},
|
||||
database: DatabaseConfig {
|
||||
url: "postgres://patch_manager:changeme@localhost/patch_manager".to_string(),
|
||||
max_connections: 20,
|
||||
min_connections: 2,
|
||||
acquire_timeout_secs: 30,
|
||||
},
|
||||
worker: WorkerConfig {
|
||||
health_poll_interval_secs: 300,
|
||||
patch_poll_interval_secs: 1800,
|
||||
health_check_poll_interval_secs: 300,
|
||||
max_concurrent_agent_calls: 64,
|
||||
heartbeat_interval_secs: 30,
|
||||
ws_relay_poll_interval_secs: 10,
|
||||
},
|
||||
logging: LoggingConfig {
|
||||
level: "info".to_string(),
|
||||
format: "json".to_string(),
|
||||
},
|
||||
security: SecurityConfig {
|
||||
ip_whitelist: vec![],
|
||||
jwt_signing_key_path: "/etc/patch-manager/jwt/signing.pem".to_string(),
|
||||
jwt_verify_key_path: "/etc/patch-manager/jwt/verify.pem".to_string(),
|
||||
jwt_access_ttl_secs: 900,
|
||||
agent_client_cert_path: "/etc/patch-manager/certs/client.crt".to_string(),
|
||||
agent_client_key_path: "/etc/patch-manager/certs/client.key".to_string(),
|
||||
ca_cert_path: "/etc/patch-manager/ca/ca.crt".to_string(),
|
||||
ca_key_path: "/etc/patch-manager/ca/ca.key".to_string(),
|
||||
web_tls_cert_path: "/etc/patch-manager/tls/web.crt".to_string(),
|
||||
web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(),
|
||||
sso_callback_url: default_sso_callback_url(),
|
||||
},
|
||||
rate_limit: RateLimitConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
80
crates/pm-core/src/crypto.rs
Executable file
80
crates/pm-core/src/crypto.rs
Executable file
@ -0,0 +1,80 @@
|
||||
//! AES-256-GCM encryption for sensitive health check credentials.
|
||||
//!
|
||||
//! Uses a per-install key stored at `/etc/patch-manager/keys/health-check.key`.
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{Aead, KeyInit, OsRng},
|
||||
Aes256Gcm, Nonce,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
pub const KEY_PATH: &str = "/etc/patch-manager/keys/health-check.key";
|
||||
|
||||
/// Load or create the per-install encryption key.
|
||||
/// If the key file doesn't exist, generates a new 256-bit key and saves it.
|
||||
pub fn load_or_create_key(path: &Path) -> Result<[u8; 32], CryptoError> {
|
||||
if path.exists() {
|
||||
let key_bytes = fs::read(path).map_err(CryptoError::Io)?;
|
||||
if key_bytes.len() != 32 {
|
||||
return Err(CryptoError::InvalidKeyLength(key_bytes.len()));
|
||||
}
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(&key_bytes);
|
||||
Ok(key)
|
||||
} else {
|
||||
let mut key = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut key);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(CryptoError::Io)?;
|
||||
}
|
||||
fs::write(path, key).map_err(CryptoError::Io)?;
|
||||
// Set permissions to 0600 (owner read/write only)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
fs::set_permissions(path, fs::Permissions::from_mode(0o600))
|
||||
.map_err(CryptoError::Io)?;
|
||||
}
|
||||
Ok(key)
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypt plaintext with AES-256-GCM. Returns (ciphertext, nonce).
|
||||
pub fn encrypt(plaintext: &str, key: &[u8; 32]) -> Result<(Vec<u8>, Vec<u8>), CryptoError> {
|
||||
let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::KeyInit(e.to_string()))?;
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
OsRng.fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, plaintext.as_bytes())
|
||||
.map_err(|_| CryptoError::EncryptionFailed)?;
|
||||
Ok((ciphertext, nonce_bytes.to_vec()))
|
||||
}
|
||||
|
||||
/// Decrypt AES-256-GCM ciphertext with the given nonce.
|
||||
pub fn decrypt(ciphertext: &[u8], nonce: &[u8], key: &[u8; 32]) -> Result<String, CryptoError> {
|
||||
let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::KeyInit(e.to_string()))?;
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext)
|
||||
.map_err(|_| CryptoError::DecryptionFailed)?;
|
||||
String::from_utf8(plaintext).map_err(CryptoError::Utf8)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CryptoError {
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Invalid key length: expected 32 bytes, got {0}")]
|
||||
InvalidKeyLength(usize),
|
||||
#[error("Key init error: {0}")]
|
||||
KeyInit(String),
|
||||
#[error("Encryption failed")]
|
||||
EncryptionFailed,
|
||||
#[error("Decryption failed")]
|
||||
DecryptionFailed,
|
||||
#[error("UTF-8 error: {0}")]
|
||||
Utf8(#[from] std::string::FromUtf8Error),
|
||||
}
|
||||
117
crates/pm-core/src/db.rs
Executable file
117
crates/pm-core/src/db.rs
Executable file
@ -0,0 +1,117 @@
|
||||
use crate::config::DatabaseConfig;
|
||||
use crate::models::{CreateEnrollmentRequest, EnrollmentRequest};
|
||||
use sqlx::postgres::{PgPool, PgPoolOptions};
|
||||
use std::time::Duration;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Initialize and return a PostgreSQL connection pool.
|
||||
pub async fn init_pool(cfg: &DatabaseConfig) -> Result<PgPool, sqlx::Error> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(cfg.max_connections)
|
||||
.min_connections(cfg.min_connections)
|
||||
.acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
|
||||
.connect(&cfg.url)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
max_connections = cfg.max_connections,
|
||||
"PostgreSQL connection pool initialized"
|
||||
);
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// Run embedded SQLx migrations.
|
||||
/// Uses a PostgreSQL advisory lock to ensure only one writer runs migrations.
|
||||
pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::migrate::MigrateError> {
|
||||
tracing::info!("Acquiring advisory lock for migrations");
|
||||
|
||||
// Advisory lock key — consistent hash of the application name
|
||||
const LOCK_KEY: i64 = 0x7061_7463_686d_6772; // "patchmgr" bytes
|
||||
|
||||
// Acquire advisory lock; blocks until granted
|
||||
sqlx::query("SELECT pg_advisory_lock($1)")
|
||||
.bind(LOCK_KEY)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to acquire advisory lock");
|
||||
e
|
||||
})
|
||||
.expect("Advisory lock must be acquired before running migrations");
|
||||
|
||||
tracing::info!("Running database migrations");
|
||||
let result = sqlx::migrate!("../../migrations").run(pool).await;
|
||||
|
||||
// Always release the lock
|
||||
sqlx::query("SELECT pg_advisory_unlock($1)")
|
||||
.bind(LOCK_KEY)
|
||||
.execute(pool)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
match &result {
|
||||
Ok(_) => tracing::info!("Database migrations completed successfully"),
|
||||
Err(e) => tracing::error!(error = %e, "Database migrations failed"),
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Enrollment Requests
|
||||
// ============================================================
|
||||
|
||||
pub async fn create_enrollment_request(
|
||||
pool: &PgPool,
|
||||
req: CreateEnrollmentRequest,
|
||||
token_hash: String,
|
||||
) -> Result<EnrollmentRequest, sqlx::Error> {
|
||||
sqlx::query_as::<
|
||||
_,
|
||||
EnrollmentRequest,
|
||||
>(
|
||||
r#"
|
||||
INSERT INTO enrollment_requests (machine_id, fqdn, ip_address, os_details, polling_token, hostname)
|
||||
VALUES ($1, $2, $3::inet, $4, $5, $6)
|
||||
RETURNING id, machine_id, fqdn, ip_address::text, os_details, polling_token, hostname, created_at, expires_at
|
||||
"#,
|
||||
)
|
||||
.bind(req.machine_id)
|
||||
.bind(req.fqdn)
|
||||
.bind(req.ip_address)
|
||||
.bind(req.os_details)
|
||||
.bind(token_hash)
|
||||
.bind(&req.hostname)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_enrollment_requests(
|
||||
pool: &PgPool,
|
||||
) -> Result<Vec<EnrollmentRequest>, sqlx::Error> {
|
||||
sqlx::query_as::<_, EnrollmentRequest>(
|
||||
"SELECT id, machine_id, fqdn, ip_address::text, os_details, polling_token, hostname, created_at, expires_at FROM enrollment_requests ORDER BY created_at DESC",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_enrollment_request(pool: &PgPool, id: Uuid) -> Result<u64, sqlx::Error> {
|
||||
let result = sqlx::query("DELETE FROM enrollment_requests WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Check that the database schema is at the expected version.
|
||||
/// Used by the worker to wait until migrations have been applied.
|
||||
pub async fn check_schema_version(pool: &PgPool) -> Result<i64, sqlx::Error> {
|
||||
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_migrations WHERE success = true")
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.0)
|
||||
}
|
||||
126
crates/pm-core/src/error.rs
Executable file
126
crates/pm-core/src/error.rs
Executable file
@ -0,0 +1,126 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Unified application error type.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AppError {
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
#[error("Forbidden: {0}")]
|
||||
Forbidden(String),
|
||||
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(String),
|
||||
|
||||
#[error("Conflict: {0}")]
|
||||
Conflict(String),
|
||||
|
||||
#[error("Unprocessable entity: {0}")]
|
||||
UnprocessableEntity(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(#[from] anyhow::Error),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
}
|
||||
|
||||
/// JSON error envelope returned to clients.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorDetail {
|
||||
/// Machine-readable error code (e.g. "not_found", "unauthorized")
|
||||
pub code: String,
|
||||
/// Human-readable message
|
||||
pub message: String,
|
||||
/// Request ID for correlation (set by middleware)
|
||||
pub request_id: Option<String>,
|
||||
/// Optional structured details
|
||||
pub details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error: ErrorDetail {
|
||||
code: code.into(),
|
||||
message: message.into(),
|
||||
request_id: None,
|
||||
details: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
|
||||
self.error.request_id = Some(request_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_details(mut self, details: serde_json::Value) -> Self {
|
||||
self.error.details = Some(details);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, code, message) = match &self {
|
||||
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, "not_found", msg.clone()),
|
||||
AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, "unauthorized", msg.clone()),
|
||||
AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, "forbidden", msg.clone()),
|
||||
AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, "bad_request", msg.clone()),
|
||||
AppError::Conflict(msg) => (StatusCode::CONFLICT, "conflict", msg.clone()),
|
||||
AppError::UnprocessableEntity(msg) => (
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"unprocessable_entity",
|
||||
msg.clone(),
|
||||
),
|
||||
AppError::Database(e) => {
|
||||
tracing::error!(error = %e, "Database error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"An internal error occurred".to_string(),
|
||||
)
|
||||
},
|
||||
AppError::Internal(e) => {
|
||||
tracing::error!(error = %e, "Internal error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"An internal error occurred".to_string(),
|
||||
)
|
||||
},
|
||||
AppError::Config(msg) => {
|
||||
tracing::error!(error = %msg, "Configuration error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"config_error",
|
||||
"Server configuration error".to_string(),
|
||||
)
|
||||
},
|
||||
};
|
||||
|
||||
let body = ErrorResponse::new(code, message);
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience alias for handler return types.
|
||||
pub type ApiResult<T> = Result<T, AppError>;
|
||||
23
crates/pm-core/src/lib.rs
Executable file
23
crates/pm-core/src/lib.rs
Executable file
@ -0,0 +1,23 @@
|
||||
pub mod audit;
|
||||
pub mod config;
|
||||
pub mod crypto;
|
||||
pub mod db;
|
||||
pub mod error;
|
||||
pub mod logging;
|
||||
pub mod models;
|
||||
pub mod request_id;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use config::AppConfig;
|
||||
pub use crypto::{decrypt, encrypt, load_or_create_key, CryptoError, KEY_PATH};
|
||||
pub use error::{AppError, ErrorResponse};
|
||||
pub use models::{
|
||||
AdminResetPasswordRequest, AuthProvider, ChangePasswordRequest, CreateGroupRequest,
|
||||
CreateHealthCheckRequest, CreateHostRequest, CreateUserRequest, DiscoveryCidrRequest,
|
||||
DiscoveryResult, Group, HealthCheck, HealthCheckResult, HealthCheckWithResult, Host,
|
||||
HostHealthStatus, HostSummary, RegisterDiscoveredRequest, UpdateGroupRequest,
|
||||
UpdateHealthCheckRequest, UpdateUserRequest, User, UserRole as DbUserRole,
|
||||
};
|
||||
|
||||
// Re-export audit integrity types
|
||||
pub use audit::{verify_integrity, IntegrityError, IntegrityResult};
|
||||
31
crates/pm-core/src/logging.rs
Executable file
31
crates/pm-core/src/logging.rs
Executable file
@ -0,0 +1,31 @@
|
||||
use crate::config::LoggingConfig;
|
||||
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
/// Initialize the global tracing subscriber.
|
||||
///
|
||||
/// Format is controlled by `cfg.format`:
|
||||
/// - `"json"` — machine-readable JSON (production default)
|
||||
/// - anything else — human-readable pretty output (development)
|
||||
///
|
||||
/// Log level is controlled by `cfg.level` (e.g. `"info"`, `"debug"`).
|
||||
/// The `RUST_LOG` environment variable overrides `cfg.level`.
|
||||
pub fn init(cfg: &LoggingConfig) {
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cfg.level));
|
||||
|
||||
match cfg.format.as_str() {
|
||||
"json" => {
|
||||
tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt::layer().json().with_current_span(true))
|
||||
.init();
|
||||
},
|
||||
_ => {
|
||||
tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt::layer().pretty())
|
||||
.init();
|
||||
},
|
||||
}
|
||||
|
||||
tracing::info!(format = %cfg.format, level = %cfg.level, "Logging initialized");
|
||||
}
|
||||
555
crates/pm-core/src/models.rs
Executable file
555
crates/pm-core/src/models.rs
Executable file
@ -0,0 +1,555 @@
|
||||
//! Shared database model types used across pm-web and pm-worker.
|
||||
//!
|
||||
//! These match the database schema defined in migrations/.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use uuid::Uuid;
|
||||
|
||||
// ============================================================
|
||||
// Enumerations (matching PostgreSQL ENUM types)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[sqlx(type_name = "host_health_status", rename_all = "lowercase")]
|
||||
pub enum HostHealthStatus {
|
||||
Pending,
|
||||
Healthy,
|
||||
Degraded,
|
||||
Unreachable,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HostHealthStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded => write!(f, "degraded"),
|
||||
Self::Unreachable => write!(f, "unreachable"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
|
||||
pub enum UserRole {
|
||||
Admin,
|
||||
Operator,
|
||||
Reporter,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UserRole {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Admin => write!(f, "admin"),
|
||||
Self::Operator => write!(f, "operator"),
|
||||
Self::Reporter => write!(f, "reporter"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[sqlx(type_name = "auth_provider", rename_all = "snake_case")]
|
||||
pub enum AuthProvider {
|
||||
Local,
|
||||
#[sqlx(rename = "azure_sso")]
|
||||
AzureSso,
|
||||
Keycloak,
|
||||
Oidc,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AuthProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Local => write!(f, "local"),
|
||||
Self::AzureSso => write!(f, "azure_sso"),
|
||||
Self::Keycloak => write!(f, "keycloak"),
|
||||
Self::Oidc => write!(f, "oidc"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Host
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct Host {
|
||||
pub id: Uuid,
|
||||
pub fqdn: String,
|
||||
pub ip_address: String, // stored as INET, returned as text
|
||||
pub display_name: String,
|
||||
pub os_family: Option<String>,
|
||||
pub os_name: Option<String>,
|
||||
pub arch: Option<String>,
|
||||
pub agent_version: Option<String>,
|
||||
pub health_status: HostHealthStatus,
|
||||
pub last_health_at: Option<DateTime<Utc>>,
|
||||
pub last_patch_at: Option<DateTime<Utc>>,
|
||||
pub agent_port: i32,
|
||||
pub notes: String,
|
||||
pub registered_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Payload for registering a new host.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateHostRequest {
|
||||
/// FQDN or IP address of the managed host
|
||||
pub fqdn: String,
|
||||
pub display_name: Option<String>,
|
||||
pub agent_port: Option<i32>,
|
||||
pub notes: Option<String>,
|
||||
pub group_ids: Option<Vec<Uuid>>,
|
||||
}
|
||||
|
||||
/// Payload for updating an existing host.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateHostRequest {
|
||||
pub fqdn: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Host list item (lighter projection for list views)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct HostSummary {
|
||||
pub id: Uuid,
|
||||
pub fqdn: String,
|
||||
pub ip_address: String,
|
||||
pub display_name: String,
|
||||
pub os_family: Option<String>,
|
||||
pub os_name: Option<String>,
|
||||
pub health_status: HostHealthStatus,
|
||||
pub agent_version: Option<String>,
|
||||
pub patches_missing: i32,
|
||||
pub health_check_status: Option<String>,
|
||||
pub registered_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Host Enrollment
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct EnrollmentRequest {
|
||||
pub id: Uuid,
|
||||
pub machine_id: String,
|
||||
pub fqdn: String,
|
||||
pub ip_address: String,
|
||||
pub os_details: serde_json::Value,
|
||||
pub polling_token: String,
|
||||
/// Short hostname provided during enrollment (optional).
|
||||
pub hostname: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Payload for initial host enrollment request.
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct CreateEnrollmentRequest {
|
||||
pub machine_id: String,
|
||||
pub fqdn: String,
|
||||
pub ip_address: String,
|
||||
pub os_details: serde_json::Value,
|
||||
/// Short hostname (from /etc/hostname, optional).
|
||||
pub hostname: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "status", rename_all = "lowercase")]
|
||||
pub enum EnrollmentStatusResponse {
|
||||
Pending,
|
||||
Approved {
|
||||
ca_crt: String,
|
||||
server_crt: String,
|
||||
server_key: String,
|
||||
},
|
||||
Denied,
|
||||
NotFound,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PkiBundle {
|
||||
pub ca_crt: String,
|
||||
pub server_crt: String,
|
||||
pub server_key: String,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Health Checks
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct HealthCheck {
|
||||
pub id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub name: String,
|
||||
pub check_type: String, // "service" or "http"
|
||||
pub enabled: bool,
|
||||
// Service check fields
|
||||
pub service_name: Option<String>,
|
||||
// HTTP check fields
|
||||
pub url: Option<String>,
|
||||
pub expected_body: Option<String>,
|
||||
pub ignore_cert_errors: bool,
|
||||
pub basic_auth_user: Option<String>,
|
||||
pub target_host_id: Option<Uuid>,
|
||||
// basic_auth_pass_encrypted and nonce NOT exposed in API responses
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthCheckWithResult {
|
||||
#[serde(flatten)]
|
||||
pub check: HealthCheck,
|
||||
pub last_result: Option<HealthCheckResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct HealthCheckResult {
|
||||
pub id: Uuid,
|
||||
pub check_id: Uuid,
|
||||
pub healthy: bool,
|
||||
pub detail: Option<String>,
|
||||
pub latency_ms: Option<i32>,
|
||||
pub checked_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateHealthCheckRequest {
|
||||
pub name: String,
|
||||
pub check_type: String, // "service" or "http"
|
||||
pub service_name: Option<String>,
|
||||
pub url: Option<String>,
|
||||
pub expected_body: Option<String>,
|
||||
#[serde(default = "default_true")]
|
||||
pub ignore_cert_errors: bool,
|
||||
pub basic_auth_user: Option<String>,
|
||||
pub basic_auth_pass: Option<String>, // plaintext in request, encrypted before storage
|
||||
pub target_host_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpdateHealthCheckRequest {
|
||||
pub name: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
pub service_name: Option<String>,
|
||||
pub url: Option<String>,
|
||||
pub expected_body: Option<String>,
|
||||
pub ignore_cert_errors: Option<bool>,
|
||||
pub basic_auth_user: Option<String>,
|
||||
pub basic_auth_pass: Option<String>, // if provided, re-encrypt
|
||||
pub target_host_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Group
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct Group {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateGroupRequest {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateGroupRequest {
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// User
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub display_name: String,
|
||||
pub email: String,
|
||||
pub role: UserRole,
|
||||
pub auth_provider: AuthProvider,
|
||||
pub mfa_enabled: bool,
|
||||
pub is_active: bool,
|
||||
pub force_password_reset: bool,
|
||||
pub last_login_at: Option<DateTime<Utc>>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// User create payload (admin-only)
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateUserRequest {
|
||||
pub username: String,
|
||||
pub display_name: Option<String>,
|
||||
pub email: String,
|
||||
pub role: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
/// User update payload
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateUserRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub role: Option<String>,
|
||||
pub is_active: Option<bool>,
|
||||
pub force_password_reset: Option<bool>,
|
||||
}
|
||||
|
||||
/// Self-service password change payload
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChangePasswordRequest {
|
||||
pub current_password: String,
|
||||
pub new_password: String,
|
||||
}
|
||||
|
||||
/// Admin password reset payload
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AdminResetPasswordRequest {
|
||||
pub new_password: String,
|
||||
#[serde(default)]
|
||||
pub force_password_reset: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Discovery
|
||||
// ============================================================
|
||||
|
||||
/// Request body for CIDR auto-discovery scan.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DiscoveryCidrRequest {
|
||||
/// CIDR range to scan (e.g. "10.0.0.0/24")
|
||||
pub cidr: String,
|
||||
/// Agent port to probe (default 12443)
|
||||
pub agent_port: Option<i32>,
|
||||
}
|
||||
|
||||
/// A single discovered host result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct DiscoveryResult {
|
||||
pub id: Uuid,
|
||||
pub scan_id: Uuid,
|
||||
pub ip_address: String,
|
||||
pub fqdn: Option<String>,
|
||||
pub agent_version: Option<String>,
|
||||
pub os_name: Option<String>,
|
||||
pub agent_port: i32,
|
||||
pub discovered_at: DateTime<Utc>,
|
||||
pub registered: bool,
|
||||
}
|
||||
|
||||
/// Payload for registering a host from a discovery result.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterDiscoveredRequest {
|
||||
pub discovery_id: Uuid,
|
||||
pub display_name: Option<String>,
|
||||
pub group_ids: Option<Vec<Uuid>>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Patch Jobs
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
|
||||
pub enum JobStatus {
|
||||
Queued,
|
||||
Pending,
|
||||
Running,
|
||||
Succeeded,
|
||||
Failed,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for JobStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Queued => write!(f, "queued"),
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Succeeded => write!(f, "succeeded"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
Self::Cancelled => write!(f, "cancelled"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[sqlx(type_name = "job_kind", rename_all = "snake_case")]
|
||||
pub enum JobKind {
|
||||
#[sqlx(rename = "patch_apply")]
|
||||
PatchApply,
|
||||
#[sqlx(rename = "patch_remove")]
|
||||
PatchRemove,
|
||||
Reboot,
|
||||
Rollback,
|
||||
}
|
||||
|
||||
/// Full `patch_jobs` row.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PatchJob {
|
||||
pub id: Uuid,
|
||||
pub kind: JobKind,
|
||||
pub status: JobStatus,
|
||||
pub created_by_user_id: Option<Uuid>,
|
||||
pub parent_job_id: Option<Uuid>,
|
||||
pub maintenance_window_id: Option<Uuid>,
|
||||
pub immediate: bool,
|
||||
pub patch_selection: serde_json::Value,
|
||||
pub notes: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Full `patch_job_hosts` row (includes columns added in migration 003).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PatchJobHost {
|
||||
pub id: Uuid,
|
||||
pub job_id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub status: JobStatus,
|
||||
pub agent_job_id: Option<String>,
|
||||
pub retry_count: i32,
|
||||
pub output: String,
|
||||
pub error_message: Option<String>,
|
||||
pub retry_next_at: Option<DateTime<Utc>>,
|
||||
pub last_error: Option<String>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Request payload for creating a patch job via `POST /api/v1/jobs`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateJobRequest {
|
||||
/// Host IDs to patch.
|
||||
pub host_ids: Vec<Uuid>,
|
||||
/// Package names to apply (empty = all available patches).
|
||||
pub packages: Vec<String>,
|
||||
/// If true: apply immediately. If false: queue for next maintenance window.
|
||||
pub immediate: bool,
|
||||
/// Optional maintenance window to bind to.
|
||||
pub maintenance_window_id: Option<Uuid>,
|
||||
/// Allow reboot if required by patches.
|
||||
pub allow_reboot: Option<bool>,
|
||||
/// Optional operator notes.
|
||||
pub notes: Option<String>,
|
||||
}
|
||||
|
||||
/// Summary row for job list view (aggregates per-host counts).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PatchJobSummary {
|
||||
pub id: Uuid,
|
||||
pub kind: JobKind,
|
||||
pub status: JobStatus,
|
||||
pub immediate: bool,
|
||||
pub host_count: i64,
|
||||
pub succeeded_count: i64,
|
||||
pub failed_count: i64,
|
||||
pub notes: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Maintenance Windows
|
||||
// ============================================================
|
||||
|
||||
/// Recurrence type for a maintenance window.
|
||||
/// Mirrors the `window_recurrence` PostgreSQL ENUM.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
#[sqlx(type_name = "window_recurrence", rename_all = "lowercase")]
|
||||
pub enum WindowRecurrence {
|
||||
/// Single one-time window (at `start_at` for `duration_minutes` minutes).
|
||||
Once,
|
||||
/// Repeats every day at the time portion of `start_at`.
|
||||
Daily,
|
||||
/// Repeats on the day-of-week in `recurrence_day` (0 = Sunday).
|
||||
Weekly,
|
||||
/// Repeats on the day-of-month in `recurrence_day` (1-31).
|
||||
Monthly,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WindowRecurrence {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Once => write!(f, "once"),
|
||||
Self::Daily => write!(f, "daily"),
|
||||
Self::Weekly => write!(f, "weekly"),
|
||||
Self::Monthly => write!(f, "monthly"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Full row from `maintenance_windows`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct MaintenanceWindow {
|
||||
pub id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub label: String,
|
||||
pub recurrence: WindowRecurrence,
|
||||
/// Absolute start time (one-time) or time-of-day reference (recurring).
|
||||
pub start_at: DateTime<Utc>,
|
||||
/// Duration of the window in minutes.
|
||||
pub duration_minutes: i32,
|
||||
/// Day-of-week (0=Sun, weekly) or day-of-month (1-31, monthly); NULL for once/daily.
|
||||
pub recurrence_day: Option<i32>,
|
||||
pub enabled: bool,
|
||||
pub auto_apply: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Payload for `POST /api/v1/hosts/{id}/maintenance-windows`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateMaintenanceWindowRequest {
|
||||
pub label: String,
|
||||
pub recurrence: WindowRecurrence,
|
||||
/// RFC 3339 / ISO 8601 timestamp (UTC recommended).
|
||||
pub start_at: DateTime<Utc>,
|
||||
/// How many minutes the window is open (default 60).
|
||||
pub duration_minutes: Option<i32>,
|
||||
/// Required for `weekly` (0-6) and `monthly` (1-31).
|
||||
pub recurrence_day: Option<i32>,
|
||||
/// Whether the window is active (default true).
|
||||
pub enabled: Option<bool>,
|
||||
/// Whether to auto-create a patch_apply job when this window opens and patches are pending (default true).
|
||||
pub auto_apply: Option<bool>,
|
||||
}
|
||||
|
||||
/// Payload for `PUT /api/v1/hosts/{id}/maintenance-windows/{window_id}`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateMaintenanceWindowRequest {
|
||||
pub label: Option<String>,
|
||||
pub recurrence: Option<WindowRecurrence>,
|
||||
pub start_at: Option<DateTime<Utc>>,
|
||||
pub duration_minutes: Option<i32>,
|
||||
pub recurrence_day: Option<i32>,
|
||||
pub enabled: Option<bool>,
|
||||
pub auto_apply: Option<bool>,
|
||||
}
|
||||
39
crates/pm-core/src/request_id.rs
Executable file
39
crates/pm-core/src/request_id.rs
Executable file
@ -0,0 +1,39 @@
|
||||
use axum::{extract::Request, http::HeaderValue, middleware::Next, response::Response};
|
||||
use ulid::Ulid;
|
||||
|
||||
/// HTTP header name for request correlation IDs.
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
||||
/// Axum middleware that generates a ULID request ID and attaches it
|
||||
/// to both the request extensions and the response header.
|
||||
pub async fn request_id_middleware(mut req: Request, next: Next) -> Response {
|
||||
// Use existing X-Request-Id if provided by upstream proxy, else generate
|
||||
let id = req
|
||||
.headers()
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| Ulid::new().to_string());
|
||||
|
||||
// Insert as extension so handlers can access it
|
||||
req.extensions_mut().insert(RequestId(id.clone()));
|
||||
|
||||
let mut response = next.run(req).await;
|
||||
|
||||
// Echo the ID back in the response
|
||||
if let Ok(value) = HeaderValue::from_str(&id) {
|
||||
response.headers_mut().insert(REQUEST_ID_HEADER, value);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Extractor type for retrieving the request ID inside handlers.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestId(pub String);
|
||||
|
||||
impl std::fmt::Display for RequestId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
24
crates/pm-reports/Cargo.toml
Normal file
24
crates/pm-reports/Cargo.toml
Normal file
@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "pm-reports"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
# Report generation
|
||||
csv = "1"
|
||||
printpdf = { version = "0.7", features = ["embedded_images"] }
|
||||
plotters = { version = "0.3", default-features = false, features = ["bitmap_backend", "bitmap_encoder", "line_series", "area_series", "ttf"] }
|
||||
image = { version = "0.25", default-features = false, features = ["png"] }
|
||||
351
crates/pm-reports/src/csv.rs
Executable file
351
crates/pm-reports/src/csv.rs
Executable file
@ -0,0 +1,351 @@
|
||||
//! CSV report generation for pm-reports.
|
||||
|
||||
use crate::{ReportParams, ReportType};
|
||||
use anyhow::Context;
|
||||
|
||||
/// Generate a CSV report and return the raw bytes.
|
||||
pub async fn generate_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
match params.report_type {
|
||||
ReportType::Compliance => compliance_csv(pool, params).await,
|
||||
ReportType::PatchHistory => patch_history_csv(pool, params).await,
|
||||
ReportType::Vulnerability => vulnerability_csv(pool, params).await,
|
||||
ReportType::Audit => audit_csv(pool, params).await,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compliance
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn compliance_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
let rows = if let Some(gid) = params.group_id {
|
||||
sqlx::query(
|
||||
"
|
||||
SELECT
|
||||
h.id::text AS host_id,
|
||||
h.display_name,
|
||||
h.fqdn,
|
||||
h.health_status::text AS health_status,
|
||||
h.last_patch_at,
|
||||
COALESCE(jsonb_array_length(pd.installed_packages), 0) AS total_packages,
|
||||
COALESCE(pd.patch_count, 0) AS pending_patches,
|
||||
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages), 0) = 0 THEN 100.0
|
||||
ELSE ROUND(CAST((1.0 - pd.patch_count::float / NULLIF(jsonb_array_length(pd.installed_packages), 0)) * 100 AS numeric), 1)
|
||||
END)::float8 AS compliance_pct,
|
||||
COALESCE(string_agg(DISTINCT g.name, ', '), '') AS group_names
|
||||
FROM hosts h
|
||||
LEFT JOIN host_patch_data pd ON pd.host_id = h.id
|
||||
LEFT JOIN host_groups hg ON hg.host_id = h.id
|
||||
LEFT JOIN groups g ON g.id = hg.group_id
|
||||
WHERE h.id IN (
|
||||
SELECT host_id FROM host_groups WHERE group_id = $1
|
||||
)
|
||||
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
|
||||
ORDER BY compliance_pct ASC
|
||||
",
|
||||
)
|
||||
.bind(gid)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("compliance query (group filter) failed")?
|
||||
} else {
|
||||
sqlx::query(
|
||||
"
|
||||
SELECT
|
||||
h.id::text AS host_id,
|
||||
h.display_name,
|
||||
h.fqdn,
|
||||
h.health_status::text AS health_status,
|
||||
h.last_patch_at,
|
||||
COALESCE(jsonb_array_length(pd.installed_packages), 0) AS total_packages,
|
||||
COALESCE(pd.patch_count, 0) AS pending_patches,
|
||||
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages), 0) = 0 THEN 100.0
|
||||
ELSE ROUND(CAST((1.0 - pd.patch_count::float / NULLIF(jsonb_array_length(pd.installed_packages), 0)) * 100 AS numeric), 1)
|
||||
END)::float8 AS compliance_pct,
|
||||
COALESCE(string_agg(DISTINCT g.name, ', '), '') AS group_names
|
||||
FROM hosts h
|
||||
LEFT JOIN host_patch_data pd ON pd.host_id = h.id
|
||||
LEFT JOIN host_groups hg ON hg.host_id = h.id
|
||||
LEFT JOIN groups g ON g.id = hg.group_id
|
||||
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
|
||||
ORDER BY compliance_pct ASC
|
||||
",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("compliance query failed")?
|
||||
};
|
||||
|
||||
let mut wtr = csv::Writer::from_writer(vec![]);
|
||||
wtr.write_record([
|
||||
"host_id",
|
||||
"display_name",
|
||||
"fqdn",
|
||||
"group_names",
|
||||
"total_packages",
|
||||
"pending_patches",
|
||||
"compliance_pct",
|
||||
"last_patch_at",
|
||||
"health_status",
|
||||
])?;
|
||||
|
||||
for row in &rows {
|
||||
use sqlx::Row;
|
||||
let host_id: String = row.try_get("host_id").unwrap_or_default();
|
||||
let display_name: String = row.try_get("display_name").unwrap_or_default();
|
||||
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
|
||||
let group_names: String = row.try_get("group_names").unwrap_or_default();
|
||||
let total_packages: i64 = row.try_get("total_packages").unwrap_or(0);
|
||||
let pending_patches: i64 = row.try_get("pending_patches").unwrap_or(0);
|
||||
let compliance_pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0);
|
||||
let last_patch_at: Option<chrono::DateTime<chrono::Utc>> =
|
||||
row.try_get("last_patch_at").unwrap_or(None);
|
||||
let health_status: String = row.try_get("health_status").unwrap_or_default();
|
||||
|
||||
wtr.write_record(&[
|
||||
host_id,
|
||||
display_name,
|
||||
fqdn,
|
||||
group_names,
|
||||
total_packages.to_string(),
|
||||
pending_patches.to_string(),
|
||||
format!("{:.1}", compliance_pct),
|
||||
last_patch_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
|
||||
health_status,
|
||||
])?;
|
||||
}
|
||||
|
||||
wtr.into_inner().context("csv flush failed")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Patch history
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn patch_history_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
let rows = sqlx::query(
|
||||
"
|
||||
SELECT
|
||||
pj.id::text AS job_id,
|
||||
pj.kind::text AS job_kind,
|
||||
pj.status::text AS job_status,
|
||||
h.display_name,
|
||||
h.fqdn,
|
||||
jsonb_array_length(COALESCE(pj.patch_selection->'packages', '[]'::jsonb)) AS package_count,
|
||||
pjh.started_at,
|
||||
pjh.completed_at,
|
||||
EXTRACT(EPOCH FROM (pjh.completed_at - pjh.started_at))::bigint AS duration_seconds,
|
||||
COALESCE(u.username, 'system') AS operator
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs pj ON pj.id = pjh.job_id
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
LEFT JOIN users u ON u.id = pj.created_by_user_id
|
||||
WHERE ($1::timestamptz IS NULL OR pjh.started_at >= $1)
|
||||
AND ($2::timestamptz IS NULL OR pjh.started_at <= $2)
|
||||
ORDER BY pjh.started_at DESC
|
||||
",
|
||||
)
|
||||
.bind(params.from)
|
||||
.bind(params.to)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("patch history query failed")?;
|
||||
|
||||
let mut wtr = csv::Writer::from_writer(vec![]);
|
||||
wtr.write_record([
|
||||
"job_id",
|
||||
"job_kind",
|
||||
"job_status",
|
||||
"host_display_name",
|
||||
"host_fqdn",
|
||||
"package_count",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"duration_seconds",
|
||||
"operator",
|
||||
])?;
|
||||
|
||||
for row in &rows {
|
||||
use sqlx::Row;
|
||||
let job_id: String = row.try_get("job_id").unwrap_or_default();
|
||||
let job_kind: String = row.try_get("job_kind").unwrap_or_default();
|
||||
let job_status: String = row.try_get("job_status").unwrap_or_default();
|
||||
let display_name: String = row.try_get("display_name").unwrap_or_default();
|
||||
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
|
||||
let package_count: i64 = row.try_get("package_count").unwrap_or(0);
|
||||
let started_at: Option<chrono::DateTime<chrono::Utc>> =
|
||||
row.try_get("started_at").unwrap_or(None);
|
||||
let completed_at: Option<chrono::DateTime<chrono::Utc>> =
|
||||
row.try_get("completed_at").unwrap_or(None);
|
||||
let duration_seconds: Option<i64> = row.try_get("duration_seconds").unwrap_or(None);
|
||||
let operator: String = row.try_get("operator").unwrap_or_default();
|
||||
|
||||
wtr.write_record(&[
|
||||
job_id,
|
||||
job_kind,
|
||||
job_status,
|
||||
display_name,
|
||||
fqdn,
|
||||
package_count.to_string(),
|
||||
started_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
|
||||
completed_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
|
||||
duration_seconds.unwrap_or(0).to_string(),
|
||||
operator,
|
||||
])?;
|
||||
}
|
||||
|
||||
wtr.into_inner().context("csv flush failed")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Vulnerability
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn vulnerability_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
let mut wtr = csv::Writer::from_writer(vec![]);
|
||||
wtr.write_record([
|
||||
"host_id",
|
||||
"display_name",
|
||||
"fqdn",
|
||||
"cve_id",
|
||||
"package_name",
|
||||
"severity",
|
||||
"available_version",
|
||||
"last_seen_at",
|
||||
])?;
|
||||
|
||||
let result = sqlx::query(
|
||||
"
|
||||
SELECT
|
||||
h.id::text AS host_id,
|
||||
h.display_name,
|
||||
h.fqdn,
|
||||
cve_id,
|
||||
patch->>'name' AS package_name,
|
||||
patch->>'severity' AS severity,
|
||||
patch->>'available_version' AS available_version,
|
||||
pd.polled_at AS last_seen_at
|
||||
FROM hosts h
|
||||
JOIN host_patch_data pd ON pd.host_id = h.id
|
||||
CROSS JOIN LATERAL jsonb_array_elements(COALESCE(pd.available_patches, '[]'::jsonb)) AS patch
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(COALESCE(patch->'cve_ids', '[]'::jsonb)) AS cve_id
|
||||
WHERE ($1::timestamptz IS NULL OR pd.polled_at >= $1)
|
||||
AND ($2::timestamptz IS NULL OR pd.polled_at <= $2)
|
||||
ORDER BY
|
||||
CASE patch->>'severity'
|
||||
WHEN 'critical' THEN 1
|
||||
WHEN 'high' THEN 2
|
||||
WHEN 'medium' THEN 3
|
||||
ELSE 4
|
||||
END,
|
||||
h.display_name
|
||||
",
|
||||
)
|
||||
.bind(params.from)
|
||||
.bind(params.to)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
for row in &rows {
|
||||
use sqlx::Row;
|
||||
let host_id: String = row.try_get("host_id").unwrap_or_default();
|
||||
let display_name: String = row.try_get("display_name").unwrap_or_default();
|
||||
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
|
||||
let cve_id: String = row.try_get("cve_id").unwrap_or_default();
|
||||
let package_name: String = row.try_get("package_name").unwrap_or_default();
|
||||
let severity: String = row.try_get("severity").unwrap_or_default();
|
||||
let available_version: String =
|
||||
row.try_get("available_version").unwrap_or_default();
|
||||
let last_seen_at: Option<chrono::DateTime<chrono::Utc>> =
|
||||
row.try_get("last_seen_at").unwrap_or(None);
|
||||
|
||||
wtr.write_record(&[
|
||||
host_id,
|
||||
display_name,
|
||||
fqdn,
|
||||
cve_id,
|
||||
package_name,
|
||||
severity,
|
||||
available_version,
|
||||
last_seen_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
|
||||
])?;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "vulnerability query failed — returning header-only CSV");
|
||||
// Return header-only CSV (no invalid comment rows)
|
||||
},
|
||||
}
|
||||
|
||||
wtr.into_inner().context("csv flush failed")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Audit
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn audit_csv(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
let rows = sqlx::query(
|
||||
"
|
||||
SELECT
|
||||
id::text AS id,
|
||||
created_at,
|
||||
action::text AS action,
|
||||
actor_username,
|
||||
target_type,
|
||||
target_id,
|
||||
ip_address::text AS ip_address,
|
||||
request_id
|
||||
FROM audit_log
|
||||
WHERE ($1::timestamptz IS NULL OR created_at >= $1)
|
||||
AND ($2::timestamptz IS NULL OR created_at <= $2)
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10000
|
||||
",
|
||||
)
|
||||
.bind(params.from)
|
||||
.bind(params.to)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("audit query failed")?;
|
||||
|
||||
let mut wtr = csv::Writer::from_writer(vec![]);
|
||||
wtr.write_record([
|
||||
"id",
|
||||
"created_at",
|
||||
"action",
|
||||
"actor_username",
|
||||
"target_type",
|
||||
"target_id",
|
||||
"ip_address",
|
||||
"request_id",
|
||||
])?;
|
||||
|
||||
for row in &rows {
|
||||
use sqlx::Row;
|
||||
let id: String = row.try_get("id").unwrap_or_default();
|
||||
let created_at: Option<chrono::DateTime<chrono::Utc>> =
|
||||
row.try_get("created_at").unwrap_or(None);
|
||||
let action: String = row.try_get("action").unwrap_or_default();
|
||||
let actor_username: String = row.try_get("actor_username").unwrap_or_default();
|
||||
let target_type: String = row.try_get("target_type").unwrap_or_default();
|
||||
let target_id: String = row.try_get("target_id").unwrap_or_default();
|
||||
let ip_address: String = row.try_get("ip_address").unwrap_or_default();
|
||||
let request_id: String = row.try_get("request_id").unwrap_or_default();
|
||||
|
||||
wtr.write_record(&[
|
||||
id,
|
||||
created_at.map(|d| d.to_rfc3339()).unwrap_or_default(),
|
||||
action,
|
||||
actor_username,
|
||||
target_type,
|
||||
target_id,
|
||||
ip_address,
|
||||
request_id,
|
||||
])?;
|
||||
}
|
||||
|
||||
wtr.into_inner().context("csv flush failed")
|
||||
}
|
||||
27
crates/pm-reports/src/lib.rs
Executable file
27
crates/pm-reports/src/lib.rs
Executable file
@ -0,0 +1,27 @@
|
||||
//! pm-reports — CSV and PDF report generation.
|
||||
//!
|
||||
//! Uses printpdf + plotters for in-process PDF with charts.
|
||||
pub mod csv;
|
||||
pub mod pdf;
|
||||
|
||||
pub use csv::generate_csv;
|
||||
pub use pdf::generate_pdf;
|
||||
|
||||
/// The type of report to generate.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ReportType {
|
||||
Compliance,
|
||||
PatchHistory,
|
||||
Vulnerability,
|
||||
Audit,
|
||||
}
|
||||
|
||||
/// Parameters controlling report generation.
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct ReportParams {
|
||||
pub report_type: ReportType,
|
||||
pub from: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub to: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub group_id: Option<uuid::Uuid>,
|
||||
}
|
||||
599
crates/pm-reports/src/pdf.rs
Executable file
599
crates/pm-reports/src/pdf.rs
Executable file
@ -0,0 +1,599 @@
|
||||
//! PDF report generation for pm-reports.
|
||||
//!
|
||||
//! Uses printpdf for document structure and plotters + image for embedded charts.
|
||||
|
||||
use crate::{ReportParams, ReportType};
|
||||
use anyhow::Context;
|
||||
use plotters::prelude::*;
|
||||
use printpdf::{
|
||||
BuiltinFont, ColorBits, ColorSpace, Image, ImageTransform, ImageXObject, IndirectFontRef, Mm,
|
||||
PdfDocument, PdfLayerIndex, PdfLayerReference, PdfPageIndex, Px,
|
||||
};
|
||||
|
||||
const PAGE_W: f32 = 297.0; // A4 landscape width (mm)
|
||||
const PAGE_H: f32 = 210.0; // A4 landscape height (mm)
|
||||
const MARGIN: f32 = 10.0;
|
||||
const ROW_H: f32 = 6.0;
|
||||
const HEADER_Y_START: f32 = 190.0;
|
||||
const NEW_PAGE_THRESHOLD: f32 = 20.0;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Generate a PDF report and return the raw bytes.
|
||||
pub async fn generate_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
match params.report_type {
|
||||
ReportType::Compliance => compliance_pdf(pool, params).await,
|
||||
ReportType::PatchHistory => patch_history_pdf(pool, params).await,
|
||||
ReportType::Vulnerability => vulnerability_pdf(pool, params).await,
|
||||
ReportType::Audit => audit_pdf(pool, params).await,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Chart helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Render a bar chart to an in-memory PNG and return the raw PNG bytes.
|
||||
fn render_bar_chart(
|
||||
labels: &[String],
|
||||
values: &[f64],
|
||||
title: &str,
|
||||
) -> anyhow::Result<(Vec<u8>, u32, u32)> {
|
||||
const W: u32 = 800;
|
||||
const H: u32 = 400;
|
||||
|
||||
let mut pixel_buf = vec![0u8; (W * H * 3) as usize];
|
||||
|
||||
// Guard: skip rendering for empty or mismatched data
|
||||
if labels.is_empty() || values.is_empty() || labels.len() != values.len() {
|
||||
anyhow::bail!("cannot render bar chart with empty or mismatched data");
|
||||
}
|
||||
|
||||
// Guard: reject NaN / infinity which would panic in plotters
|
||||
if values.iter().any(|v| !v.is_finite()) {
|
||||
anyhow::bail!("bar chart values contain non-finite numbers");
|
||||
}
|
||||
|
||||
{
|
||||
let root = BitMapBackend::with_buffer(&mut pixel_buf, (W, H)).into_drawing_area();
|
||||
root.fill(&WHITE)?;
|
||||
|
||||
let max_val = values.iter().cloned().fold(0.0_f64, f64::max).max(1.0);
|
||||
let n = labels.len().max(1);
|
||||
|
||||
let mut chart = ChartBuilder::on(&root)
|
||||
.caption(title, ("sans-serif", 20).into_font())
|
||||
.margin(20u32)
|
||||
.x_label_area_size(60u32)
|
||||
.y_label_area_size(50u32)
|
||||
.build_cartesian_2d(0..n, 0.0..max_val * 1.1)?;
|
||||
|
||||
chart
|
||||
.configure_mesh()
|
||||
.x_labels(n.min(20))
|
||||
.x_label_formatter(&|idx| {
|
||||
labels
|
||||
.get(*idx)
|
||||
.map(|s| {
|
||||
if s.len() > 12 {
|
||||
s[..12].to_string()
|
||||
} else {
|
||||
s.clone()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.y_desc("Value")
|
||||
.draw()?;
|
||||
|
||||
chart.draw_series((0..n).map(|i| {
|
||||
let v = values.get(i).copied().unwrap_or(0.0);
|
||||
let color = if v >= 90.0 {
|
||||
RGBColor(76, 175, 80)
|
||||
} else if v >= 70.0 {
|
||||
RGBColor(255, 193, 7)
|
||||
} else {
|
||||
RGBColor(244, 67, 54)
|
||||
};
|
||||
Rectangle::new([(i, 0.0), (i + 1, v)], color.filled())
|
||||
}))?;
|
||||
|
||||
root.present()?;
|
||||
}
|
||||
|
||||
// Return raw RGB pixels + dimensions for direct PDF embedding
|
||||
Ok((pixel_buf, W, H))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PDF builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct PdfBuilder {
|
||||
doc: printpdf::PdfDocumentReference,
|
||||
font: IndirectFontRef,
|
||||
font_bold: IndirectFontRef,
|
||||
page_idx: PdfPageIndex,
|
||||
layer_idx: PdfLayerIndex,
|
||||
current_y: f32,
|
||||
}
|
||||
|
||||
impl PdfBuilder {
|
||||
fn new(title: &str) -> anyhow::Result<Self> {
|
||||
let doc = PdfDocument::empty(title);
|
||||
let (page_idx, layer_idx) = doc.add_page(Mm(PAGE_W), Mm(PAGE_H), "Layer 1");
|
||||
let font = doc.add_builtin_font(BuiltinFont::Helvetica)?;
|
||||
let font_bold = doc.add_builtin_font(BuiltinFont::HelveticaBold)?;
|
||||
Ok(Self {
|
||||
doc,
|
||||
font,
|
||||
font_bold,
|
||||
page_idx,
|
||||
layer_idx,
|
||||
current_y: HEADER_Y_START,
|
||||
})
|
||||
}
|
||||
|
||||
fn layer(&self) -> PdfLayerReference {
|
||||
self.doc.get_page(self.page_idx).get_layer(self.layer_idx)
|
||||
}
|
||||
|
||||
fn write_text(&self, s: &str, font_size: f32, x: f32, y: f32, bold: bool) {
|
||||
let f = if bold { &self.font_bold } else { &self.font };
|
||||
self.layer().use_text(s, font_size, Mm(x), Mm(y), f);
|
||||
}
|
||||
|
||||
fn new_page(&mut self) {
|
||||
let (pi, li) = self.doc.add_page(Mm(PAGE_W), Mm(PAGE_H), "Layer 1");
|
||||
self.page_idx = pi;
|
||||
self.layer_idx = li;
|
||||
self.current_y = HEADER_Y_START;
|
||||
}
|
||||
|
||||
fn ensure_space(&mut self, needed: f32) {
|
||||
if self.current_y - needed < NEW_PAGE_THRESHOLD {
|
||||
self.new_page();
|
||||
}
|
||||
}
|
||||
|
||||
fn table_row(&mut self, cells: &[&str], col_x: &[f32], font_size: f32, bold: bool) {
|
||||
self.ensure_space(ROW_H);
|
||||
let y = self.current_y;
|
||||
for (i, cell) in cells.iter().enumerate() {
|
||||
let x = col_x.get(i).copied().unwrap_or(MARGIN);
|
||||
let s = if cell.len() > 30 { &cell[..30] } else { cell };
|
||||
self.write_text(s, font_size, x, y, bold);
|
||||
}
|
||||
self.current_y -= ROW_H;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn embed_image(
|
||||
&self,
|
||||
raw_rgb: Vec<u8>,
|
||||
img_w: u32,
|
||||
img_h: u32,
|
||||
x_mm: f32,
|
||||
y_mm: f32,
|
||||
scale_x: f32,
|
||||
scale_y: f32,
|
||||
) -> anyhow::Result<()> {
|
||||
// Validate dimensions and buffer size to prevent panics
|
||||
let expected_len = (img_w as usize) * (img_h as usize) * 3;
|
||||
if raw_rgb.len() != expected_len || img_w == 0 || img_h == 0 {
|
||||
anyhow::bail!(
|
||||
"image buffer size mismatch: expected {} bytes for {}x{} RGB, got {}",
|
||||
expected_len,
|
||||
img_w,
|
||||
img_h,
|
||||
raw_rgb.len()
|
||||
);
|
||||
}
|
||||
if !scale_x.is_finite() || !scale_y.is_finite() || scale_x <= 0.0 || scale_y <= 0.0 {
|
||||
anyhow::bail!(
|
||||
"invalid image scale factors: scale_x={}, scale_y={}",
|
||||
scale_x,
|
||||
scale_y
|
||||
);
|
||||
}
|
||||
let xobj = ImageXObject {
|
||||
width: Px(img_w as usize),
|
||||
height: Px(img_h as usize),
|
||||
color_space: ColorSpace::Rgb,
|
||||
bits_per_component: ColorBits::Bit8,
|
||||
interpolate: true,
|
||||
image_data: raw_rgb,
|
||||
image_filter: None,
|
||||
smask: None,
|
||||
clipping_bbox: None,
|
||||
};
|
||||
let pdf_img = Image::from(xobj);
|
||||
pdf_img.add_to_layer(
|
||||
self.layer(),
|
||||
ImageTransform {
|
||||
translate_x: Some(Mm(x_mm)),
|
||||
translate_y: Some(Mm(y_mm)),
|
||||
scale_x: Some(scale_x),
|
||||
scale_y: Some(scale_y),
|
||||
dpi: Some(150.0),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn save(self) -> anyhow::Result<Vec<u8>> {
|
||||
Ok(self.doc.save_to_bytes()?)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Title page helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn write_title_page(pdf: &mut PdfBuilder, title: &str, params: &ReportParams) {
|
||||
pdf.write_text(title, 24.0, MARGIN, 160.0, true);
|
||||
pdf.write_text(
|
||||
&format!(
|
||||
"Generated: {}",
|
||||
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC")
|
||||
),
|
||||
11.0,
|
||||
MARGIN,
|
||||
148.0,
|
||||
false,
|
||||
);
|
||||
if let Some(from) = params.from {
|
||||
pdf.write_text(
|
||||
&format!("From: {}", from.format("%Y-%m-%d")),
|
||||
10.0,
|
||||
MARGIN,
|
||||
140.0,
|
||||
false,
|
||||
);
|
||||
}
|
||||
if let Some(to) = params.to {
|
||||
pdf.write_text(
|
||||
&format!("To: {}", to.format("%Y-%m-%d")),
|
||||
10.0,
|
||||
MARGIN,
|
||||
134.0,
|
||||
false,
|
||||
);
|
||||
}
|
||||
if let Some(gid) = params.group_id {
|
||||
pdf.write_text(&format!("Group: {}", gid), 10.0, MARGIN, 128.0, false);
|
||||
}
|
||||
pdf.new_page();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compliance PDF
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn compliance_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
use sqlx::Row;
|
||||
let rows = if let Some(gid) = params.group_id {
|
||||
sqlx::query(
|
||||
"
|
||||
SELECT h.display_name, h.fqdn,
|
||||
COALESCE(jsonb_array_length(pd.installed_packages),0) AS total_packages,
|
||||
COALESCE(pd.patch_count,0) AS pending_patches,
|
||||
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages),0)=0 THEN 100.0
|
||||
ELSE ROUND(CAST((1.0-pd.patch_count::float/NULLIF(jsonb_array_length(pd.installed_packages),0))*100 AS numeric),1)
|
||||
END)::float8 AS compliance_pct,
|
||||
h.health_status::text AS health_status
|
||||
FROM hosts h LEFT JOIN host_patch_data pd ON pd.host_id=h.id
|
||||
WHERE h.id IN (SELECT host_id FROM host_groups WHERE group_id=$1)
|
||||
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
|
||||
ORDER BY compliance_pct ASC",
|
||||
)
|
||||
.bind(gid)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("compliance PDF query (group) failed")?
|
||||
} else {
|
||||
sqlx::query(
|
||||
"
|
||||
SELECT h.display_name, h.fqdn,
|
||||
COALESCE(jsonb_array_length(pd.installed_packages),0) AS total_packages,
|
||||
COALESCE(pd.patch_count,0) AS pending_patches,
|
||||
(CASE WHEN COALESCE(jsonb_array_length(pd.installed_packages),0)=0 THEN 100.0
|
||||
ELSE ROUND(CAST((1.0-pd.patch_count::float/NULLIF(jsonb_array_length(pd.installed_packages),0))*100 AS numeric),1)
|
||||
END)::float8 AS compliance_pct,
|
||||
h.health_status::text AS health_status
|
||||
FROM hosts h LEFT JOIN host_patch_data pd ON pd.host_id=h.id
|
||||
GROUP BY h.id, h.health_status, pd.installed_packages, pd.patch_count
|
||||
ORDER BY compliance_pct ASC",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("compliance PDF query failed")?
|
||||
};
|
||||
let labels: Vec<String> = rows
|
||||
.iter()
|
||||
.map(|r| r.try_get::<String, _>("display_name").unwrap_or_default())
|
||||
.collect();
|
||||
let values: Vec<f64> = rows
|
||||
.iter()
|
||||
.map(|r| r.try_get::<f64, _>("compliance_pct").unwrap_or(0.0))
|
||||
.collect();
|
||||
let mut pdf = PdfBuilder::new("Compliance Report")?;
|
||||
write_title_page(&mut pdf, "Compliance Report", params);
|
||||
let col_x: &[f32] = &[MARGIN, 65.0, 130.0, 165.0, 200.0, 235.0];
|
||||
pdf.table_row(
|
||||
&[
|
||||
"Host",
|
||||
"FQDN",
|
||||
"Total Pkgs",
|
||||
"Pending",
|
||||
"Compliance %",
|
||||
"Status",
|
||||
],
|
||||
col_x,
|
||||
9.0,
|
||||
true,
|
||||
);
|
||||
for row in &rows {
|
||||
let name: String = row.try_get("display_name").unwrap_or_default();
|
||||
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
|
||||
let total: i64 = row.try_get("total_packages").unwrap_or(0);
|
||||
let pend: i64 = row.try_get("pending_patches").unwrap_or(0);
|
||||
let pct: f64 = row.try_get("compliance_pct").unwrap_or(0.0);
|
||||
let status: String = row.try_get("health_status").unwrap_or_default();
|
||||
pdf.table_row(
|
||||
&[
|
||||
&name,
|
||||
&fqdn,
|
||||
&total.to_string(),
|
||||
&pend.to_string(),
|
||||
&format!("{:.1}%", pct),
|
||||
&status,
|
||||
],
|
||||
col_x,
|
||||
8.0,
|
||||
false,
|
||||
);
|
||||
}
|
||||
if !labels.is_empty() {
|
||||
match render_bar_chart(&labels, &values, "Compliance % by Host") {
|
||||
Ok((raw, w, h)) => {
|
||||
pdf.new_page();
|
||||
pdf.write_text("Compliance Chart", 16.0, MARGIN, 200.0, true);
|
||||
if let Err(e) = pdf.embed_image(raw, w, h, MARGIN, 10.0, 0.28, 0.28) {
|
||||
tracing::warn!(error = %e, "chart embed failed");
|
||||
}
|
||||
},
|
||||
Err(e) => tracing::warn!(error = %e, "chart render failed"),
|
||||
}
|
||||
}
|
||||
pdf.save()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Patch history PDF
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn patch_history_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
use sqlx::Row;
|
||||
let rows = sqlx::query(
|
||||
"
|
||||
SELECT pj.kind::text AS job_kind, pj.status::text AS job_status,
|
||||
h.display_name, h.fqdn, pjh.started_at, pjh.completed_at,
|
||||
EXTRACT(EPOCH FROM (pjh.completed_at-pjh.started_at))::bigint AS duration_seconds,
|
||||
COALESCE(u.username,'system') AS operator
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs pj ON pj.id=pjh.job_id
|
||||
JOIN hosts h ON h.id=pjh.host_id
|
||||
LEFT JOIN users u ON u.id=pj.created_by_user_id
|
||||
WHERE ($1::timestamptz IS NULL OR pjh.started_at>=$1)
|
||||
AND ($2::timestamptz IS NULL OR pjh.started_at<=$2)
|
||||
ORDER BY pjh.started_at DESC",
|
||||
)
|
||||
.bind(params.from)
|
||||
.bind(params.to)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("patch history PDF query failed")?;
|
||||
let mut dc: std::collections::BTreeMap<String, f64> = std::collections::BTreeMap::new();
|
||||
for row in &rows {
|
||||
if let Ok(Some(s)) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("started_at") {
|
||||
*dc.entry(s.format("%Y-%m-%d").to_string()).or_insert(0.0) += 1.0;
|
||||
}
|
||||
}
|
||||
let cl: Vec<String> = dc.keys().cloned().collect();
|
||||
let cv: Vec<f64> = dc.values().cloned().collect();
|
||||
let mut pdf = PdfBuilder::new("Patch History Report")?;
|
||||
write_title_page(&mut pdf, "Patch History Report", params);
|
||||
let col_x: &[f32] = &[MARGIN, 45.0, 80.0, 115.0, 155.0, 200.0, 245.0, 270.0];
|
||||
pdf.table_row(
|
||||
&[
|
||||
"Kind",
|
||||
"Status",
|
||||
"Host",
|
||||
"FQDN",
|
||||
"Started",
|
||||
"Completed",
|
||||
"Dur(s)",
|
||||
"Operator",
|
||||
],
|
||||
col_x,
|
||||
9.0,
|
||||
true,
|
||||
);
|
||||
for row in &rows {
|
||||
let kind: String = row.try_get("job_kind").unwrap_or_default();
|
||||
let status: String = row.try_get("job_status").unwrap_or_default();
|
||||
let name: String = row.try_get("display_name").unwrap_or_default();
|
||||
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
|
||||
let started: String = row
|
||||
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("started_at")
|
||||
.unwrap_or(None)
|
||||
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
|
||||
.unwrap_or_default();
|
||||
let completed: String = row
|
||||
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("completed_at")
|
||||
.unwrap_or(None)
|
||||
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
|
||||
.unwrap_or_default();
|
||||
let dur: i64 = row.try_get("duration_seconds").unwrap_or(0);
|
||||
let op: String = row.try_get("operator").unwrap_or_default();
|
||||
pdf.table_row(
|
||||
&[
|
||||
&kind,
|
||||
&status,
|
||||
&name,
|
||||
&fqdn,
|
||||
&started,
|
||||
&completed,
|
||||
&dur.to_string(),
|
||||
&op,
|
||||
],
|
||||
col_x,
|
||||
8.0,
|
||||
false,
|
||||
);
|
||||
}
|
||||
if !cl.is_empty() {
|
||||
match render_bar_chart(&cl, &cv, "Jobs per Day") {
|
||||
Ok((raw, w, h)) => {
|
||||
pdf.new_page();
|
||||
pdf.write_text("Patch Activity Chart", 16.0, MARGIN, 200.0, true);
|
||||
if let Err(e) = pdf.embed_image(raw, w, h, MARGIN, 10.0, 0.28, 0.28) {
|
||||
tracing::warn!(error = %e, "chart embed failed");
|
||||
}
|
||||
},
|
||||
Err(e) => tracing::warn!(error = %e, "chart render failed"),
|
||||
}
|
||||
}
|
||||
pdf.save()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Vulnerability PDF
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn vulnerability_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
use sqlx::Row;
|
||||
// Query DB FIRST (before creating any non-Send PdfBuilder)
|
||||
let query_result = sqlx::query("
|
||||
SELECT h.display_name, h.fqdn,
|
||||
cve_id,
|
||||
patch->>'name' AS package_name,
|
||||
patch->>'severity' AS severity,
|
||||
patch->>'available_version' AS available_version,
|
||||
pd.polled_at AS last_seen_at
|
||||
FROM hosts h JOIN host_patch_data pd ON pd.host_id=h.id
|
||||
CROSS JOIN LATERAL jsonb_array_elements(COALESCE(pd.available_patches,'[]'::jsonb)) AS patch
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(COALESCE(patch->'cve_ids','[]'::jsonb)) AS cve_id
|
||||
WHERE ($1::timestamptz IS NULL OR pd.polled_at>=$1)
|
||||
AND ($2::timestamptz IS NULL OR pd.polled_at<=$2)
|
||||
ORDER BY CASE patch->>'severity' WHEN 'critical' THEN 1 WHEN 'high' THEN 2 WHEN 'medium' THEN 3 ELSE 4 END,
|
||||
h.display_name")
|
||||
.bind(params.from).bind(params.to).fetch_all(pool).await;
|
||||
// Now create PdfBuilder (non-Send Rc types) after all awaits
|
||||
let mut pdf = PdfBuilder::new("Vulnerability Report")?;
|
||||
write_title_page(&mut pdf, "Vulnerability Exposure Report", params);
|
||||
let col_x: &[f32] = &[MARGIN, 55.0, 100.0, 130.0, 175.0, 215.0, 255.0];
|
||||
pdf.table_row(
|
||||
&[
|
||||
"Host",
|
||||
"FQDN",
|
||||
"CVE ID",
|
||||
"Package",
|
||||
"Severity",
|
||||
"Fix Version",
|
||||
"Last Seen",
|
||||
],
|
||||
col_x,
|
||||
9.0,
|
||||
true,
|
||||
);
|
||||
match query_result {
|
||||
Ok(rows) => {
|
||||
for row in &rows {
|
||||
let name: String = row.try_get("display_name").unwrap_or_default();
|
||||
let fqdn: String = row.try_get("fqdn").unwrap_or_default();
|
||||
let cve: String = row.try_get("cve_id").unwrap_or_default();
|
||||
let pkg: String = row.try_get("package_name").unwrap_or_default();
|
||||
let sev: String = row.try_get("severity").unwrap_or_default();
|
||||
let fix: String = row.try_get("available_version").unwrap_or_default();
|
||||
let seen: String = row
|
||||
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_seen_at")
|
||||
.unwrap_or(None)
|
||||
.map(|d| d.format("%Y-%m-%d").to_string())
|
||||
.unwrap_or_default();
|
||||
pdf.table_row(
|
||||
&[&name, &fqdn, &cve, &pkg, &sev, &fix, &seen],
|
||||
col_x,
|
||||
8.0,
|
||||
false,
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "vulnerability PDF query failed");
|
||||
let y = pdf.current_y;
|
||||
pdf.write_text(&format!("No data: {}", e), 10.0, MARGIN, y, false);
|
||||
},
|
||||
}
|
||||
pdf.save()
|
||||
}
|
||||
|
||||
async fn audit_pdf(pool: &sqlx::PgPool, params: &ReportParams) -> anyhow::Result<Vec<u8>> {
|
||||
use sqlx::Row;
|
||||
let rows = sqlx::query(
|
||||
"
|
||||
SELECT id::text AS id, created_at, action::text AS action,
|
||||
actor_username, target_type, target_id,
|
||||
ip_address::text AS ip_address, request_id
|
||||
FROM audit_log
|
||||
WHERE ($1::timestamptz IS NULL OR created_at>=$1)
|
||||
AND ($2::timestamptz IS NULL OR created_at<=$2)
|
||||
ORDER BY created_at DESC LIMIT 10000",
|
||||
)
|
||||
.bind(params.from)
|
||||
.bind(params.to)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("audit PDF query failed")?;
|
||||
let mut pdf = PdfBuilder::new("Audit Trail Report")?;
|
||||
write_title_page(&mut pdf, "Audit Trail Report", params);
|
||||
let col_x: &[f32] = &[MARGIN, 50.0, 95.0, 135.0, 175.0, 215.0, 255.0];
|
||||
pdf.table_row(
|
||||
&[
|
||||
"Timestamp",
|
||||
"Action",
|
||||
"Actor",
|
||||
"Target Type",
|
||||
"Target ID",
|
||||
"IP",
|
||||
"Request ID",
|
||||
],
|
||||
col_x,
|
||||
9.0,
|
||||
true,
|
||||
);
|
||||
for row in &rows {
|
||||
let created: String = row
|
||||
.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>("created_at")
|
||||
.unwrap_or(None)
|
||||
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
|
||||
.unwrap_or_default();
|
||||
let action: String = row.try_get("action").unwrap_or_default();
|
||||
let actor: String = row.try_get("actor_username").unwrap_or_default();
|
||||
let ttype: String = row.try_get("target_type").unwrap_or_default();
|
||||
let tid: String = row.try_get("target_id").unwrap_or_default();
|
||||
let ip: String = row.try_get("ip_address").unwrap_or_default();
|
||||
let req: String = row.try_get("request_id").unwrap_or_default();
|
||||
pdf.table_row(
|
||||
&[&created, &action, &actor, &ttype, &tid, &ip, &req],
|
||||
col_x,
|
||||
8.0,
|
||||
false,
|
||||
);
|
||||
}
|
||||
pdf.save()
|
||||
}
|
||||
46
crates/pm-web/Cargo.toml
Normal file
46
crates/pm-web/Cargo.toml
Normal file
@ -0,0 +1,46 @@
|
||||
[package]
|
||||
name = "pm-web"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "pm-web"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pm-ca = { path = "../pm-ca" }
|
||||
pm-core = { path = "../pm-core" }
|
||||
pm-auth = { path = "../pm-auth" }
|
||||
pm-reports = { path = "../pm-reports" }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
axum-server = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
ulid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
dashmap = { version = "6" }
|
||||
tower_governor = { workspace = true }
|
||||
governor = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
|
||||
rand = { workspace = true }
|
||||
hex = "0.4"
|
||||
base64 = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
url = { workspace = true }
|
||||
urlencoding = "2"
|
||||
353
crates/pm-web/src/main.rs
Normal file
353
crates/pm-web/src/main.rs
Normal file
@ -0,0 +1,353 @@
|
||||
//! pm-web — Linux Patch Manager web server.
|
||||
|
||||
mod routes;
|
||||
|
||||
use axum::{extract::State, http::StatusCode, middleware, response::Json, routing::get, Router};
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use dashmap::DashMap;
|
||||
use pm_auth::{
|
||||
jwt,
|
||||
rbac::{require_auth, AuthConfig},
|
||||
};
|
||||
use pm_core::{
|
||||
config::AppConfig, db, logging, models::PkiBundle, request_id::request_id_middleware,
|
||||
};
|
||||
use routes::sso::{OidcCache, SsoSession};
|
||||
use routes::ws::WsTicket;
|
||||
use serde_json::{json, Value};
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_governor::{
|
||||
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
|
||||
};
|
||||
use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
trace::TraceLayer,
|
||||
};
|
||||
|
||||
/// Shared application state threaded through Axum.
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub db: sqlx::PgPool,
|
||||
pub config: Arc<AppConfig>,
|
||||
pub signing_key_pem: String,
|
||||
pub auth_config: Arc<AuthConfig>,
|
||||
/// In-memory store for single-use WebSocket authentication tickets.
|
||||
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
|
||||
/// In-memory store for SSO PKCE sessions (state → code_verifier).
|
||||
pub sso_sessions: Arc<DashMap<String, SsoSession>>,
|
||||
/// Cached OIDC discovery document and JWKS for SSO id_token verification.
|
||||
pub oidc_cache: Arc<Mutex<OidcCache>>,
|
||||
/// Internal certificate authority for mTLS client cert issuance.
|
||||
pub ca: Arc<pm_ca::CertAuthority>,
|
||||
/// Short-lived cache for approved enrollment PKI bundles.
|
||||
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Install the default crypto provider for rustls (required since 0.23)
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install rustls crypto provider");
|
||||
|
||||
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
|
||||
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
|
||||
|
||||
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
|
||||
eprintln!("Config file not found or invalid, using defaults");
|
||||
AppConfig::default()
|
||||
});
|
||||
|
||||
logging::init(&config.logging);
|
||||
tracing::info!(
|
||||
version = env!("CARGO_PKG_VERSION"),
|
||||
"patch-manager-web starting"
|
||||
);
|
||||
|
||||
let signing_key_pem = jwt::load_signing_key(&config.security.jwt_signing_key_path)
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "JWT signing key not found (dev mode)");
|
||||
String::new()
|
||||
});
|
||||
|
||||
let verify_key_pem =
|
||||
jwt::load_verify_key(&config.security.jwt_verify_key_path).unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "JWT verify key not found (dev mode)");
|
||||
String::new()
|
||||
});
|
||||
|
||||
let auth_config = Arc::new(AuthConfig::new(
|
||||
verify_key_pem,
|
||||
&config.security.ip_whitelist,
|
||||
));
|
||||
|
||||
let pool = db::init_pool(&config.database).await?;
|
||||
db::run_migrations(&pool).await?;
|
||||
|
||||
// Initialise the internal CA using the configured certificate paths.
|
||||
// The CA certificate and key must exist at the configured locations and be
|
||||
// unencrypted PEM. If absent, a new CA is generated in that directory.
|
||||
let ca_base = std::path::Path::new(&config.security.ca_cert_path)
|
||||
.parent()
|
||||
.expect("CA certificate path must have a parent directory");
|
||||
let ca = pm_ca::CertAuthority::init(ca_base, &pool)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(error = %e, "CA init failed (dev mode)");
|
||||
panic!("CA initialization failed: {}", e);
|
||||
});
|
||||
|
||||
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
|
||||
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
|
||||
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
|
||||
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
|
||||
|
||||
// Background task: purge expired WS tickets every 30 seconds.
|
||||
{
|
||||
let tickets = ws_tickets.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(30));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = chrono::Utc::now();
|
||||
let before = tickets.len();
|
||||
tickets.retain(|_, v| v.expires_at > now);
|
||||
let removed = before.saturating_sub(tickets.len());
|
||||
if removed > 0 {
|
||||
tracing::debug!(removed, "Purged expired WS tickets");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Background task: purge expired SSO sessions every 60 seconds (sessions older than 10 minutes).
|
||||
{
|
||||
let sessions = sso_sessions.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = chrono::Utc::now();
|
||||
let cutoff = now - chrono::Duration::minutes(10);
|
||||
let before = sessions.len();
|
||||
sessions.retain(|_, v| v.created_at > cutoff);
|
||||
let removed = before.saturating_sub(sessions.len());
|
||||
if removed > 0 {
|
||||
tracing::debug!(removed, "Purged expired SSO sessions");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Background task: purge approved enrollment PKI bundles every 10 minutes.
|
||||
{
|
||||
let approved = approved_enrollments.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(600));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
approved.clear();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let state = AppState {
|
||||
db: pool,
|
||||
config: Arc::new(config.clone()),
|
||||
signing_key_pem,
|
||||
auth_config,
|
||||
ws_tickets,
|
||||
sso_sessions,
|
||||
ca: Arc::new(ca),
|
||||
approved_enrollments,
|
||||
oidc_cache,
|
||||
};
|
||||
|
||||
let app = build_router(state);
|
||||
|
||||
let addr: SocketAddr = format!("{}:{}", config.server.host, config.server.port)
|
||||
.parse()
|
||||
.expect("Invalid bind address");
|
||||
|
||||
// Try to load TLS certificate and key; fall back to plain HTTP if missing.
|
||||
let tls_cert = std::path::Path::new(&config.security.web_tls_cert_path);
|
||||
let tls_key = std::path::Path::new(&config.security.web_tls_key_path);
|
||||
|
||||
if tls_cert.exists() && tls_key.exists() {
|
||||
let tls_config = RustlsConfig::from_pem_file(
|
||||
&config.security.web_tls_cert_path,
|
||||
&config.security.web_tls_key_path,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load TLS certificates");
|
||||
e
|
||||
})?;
|
||||
|
||||
tracing::info!(%addr, "Listening (HTTPS)");
|
||||
axum_server::bind_rustls(addr, tls_config)
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await?;
|
||||
} else {
|
||||
tracing::warn!(
|
||||
cert_path = %config.security.web_tls_cert_path,
|
||||
key_path = %config.security.web_tls_key_path,
|
||||
"TLS certificates not found — falling back to plain HTTP. \
|
||||
This is insecure for production!"
|
||||
);
|
||||
tracing::info!(%addr, "Listening (HTTP — no TLS)");
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Construct the full Axum router.
|
||||
pub fn build_router(state: AppState) -> Router {
|
||||
let static_dir = state.config.server.static_dir.clone();
|
||||
let auth_config = state.auth_config.clone();
|
||||
let rl = &state.config.rate_limit;
|
||||
|
||||
// Enrollment rate limiting: strict (5 req/min per IP, burst 3)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 12_000ms = ~5/min sustained
|
||||
let enrollment_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(12_000)
|
||||
.burst_size(rl.enrollment_burst)
|
||||
.finish()
|
||||
.expect("Invalid enrollment governor config"),
|
||||
);
|
||||
|
||||
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 3_000ms = ~20/min sustained
|
||||
let auth_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(3_000)
|
||||
.burst_size(rl.auth_burst)
|
||||
.finish()
|
||||
.expect("Invalid auth governor config"),
|
||||
);
|
||||
|
||||
// API rate limiting: normal (120 req/min per IP, burst 30)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 500ms = ~120/min sustained
|
||||
let api_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(500)
|
||||
.burst_size(rl.api_burst)
|
||||
.finish()
|
||||
.expect("Invalid API governor config"),
|
||||
);
|
||||
|
||||
// Enrollment routes with strict per-IP rate limiting
|
||||
let enrollment_router =
|
||||
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
|
||||
|
||||
// Public auth routes with moderate per-IP rate limiting
|
||||
let auth_public_router =
|
||||
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
|
||||
|
||||
// SSO routes with moderate per-IP rate limiting
|
||||
let sso_public_router =
|
||||
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
|
||||
let sso_azure_router =
|
||||
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
|
||||
|
||||
// All protected API routes — require valid JWT, with normal per-IP rate limiting
|
||||
let protected_api = Router::new()
|
||||
// Auth: MFA setup/verify
|
||||
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
|
||||
.nest("/auth", routes::auth::protected_router())
|
||||
// Hosts
|
||||
.nest("/hosts", routes::hosts::router())
|
||||
// Host-scoped certificate endpoints (merged separately to avoid conflict)
|
||||
.nest("/hosts", routes::ca::host_cert_router())
|
||||
// Groups
|
||||
.nest("/groups", routes::groups::router())
|
||||
// Users
|
||||
.nest("/users", routes::users::router())
|
||||
// Discovery
|
||||
.nest("/discovery", routes::discovery::router())
|
||||
// Fleet status
|
||||
.nest("/status", routes::status::router())
|
||||
// Patch jobs
|
||||
.nest("/jobs", routes::jobs::router())
|
||||
// Maintenance windows (nested under hosts path param)
|
||||
.nest(
|
||||
"/hosts/{host_id}/maintenance-windows",
|
||||
routes::maintenance_windows::router(),
|
||||
)
|
||||
// Maintenance windows — bulk list-all endpoint
|
||||
.nest(
|
||||
"/maintenance-windows",
|
||||
routes::maintenance_windows::all_windows_router(),
|
||||
)
|
||||
// CA root certificate download
|
||||
.nest("/ca", routes::ca::ca_router())
|
||||
// Certificate list / renew / revoke
|
||||
.nest("/certificates", routes::ca::certs_router())
|
||||
// WS ticket issuance (JWT-protected — ticket returned to browser, then used for WS upgrade)
|
||||
.merge(routes::ws::ticket_router())
|
||||
// Reports
|
||||
.nest("/reports", routes::reports::router())
|
||||
.nest(
|
||||
"/hosts/{host_id}/health-checks",
|
||||
routes::health_checks::router(),
|
||||
)
|
||||
// Settings (admin-only)
|
||||
.nest("/settings", routes::settings::router())
|
||||
// Admin enrollment routes (JWT protected, Admin role enforced)
|
||||
.nest("/admin", routes::enrollment::admin_router())
|
||||
// Apply rate limiting then auth middleware
|
||||
.layer(GovernorLayer::new(api_governor))
|
||||
.route_layer(middleware::from_fn(move |req, next| {
|
||||
let auth_config = auth_config.clone();
|
||||
require_auth(auth_config, req, next)
|
||||
}));
|
||||
|
||||
Router::new()
|
||||
.route("/status/health", get(health_handler))
|
||||
// Public auth routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth", auth_public_router)
|
||||
// Public enrollment endpoints (rate-limited, no JWT)
|
||||
.nest("/api/v1", enrollment_router)
|
||||
// Public SSO routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth/sso", sso_public_router)
|
||||
// Public Azure SSO routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth/azure", sso_azure_router)
|
||||
// Protected API routes (JWT required, rate-limited)
|
||||
.nest("/api/v1", protected_api)
|
||||
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
|
||||
.merge(routes::ws::ws_router())
|
||||
// Serve React SPA
|
||||
.fallback_service(
|
||||
ServeDir::new(&static_dir)
|
||||
.append_index_html_on_directories(true)
|
||||
.fallback(ServeFile::new(format!("{}/index.html", static_dir))),
|
||||
)
|
||||
.layer(middleware::from_fn(request_id_middleware))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn health_handler(State(state): State<AppState>) -> Result<Json<Value>, StatusCode> {
|
||||
let db_ok = sqlx::query("SELECT 1").execute(&state.db).await.is_ok();
|
||||
let status = if db_ok { "healthy" } else { "degraded" };
|
||||
let body = json!({ "service": "patch-manager-web", "version": env!("CARGO_PKG_VERSION"), "status": status, "database": if db_ok { "ok" } else { "error" } });
|
||||
if db_ok {
|
||||
Ok(Json(body))
|
||||
} else {
|
||||
Err(StatusCode::SERVICE_UNAVAILABLE)
|
||||
}
|
||||
}
|
||||
434
crates/pm-web/src/routes/auth.rs
Executable file
434
crates/pm-web/src/routes/auth.rs
Executable file
@ -0,0 +1,434 @@
|
||||
//! Authentication route handlers.
|
||||
//!
|
||||
//! Public routes (no auth required):
|
||||
//! POST /api/v1/auth/login
|
||||
//! POST /api/v1/auth/refresh
|
||||
//! POST /api/v1/auth/logout
|
||||
//!
|
||||
//! Protected routes (JWT required):
|
||||
//! GET /api/v1/auth/mfa/setup
|
||||
//! POST /api/v1/auth/mfa/verify
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::Json,
|
||||
routing::delete,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::{
|
||||
hash_password, mfa_totp,
|
||||
rbac::AuthUser,
|
||||
session::{self, LoginRequest, LoginResponse},
|
||||
validate_password_strength, verify_password,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ============================================================
|
||||
// Public router — no authentication required
|
||||
// ============================================================
|
||||
|
||||
pub fn public_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", post(login_handler))
|
||||
.route("/refresh", post(refresh_handler))
|
||||
.route("/logout", post(logout_handler))
|
||||
.route(
|
||||
"/force-change-password",
|
||||
post(force_change_password_handler),
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Protected router — requires valid JWT (applied by caller)
|
||||
// ============================================================
|
||||
|
||||
pub fn protected_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/mfa/setup", get(mfa_setup_handler))
|
||||
.route("/mfa/verify", post(mfa_verify_handler))
|
||||
.route("/mfa", delete(disable_mfa))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
fn user_agent(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("user-agent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
fn remote_ip(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.split(',').next().unwrap_or("").trim().to_string())
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/login
|
||||
// ============================================================
|
||||
|
||||
async fn login_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
|
||||
let ip = remote_ip(&headers);
|
||||
let ua = user_agent(&headers);
|
||||
|
||||
session::login(
|
||||
&state.db,
|
||||
&req,
|
||||
&state.signing_key_pem,
|
||||
state.config.security.jwt_access_ttl_secs as i64,
|
||||
ua.as_deref(),
|
||||
ip.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
use pm_auth::session::SessionError;
|
||||
let (status, code, message) = match e {
|
||||
SessionError::InvalidCredentials | SessionError::InvalidMfaCode => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_credentials",
|
||||
"Invalid username or password",
|
||||
),
|
||||
SessionError::MfaRequired => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"mfa_required",
|
||||
"MFA code required",
|
||||
),
|
||||
SessionError::AccountDisabled => (
|
||||
StatusCode::FORBIDDEN,
|
||||
"account_disabled",
|
||||
"Account is disabled",
|
||||
),
|
||||
SessionError::PasswordResetRequired => (
|
||||
StatusCode::FORBIDDEN,
|
||||
"password_reset_required",
|
||||
"Password reset is required before login",
|
||||
),
|
||||
SessionError::AccountLocked => (
|
||||
StatusCode::LOCKED,
|
||||
"account_locked",
|
||||
"Account is locked due to too many failed login attempts",
|
||||
),
|
||||
_ => {
|
||||
tracing::error!(error = %e, "Login error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"An error occurred",
|
||||
)
|
||||
},
|
||||
};
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/refresh
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RefreshRequest {
|
||||
refresh_token: String,
|
||||
}
|
||||
|
||||
async fn refresh_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<RefreshRequest>,
|
||||
) -> Result<Json<LoginResponse>, (StatusCode, Json<Value>)> {
|
||||
let ip = remote_ip(&headers);
|
||||
let ua = user_agent(&headers);
|
||||
|
||||
session::refresh_session(
|
||||
&state.db,
|
||||
&req.refresh_token,
|
||||
&state.signing_key_pem,
|
||||
state.config.security.jwt_access_ttl_secs as i64,
|
||||
ua.as_deref(),
|
||||
ip.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
use pm_auth::session::SessionError;
|
||||
let (status, code, msg) = match e {
|
||||
SessionError::Refresh(_) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_refresh_token",
|
||||
"Refresh token is invalid or expired",
|
||||
),
|
||||
SessionError::AccountDisabled => (
|
||||
StatusCode::FORBIDDEN,
|
||||
"account_disabled",
|
||||
"Account is disabled",
|
||||
),
|
||||
_ => {
|
||||
tracing::error!(error = %e, "Refresh error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"An error occurred",
|
||||
)
|
||||
},
|
||||
};
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": msg } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/logout
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LogoutRequest {
|
||||
refresh_token: String,
|
||||
}
|
||||
|
||||
async fn logout_handler(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<LogoutRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
session::logout(&state.db, &req.refresh_token)
|
||||
.await
|
||||
.map(|_| Json(json!({ "message": "Logged out successfully" })))
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Logout error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "An error occurred" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/mfa/setup (JWT required — via middleware)
|
||||
// ============================================================
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/force-change-password (PUBLIC — no JWT)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ForceChangePasswordRequest {
|
||||
username: String,
|
||||
current_password: String,
|
||||
new_password: String,
|
||||
}
|
||||
|
||||
async fn force_change_password_handler(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<ForceChangePasswordRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Validate new password strength
|
||||
if let Err(msg) = validate_password_strength(&req.new_password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Look up user by username
|
||||
let row: Option<(Uuid, Option<String>, bool)> = sqlx::query_as(
|
||||
"SELECT id, password_hash, force_password_reset FROM users WHERE username = $1 AND auth_provider = 'local'",
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to fetch user");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (user_id, hash_opt, _force_reset) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
|
||||
),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Verify current password
|
||||
let hash_str = hash_opt.as_deref().unwrap_or("");
|
||||
let valid = verify_password(&req.current_password, hash_str).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_credentials", "message": "Invalid username or password" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Hash and update password, clear force_password_reset
|
||||
let new_hash = hash_password(&req.new_password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, failed_login_attempts = 0, locked_until = NULL, updated_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(&new_hash)
|
||||
.bind(user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update password");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(user_id = %user_id, username = %req.username, "Password changed via force-change-password");
|
||||
|
||||
Ok(Json(json!({ "message": "Password changed successfully" })))
|
||||
}
|
||||
|
||||
async fn mfa_setup_handler(
|
||||
auth_user: AuthUser,
|
||||
) -> Result<Json<mfa_totp::TotpSetup>, (StatusCode, Json<Value>)> {
|
||||
mfa_totp::generate_setup(&auth_user.username)
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "TOTP setup error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/auth/mfa/verify (JWT required — via middleware)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MfaVerifyRequest {
|
||||
secret_base32: String,
|
||||
code: String,
|
||||
}
|
||||
|
||||
async fn mfa_verify_handler(
|
||||
State(state): State<AppState>,
|
||||
auth_user: AuthUser,
|
||||
Json(req): Json<MfaVerifyRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let valid =
|
||||
mfa_totp::verify_code(&auth_user.username, &req.secret_base32, &req.code).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "invalid_code", "message": "Invalid TOTP code" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("UPDATE users SET totp_secret = $1, mfa_enabled = TRUE WHERE id = $2")
|
||||
.bind(&req.secret_base32)
|
||||
.bind(auth_user.user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to save TOTP secret");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to enable MFA" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(user_id = %auth_user.user_id, "MFA enabled for user");
|
||||
Ok(Json(json!({ "message": "MFA enabled successfully" })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// DELETE /api/v1/auth/mfa (JWT required — disable own MFA)
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DisableMfaRequest {
|
||||
password: String,
|
||||
}
|
||||
|
||||
async fn disable_mfa(
|
||||
State(state): State<AppState>,
|
||||
auth_user: AuthUser,
|
||||
Json(req): Json<DisableMfaRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Verify current password to confirm identity
|
||||
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
|
||||
.bind(auth_user.user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to fetch password hash");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?
|
||||
.flatten();
|
||||
|
||||
let hash_str = hash.unwrap_or_default();
|
||||
let valid = verify_password(&req.password, &hash_str).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE WHERE id = $1")
|
||||
.bind(auth_user.user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to disable MFA");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(user_id = %auth_user.user_id, "MFA disabled for user");
|
||||
Ok(Json(json!({ "message": "MFA disabled successfully" })))
|
||||
}
|
||||
516
crates/pm-web/src/routes/ca.rs
Executable file
516
crates/pm-web/src/routes/ca.rs
Executable file
@ -0,0 +1,516 @@
|
||||
//! CA / certificate management routes.
|
||||
//!
|
||||
//! ca_router() → mounted at /api/v1/ca
|
||||
//! GET /root.crt download_root_ca (any authed role)
|
||||
//!
|
||||
//! certs_router() → mounted at /api/v1/certificates
|
||||
//! GET / list_certificates (any authed role)
|
||||
//! POST /:cert_id/renew renew_cert (admin only)
|
||||
//! DELETE /:cert_id revoke_cert (admin only)
|
||||
//!
|
||||
//! host_cert_router() → merged under /api/v1/hosts
|
||||
//! GET /:host_id/client.crt download_client_cert (admin only)
|
||||
//! POST /:host_id/certificates issue_client_cert (admin only)
|
||||
//! POST /:host_id/certificates/reissue reissue_host_cert (admin only)
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, Query, State},
|
||||
http::{header, Response, StatusCode},
|
||||
response::Json,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::audit::{log_event, AuditAction};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::Row;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router constructors ───────────────────────────────────────────────────────
|
||||
|
||||
/// Handles routes mounted at /api/v1/ca
|
||||
pub fn ca_router() -> Router<AppState> {
|
||||
Router::new().route("/root.crt", get(download_root_ca))
|
||||
}
|
||||
|
||||
/// Handles routes mounted at /api/v1/certificates
|
||||
pub fn certs_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_certificates))
|
||||
.route("/{cert_id}/renew", post(renew_cert))
|
||||
.route("/{cert_id}", delete(revoke_cert))
|
||||
}
|
||||
|
||||
/// Handles cert-specific paths merged under /api/v1/hosts.
|
||||
/// Only adds paths not already claimed by the hosts router.
|
||||
pub fn host_cert_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/{host_id}/client.crt", get(download_client_cert))
|
||||
.route("/{host_id}/certificates", post(issue_client_cert))
|
||||
.route("/{host_id}/certificates/reissue", post(reissue_host_cert))
|
||||
}
|
||||
|
||||
// ── Shared types ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Row returned from the `certificates` table.
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
struct CertRow {
|
||||
id: Uuid,
|
||||
host_id: Option<Uuid>,
|
||||
serial_number: String,
|
||||
common_name: String,
|
||||
/// Cast to TEXT in all queries to avoid custom-enum decode.
|
||||
status: String,
|
||||
issued_at: DateTime<Utc>,
|
||||
expires_at: DateTime<Utc>,
|
||||
revoked_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Query params for `list_certificates`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CertListQuery {
|
||||
host_id: Option<Uuid>,
|
||||
status: Option<String>,
|
||||
}
|
||||
|
||||
/// Request body for `issue_client_cert`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IssueCertRequest {
|
||||
hostname: String,
|
||||
}
|
||||
|
||||
// ── Helper: build PEM download response ──────────────────────────────────────
|
||||
|
||||
fn pem_response(pem: String, filename: &str) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
|
||||
let disposition = format!("attachment; filename=\"{filename}\"");
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/x-pem-file")
|
||||
.header(header::CONTENT_DISPOSITION, disposition)
|
||||
.body(Body::from(pem))
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build PEM response");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Response build error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── Helper: admin-only guard ──────────────────────────────────────────────────
|
||||
|
||||
fn require_write_access(user: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
|
||||
if !user.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Helper: map sqlx error to 500 ─────────────────────────────────────────────
|
||||
|
||||
fn db_error(e: sqlx::Error) -> (StatusCode, Json<Value>) {
|
||||
tracing::error!(error = %e, "Database error");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── Helper: build the full IssuedCert JSON response ──────────────────────────
|
||||
|
||||
fn issued_cert_json(issued: &pm_ca::IssuedCert) -> Value {
|
||||
json!({
|
||||
"cert_pem": issued.cert_pem,
|
||||
"key_pem": issued.key_pem,
|
||||
"serial_number": issued.serial_number,
|
||||
"expires_at": issued.expires_at,
|
||||
"server_cert_pem": issued.server_cert_pem,
|
||||
"server_key_pem": issued.server_key_pem,
|
||||
"server_serial_number": issued.server_serial_number,
|
||||
"ca_root_pem": issued.ca_root_pem,
|
||||
})
|
||||
}
|
||||
|
||||
// ── GET /api/v1/ca/root.crt ───────────────────────────────────────────────────
|
||||
|
||||
/// Download the root CA certificate as a PEM file.
|
||||
async fn download_root_ca(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
|
||||
let pem = state.ca.root_cert_pem().to_owned();
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateDownloaded,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some("root_ca"),
|
||||
json!({ "operation": "download_root_ca" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
pem_response(pem, "ca.crt")
|
||||
}
|
||||
|
||||
// ── GET /api/v1/certificates ──────────────────────────────────────────────────
|
||||
|
||||
/// List certificates with optional `?host_id=` and `?status=` filters.
|
||||
async fn list_certificates(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Query(q): Query<CertListQuery>,
|
||||
) -> Result<Json<Vec<CertRow>>, (StatusCode, Json<Value>)> {
|
||||
// Use the non-macro query_as form — avoids needing DATABASE_URL at compile
|
||||
// time. status is cast to TEXT so sqlx decodes it into String directly.
|
||||
let rows: Vec<CertRow> = match (q.host_id, q.status.as_deref()) {
|
||||
(Some(hid), Some(st)) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
WHERE host_id = $1 AND status::text = $2
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.bind(hid)
|
||||
.bind(st)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
(Some(hid), None) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
WHERE host_id = $1
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.bind(hid)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
(None, Some(st)) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
WHERE status::text = $1
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.bind(st)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
(None, None) => {
|
||||
sqlx::query_as::<_, CertRow>(
|
||||
r#"SELECT id, host_id, serial_number, common_name,
|
||||
status::text AS status,
|
||||
issued_at, expires_at, revoked_at
|
||||
FROM certificates
|
||||
ORDER BY issued_at DESC"#,
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
},
|
||||
}
|
||||
.map_err(db_error)?;
|
||||
|
||||
Ok(Json(rows))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:host_id/client.crt ────────────────────────────────────
|
||||
|
||||
/// Download the most recent active client certificate PEM for a host.
|
||||
async fn download_client_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Response<Body>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
let cert_pem: Option<String> = sqlx::query_scalar(
|
||||
r#"SELECT cert_pem
|
||||
FROM certificates
|
||||
WHERE host_id = $1
|
||||
AND status = 'active'::cert_status
|
||||
AND common_name NOT LIKE '%-server'
|
||||
ORDER BY issued_at DESC
|
||||
LIMIT 1"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to fetch client cert");
|
||||
db_error(e)
|
||||
})?;
|
||||
|
||||
match cert_pem {
|
||||
Some(pem) => {
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateDownloaded,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "operation": "download_client_cert" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
pem_response(pem, "client.crt")
|
||||
},
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"code": "not_found",
|
||||
"message": "No active certificate found for this host"
|
||||
}
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/certificates ─────────────────────────────────
|
||||
|
||||
/// Issue a new mTLS client certificate (and server certificate) for a host.
|
||||
/// **The private keys are returned only once — the caller must save them.**
|
||||
async fn issue_client_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
Json(req): Json<IssueCertRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
// Look up the host's IP address from the database.
|
||||
let ip_address: String = sqlx::query_scalar("SELECT host(ip_address) FROM hosts WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to fetch host IP address");
|
||||
if e.to_string().contains("no rows") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
} else {
|
||||
db_error(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let issued = state
|
||||
.ca
|
||||
.issue_client_cert(host_id, &req.hostname, &ip_address, &state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, hostname = %req.hostname,
|
||||
"Failed to issue client cert");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateIssued,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "hostname": req.hostname, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(issued_cert_json(&issued)))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/certificates/:cert_id/renew ─────────────────────────────────
|
||||
|
||||
/// Revoke the specified certificate and issue a replacement with the same CN.
|
||||
async fn renew_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(cert_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
let issued = state.ca.renew_cert(cert_id, &state.db).await.map_err(|e| {
|
||||
let msg = e.to_string();
|
||||
tracing::error!(error = %e, %cert_id, "Failed to renew cert");
|
||||
if msg.contains("not found") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(
|
||||
json!({ "error": { "code": "not_found", "message": "Certificate not found" } }),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
|
||||
)
|
||||
}
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateRenewed,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&cert_id.to_string()),
|
||||
json!({ "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(issued_cert_json(&issued)))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/certificates/reissue ────────────────────────
|
||||
|
||||
/// Revoke ALL active certificates for a host and issue new ones.
|
||||
/// The private keys are returned only once — the caller must save them.
|
||||
async fn reissue_host_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
// Look up the host's FQDN and IP address for the new certificate CN and SANs.
|
||||
let row = sqlx::query("SELECT fqdn, host(ip_address) AS ip_address FROM hosts WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to fetch host FQDN/IP");
|
||||
if e.to_string().contains("no rows") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
} else {
|
||||
db_error(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let fqdn: String = row.try_get("fqdn").map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to read fqdn");
|
||||
db_error(e)
|
||||
})?;
|
||||
let ip_address: String = row.try_get("ip_address").map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to read ip_address");
|
||||
db_error(e)
|
||||
})?;
|
||||
|
||||
// Revoke all active certificates for this host.
|
||||
let revoked = sqlx::query(
|
||||
"UPDATE certificates SET status = 'revoked'::cert_status, revoked_at = NOW() \
|
||||
WHERE host_id = $1 AND status = 'active'::cert_status",
|
||||
)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(db_error)?;
|
||||
|
||||
tracing::info!(%host_id, rows_revoked = revoked.rows_affected(), "Revoked all active certs for host");
|
||||
|
||||
// Issue a new certificate bundle using the host's FQDN and IP.
|
||||
let issued = state
|
||||
.ca
|
||||
.issue_client_cert(host_id, &fqdn, &ip_address, &state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "Failed to issue new cert during reissue");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateReissued,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "hostname": &fqdn, "serial_number": issued.serial_number, "server_serial_number": issued.server_serial_number, "rows_revoked": revoked.rows_affected() }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(issued_cert_json(&issued)))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/certificates/:cert_id ─────────────────────────────────────
|
||||
|
||||
/// Revoke a certificate by ID. Sets status to 'revoked' in the database.
|
||||
async fn revoke_cert(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(cert_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
require_write_access(&auth)?;
|
||||
|
||||
state
|
||||
.ca
|
||||
.revoke_cert(cert_id, &state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = e.to_string();
|
||||
tracing::error!(error = %e, %cert_id, "Failed to revoke cert");
|
||||
if msg.contains("not found") {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Certificate not found" } })),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": msg } })),
|
||||
)
|
||||
}
|
||||
})?;
|
||||
|
||||
tracing::info!(%cert_id, "Certificate revoked via API");
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::CertificateRevoked,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("certificate"),
|
||||
Some(&cert_id.to_string()),
|
||||
json!({ "operation": "revoke" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "revoked": true })))
|
||||
}
|
||||
304
crates/pm-web/src/routes/discovery.rs
Executable file
304
crates/pm-web/src/routes/discovery.rs
Executable file
@ -0,0 +1,304 @@
|
||||
//! CIDR auto-discovery routes.
|
||||
//!
|
||||
//! POST /api/v1/discovery/cidr — start a CIDR scan
|
||||
//! GET /api/v1/discovery/:scan_id — get scan results
|
||||
//! POST /api/v1/discovery/:id/register — register a discovered host
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{DiscoveryCidrRequest, DiscoveryResult, RegisterDiscoveredRequest},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
net::{IpAddr, TcpStream},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{sync::Semaphore, task};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
/// Maximum concurrent TCP probes during CIDR scan.
|
||||
const MAX_CONCURRENT_PROBES: usize = 128;
|
||||
/// TCP connect timeout per probe.
|
||||
const PROBE_TIMEOUT_SECS: u64 = 2;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/cidr", post(start_cidr_scan))
|
||||
.route("/{scan_id}", get(get_scan_results))
|
||||
.route("/{id}/register", post(register_discovered_host))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/discovery/cidr ───────────────────────────────────────────────
|
||||
|
||||
async fn start_cidr_scan(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<DiscoveryCidrRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let cidr: ipnet::IpNet = req.cidr.parse().map_err(|_| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "bad_request", "message": "Invalid CIDR range" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let agent_port = req.agent_port.unwrap_or(12443) as u16;
|
||||
let scan_id = Uuid::new_v4();
|
||||
|
||||
// Clear previous results for this type of scan and start async scan
|
||||
let pool = state.db.clone();
|
||||
let scan_id_clone = scan_id;
|
||||
let cidr_str = req.cidr.clone();
|
||||
|
||||
// Spawn non-blocking background scan
|
||||
task::spawn(async move {
|
||||
run_cidr_scan(pool, scan_id_clone, cidr, agent_port).await;
|
||||
});
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::DiscoveryScanStarted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("discovery"),
|
||||
Some(&scan_id.to_string()),
|
||||
json!({ "cidr": cidr_str }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(scan_id = %scan_id, cidr = %req.cidr, "CIDR scan started");
|
||||
Ok(Json(
|
||||
json!({ "scan_id": scan_id, "message": "Discovery scan started", "cidr": req.cidr }),
|
||||
))
|
||||
}
|
||||
|
||||
/// Background CIDR scanner.
|
||||
async fn run_cidr_scan(pool: sqlx::PgPool, scan_id: Uuid, cidr: ipnet::IpNet, port: u16) {
|
||||
let semaphore = std::sync::Arc::new(Semaphore::new(MAX_CONCURRENT_PROBES));
|
||||
let hosts: Vec<IpAddr> = cidr.hosts().collect();
|
||||
let total = hosts.len();
|
||||
|
||||
tracing::info!(scan_id = %scan_id, total = total, "CIDR scan probing {} hosts", total);
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for ip in hosts {
|
||||
let sem = semaphore.clone();
|
||||
let pool_clone = pool.clone();
|
||||
let h = task::spawn(async move {
|
||||
let _permit = sem.acquire().await.ok()?;
|
||||
probe_and_store(pool_clone, scan_id, ip, port).await
|
||||
});
|
||||
handles.push(h);
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
let _ = h.await;
|
||||
}
|
||||
|
||||
tracing::info!(scan_id = %scan_id, "CIDR scan complete");
|
||||
}
|
||||
|
||||
/// Probe a single IP:port and store the result if the port is open.
|
||||
async fn probe_and_store(pool: sqlx::PgPool, scan_id: Uuid, ip: IpAddr, port: u16) -> Option<()> {
|
||||
let addr = format!("{ip}:{port}");
|
||||
|
||||
// TCP connect probe (blocking, run in thread pool)
|
||||
// TCP connect probe (blocking, run in thread pool)
|
||||
let addr_clone = addr.clone();
|
||||
let open = task::spawn_blocking(move || {
|
||||
TcpStream::connect_timeout(
|
||||
&match addr_clone.parse() {
|
||||
Ok(a) => a,
|
||||
Err(_) => return false,
|
||||
},
|
||||
Duration::from_secs(PROBE_TIMEOUT_SECS),
|
||||
)
|
||||
.is_ok()
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !open {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Reverse DNS lookup (best-effort)
|
||||
let ip_clone = ip;
|
||||
let fqdn = task::spawn_blocking(move || {
|
||||
use std::net::ToSocketAddrs;
|
||||
let addr = format!("{ip_clone}:{port}");
|
||||
addr.to_socket_addrs()
|
||||
.ok()
|
||||
.and_then(|mut a| a.next())
|
||||
.and_then(|_| dns_lookup_for_ip(ip_clone))
|
||||
})
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let _ = sqlx::query(
|
||||
r#"INSERT INTO discovery_results (scan_id, ip_address, fqdn, agent_port)
|
||||
VALUES ($1, $2::inet, $3, $4)
|
||||
ON CONFLICT DO NOTHING"#,
|
||||
)
|
||||
.bind(scan_id)
|
||||
.bind(ip.to_string())
|
||||
.bind(fqdn)
|
||||
.bind(port as i32)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
tracing::debug!(ip = %ip, port = port, "Discovered agent");
|
||||
Some(())
|
||||
}
|
||||
|
||||
/// Simple reverse DNS lookup.
|
||||
fn dns_lookup_for_ip(ip: IpAddr) -> Option<String> {
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
let _addr = SocketAddr::new(ip, 0);
|
||||
// Standard library doesn't have reverse lookup; use getaddrinfo via format
|
||||
let host = format!("{ip}");
|
||||
// Best-effort: try to resolve numeric address to hostname
|
||||
(host + ":0")
|
||||
.to_socket_addrs()
|
||||
.ok()?
|
||||
.next()
|
||||
.map(|a| a.ip().to_string())
|
||||
.filter(|s| s != &ip.to_string())
|
||||
}
|
||||
|
||||
// ── GET /api/v1/discovery/:scan_id ────────────────────────────────────────────
|
||||
|
||||
async fn get_scan_results(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(scan_id): Path<Uuid>,
|
||||
) -> Result<Json<Vec<DiscoveryResult>>, (StatusCode, Json<Value>)> {
|
||||
sqlx::query_as::<_, DiscoveryResult>(
|
||||
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
|
||||
agent_version, os_name, agent_port, discovered_at, registered
|
||||
FROM discovery_results
|
||||
WHERE scan_id = $1
|
||||
ORDER BY ip_address"#,
|
||||
)
|
||||
.bind(scan_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── POST /api/v1/discovery/:id/register ──────────────────────────────────────
|
||||
|
||||
async fn register_discovered_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<RegisterDiscoveredRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch discovery result
|
||||
let result: Option<DiscoveryResult> = sqlx::query_as(
|
||||
r#"SELECT id, scan_id, host(ip_address)::text AS ip_address, fqdn,
|
||||
agent_version, os_name, agent_port, discovered_at, registered
|
||||
FROM discovery_results WHERE id = $1"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let result = result.ok_or_else(|| (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Discovery result not found" } }))
|
||||
))?;
|
||||
|
||||
let fqdn = result.fqdn.as_deref().unwrap_or(&result.ip_address);
|
||||
let display_name = req.display_name.as_deref().unwrap_or(fqdn);
|
||||
|
||||
let host_id: Uuid = sqlx::query_scalar(
|
||||
r#"INSERT INTO hosts (fqdn, ip_address, display_name, agent_port)
|
||||
VALUES ($1, $2::inet, $3, $4)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(fqdn)
|
||||
.bind(&result.ip_address)
|
||||
.bind(display_name)
|
||||
.bind(result.agent_port)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Assign to groups
|
||||
if let Some(group_ids) = &req.group_ids {
|
||||
for gid in group_ids {
|
||||
let _ = sqlx::query("INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING")
|
||||
.bind(host_id).bind(gid).execute(&state.db).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as registered
|
||||
let _ = sqlx::query("UPDATE discovery_results SET registered = TRUE WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::HostRegistered,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "from_discovery": true, "ip": result.ip_address }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(
|
||||
json!({ "host_id": host_id, "message": "Host registered from discovery" }),
|
||||
))
|
||||
}
|
||||
319
crates/pm-web/src/routes/enrollment.rs
Normal file
319
crates/pm-web/src/routes/enrollment.rs
Normal file
@ -0,0 +1,319 @@
|
||||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use pm_auth::AuthUser;
|
||||
use pm_core::{
|
||||
db,
|
||||
models::{
|
||||
CreateEnrollmentRequest, EnrollmentRequest, EnrollmentStatusResponse, Host, PkiBundle,
|
||||
},
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct HostConflict {
|
||||
pub existing_host: Host,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Define public enrollment routes.
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/enroll", post(enroll_host))
|
||||
.route("/enroll/status/{token}", get(enroll_status))
|
||||
}
|
||||
|
||||
/// POST /api/v1/enroll
|
||||
/// Initiates host self-enrollment.
|
||||
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
|
||||
async fn enroll_host(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<CreateEnrollmentRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Generate secure random polling token
|
||||
let polling_token: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(64)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
// For database storage, we'll hash the token (spec says hashed)
|
||||
// Using a simple SHA256 or similar for the hash storage
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(polling_token.as_bytes());
|
||||
let token_hash = hex::encode(hasher.finalize());
|
||||
|
||||
// 3. Store in DB
|
||||
db::create_enrollment_request(&state.db, payload, token_hash)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to create enrollment request");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 4. Return the raw token to the client
|
||||
Ok((
|
||||
StatusCode::ACCEPTED,
|
||||
Json(serde_json::json!({ "polling_token": polling_token })),
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
|
||||
/// GET /api/v1/enroll/status/{token}
|
||||
/// Returns status of enrollment (pending/approved/denied/not_found).
|
||||
async fn enroll_status(
|
||||
State(state): State<AppState>,
|
||||
Path(token): Path<String>,
|
||||
) -> Result<Json<EnrollmentStatusResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
// Hash the provided token to match DB
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let token_hash = hex::encode(hasher.finalize());
|
||||
|
||||
// 1. Check enrollment_requests table
|
||||
let requests = db::list_enrollment_requests(&state.db).await.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(req) = requests.into_iter().find(|r| r.polling_token == token_hash) {
|
||||
if req.expires_at < Utc::now() {
|
||||
return Ok(Json(EnrollmentStatusResponse::NotFound));
|
||||
}
|
||||
return Ok(Json(EnrollmentStatusResponse::Pending));
|
||||
}
|
||||
|
||||
// 2. If not in pending, check if it was recently approved.
|
||||
if let Some(pki) = state.approved_enrollments.get(&token_hash) {
|
||||
return Ok(Json(EnrollmentStatusResponse::Approved {
|
||||
ca_crt: pki.ca_crt.clone(),
|
||||
server_crt: pki.server_crt.clone(),
|
||||
server_key: pki.server_key.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Json(EnrollmentStatusResponse::NotFound))
|
||||
}
|
||||
|
||||
/// Define admin enrollment routes.
|
||||
pub fn admin_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/enrollments", get(list_admin_enrollments))
|
||||
.route("/enrollments/{id}/approve", post(approve_enrollment))
|
||||
.route("/enrollments/{id}/deny", delete(deny_enrollment))
|
||||
}
|
||||
|
||||
/// GET /api/v1/admin/enrollments
|
||||
/// Lists all pending enrollment requests.
|
||||
async fn list_admin_enrollments(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Vec<EnrollmentRequest>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
db::list_enrollment_requests(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list enrollment requests");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// POST /api/v1/admin/enrollments/{id}/approve
|
||||
/// Approves a pending enrollment request, generates PKI, and moves to hosts table.
|
||||
async fn approve_enrollment(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<uuid::Uuid>,
|
||||
auth: AuthUser,
|
||||
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch the enrollment request
|
||||
let mut requests = db::list_enrollment_requests(&state.db).await.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list enrollment requests for approval");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let enrollment_request = match requests.iter().position(|r| r.id == id) {
|
||||
Some(idx) => requests.remove(idx),
|
||||
None => return Ok(StatusCode::NOT_FOUND),
|
||||
};
|
||||
|
||||
// Check for FQDN/IP collision in hosts table
|
||||
if let Some(existing_host) = sqlx::query_as::<_, Host>(
|
||||
"SELECT id, fqdn, ip_address::text, display_name, os_family, os_name, arch, agent_version, health_status, last_health_at, last_patch_at, agent_port, notes, registered_at, updated_at FROM hosts WHERE fqdn = $1 OR ip_address = $2::inet"
|
||||
)
|
||||
.bind(&enrollment_request.fqdn)
|
||||
.bind(enrollment_request.ip_address.to_string())
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to check for host collision");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": "Database error" })))
|
||||
})? {
|
||||
return Err((
|
||||
StatusCode::CONFLICT,
|
||||
Json(serde_json::json!({ "error": "Host collision detected", "conflict": HostConflict { existing_host, message: "FQDN or IP already exists".to_string() } }))
|
||||
));
|
||||
}
|
||||
|
||||
// Move to hosts table FIRST (certificates table has FK reference to hosts)
|
||||
let os_family = enrollment_request
|
||||
.os_details
|
||||
.get("os")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let os_name = enrollment_request
|
||||
.os_details
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
// Build os_name from os + os_version if "name" is absent
|
||||
let os = enrollment_request
|
||||
.os_details
|
||||
.get("os")
|
||||
.and_then(|v| v.as_str())?;
|
||||
let ver = enrollment_request
|
||||
.os_details
|
||||
.get("os_version")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
Some(format!("{} {}", os, ver).trim().to_string())
|
||||
});
|
||||
let arch = enrollment_request
|
||||
.os_details
|
||||
.get("architecture")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let display_name = enrollment_request
|
||||
.hostname
|
||||
.clone()
|
||||
.unwrap_or_else(|| enrollment_request.fqdn.clone());
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO hosts (id, fqdn, ip_address, os_family, os_name, arch, display_name, registered_at, updated_at)
|
||||
VALUES ($1, $2, $3::inet, $4, $5, $6, $7, NOW(), NOW())
|
||||
"#,
|
||||
)
|
||||
.bind(enrollment_request.id)
|
||||
.bind(&enrollment_request.fqdn)
|
||||
.bind(enrollment_request.ip_address.to_string())
|
||||
.bind(&os_family)
|
||||
.bind(&os_name)
|
||||
.bind(&arch)
|
||||
.bind(&display_name)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to insert host after approval");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Generate PKI bundle using CA (after host row exists)
|
||||
let issued = state
|
||||
.ca
|
||||
.issue_client_cert(
|
||||
enrollment_request.id,
|
||||
&enrollment_request.fqdn,
|
||||
&enrollment_request.ip_address.to_string(),
|
||||
&state.db,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to issue client certificate");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Certificate generation failed" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Delete from enrollment_requests table
|
||||
db::delete_enrollment_request(&state.db, id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to delete enrollment request after approval");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Store PKI bundle in cache for client retrieval
|
||||
let pki = PkiBundle {
|
||||
ca_crt: issued.ca_root_pem,
|
||||
server_crt: issued.server_cert_pem,
|
||||
server_key: issued.server_key_pem,
|
||||
};
|
||||
state
|
||||
.approved_enrollments
|
||||
.insert(enrollment_request.polling_token.clone(), pki);
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/admin/enrollments/{id}/deny
|
||||
/// Denies and purges a pending enrollment request.
|
||||
async fn deny_enrollment(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<uuid::Uuid>,
|
||||
auth: AuthUser,
|
||||
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
db::delete_enrollment_request(&state.db, id)
|
||||
.await
|
||||
.map(|_| StatusCode::NO_CONTENT)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to deny enrollment request");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": "Database error" })),
|
||||
)
|
||||
})
|
||||
}
|
||||
312
crates/pm-web/src/routes/groups.rs
Executable file
312
crates/pm-web/src/routes/groups.rs
Executable file
@ -0,0 +1,312 @@
|
||||
//! Group management routes.
|
||||
//!
|
||||
//! GET /api/v1/groups — list all groups
|
||||
//! POST /api/v1/groups — create group (admin)
|
||||
//! GET /api/v1/groups/:id — get group detail + members
|
||||
//! PUT /api/v1/groups/:id — update group (admin)
|
||||
//! DELETE /api/v1/groups/:id — delete group (admin)
|
||||
//! POST /api/v1/groups/:id/users/:user_id — add user to group (admin)
|
||||
//! DELETE /api/v1/groups/:id/users/:user_id — remove user from group (admin)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateGroupRequest, Group, UpdateGroupRequest},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_groups).post(create_group))
|
||||
.route(
|
||||
"/{id}",
|
||||
get(get_group).put(update_group).delete(delete_group),
|
||||
)
|
||||
.route(
|
||||
"/{id}/users/{user_id}",
|
||||
post(add_user_to_group).delete(remove_user_from_group),
|
||||
)
|
||||
}
|
||||
|
||||
async fn list_groups(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
|
||||
sqlx::query_as::<_, Group>(
|
||||
"SELECT id, name, description, created_at, updated_at FROM groups ORDER BY name",
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list groups");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateGroupRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let id: Uuid =
|
||||
sqlx::query_scalar("INSERT INTO groups (name, description) VALUES ($1, $2) RETURNING id")
|
||||
.bind(&req.name)
|
||||
.bind(req.description.as_deref().unwrap_or(""))
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"Group name already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("group"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "name": req.name }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "id": id, "message": "Group created" })))
|
||||
}
|
||||
|
||||
async fn get_group(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let group: Option<Group> = sqlx::query_as(
|
||||
"SELECT id, name, description, created_at, updated_at FROM groups WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let group = group.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Fetch member counts
|
||||
let host_count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM host_groups WHERE group_id = $1")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let user_count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM user_groups WHERE group_id = $1")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(Json(
|
||||
json!({ "group": group, "host_count": host_count, "user_count": user_count }),
|
||||
))
|
||||
}
|
||||
|
||||
async fn update_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<UpdateGroupRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query(
|
||||
"UPDATE groups SET name = COALESCE($1, name), description = COALESCE($2, description), updated_at = NOW() WHERE id = $3"
|
||||
)
|
||||
.bind(req.name.as_deref())
|
||||
.bind(req.description.as_deref())
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } }))))?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Json(json!({ "message": "Group updated" })))
|
||||
}
|
||||
|
||||
async fn delete_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query("DELETE FROM groups WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Group not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("group"),
|
||||
Some(&id.to_string()),
|
||||
json!({}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Group deleted" })))
|
||||
}
|
||||
|
||||
async fn add_user_to_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((id, user_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO user_groups (user_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user_group"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "user_id": user_id, "action": "added" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User added to group" })))
|
||||
}
|
||||
|
||||
async fn remove_user_from_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((id, user_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM user_groups WHERE user_id = $1 AND group_id = $2")
|
||||
.bind(user_id)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user_group"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "user_id": user_id, "action": "removed" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User removed from group" })))
|
||||
}
|
||||
1159
crates/pm-web/src/routes/health_checks.rs
Executable file
1159
crates/pm-web/src/routes/health_checks.rs
Executable file
File diff suppressed because it is too large
Load Diff
678
crates/pm-web/src/routes/hosts.rs
Executable file
678
crates/pm-web/src/routes/hosts.rs
Executable file
@ -0,0 +1,678 @@
|
||||
//! Host management routes.
|
||||
//!
|
||||
//! GET /api/v1/hosts — list hosts (RBAC scoped)
|
||||
//! POST /api/v1/hosts — register new host (admin only)
|
||||
//! GET /api/v1/hosts/{id} — get host detail
|
||||
//! DELETE /api/v1/hosts/{id} — remove host (admin only)
|
||||
//! PUT /api/v1/hosts/{id} — update host (write access)
|
||||
//! GET /api/v1/hosts/{id}/groups — list groups for host
|
||||
//! POST /api/v1/hosts/{id}/groups — assign host to group
|
||||
//! DELETE /api/v1/hosts/{id}/groups/{group_id} — remove host from group
|
||||
//! POST /api/v1/hosts/{id}/refresh — queue on-demand refresh (write access)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateHostRequest, Group, HostSummary, UpdateHostRequest},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_hosts).post(register_host))
|
||||
.route("/{id}", get(get_host).put(update_host).delete(remove_host))
|
||||
.route(
|
||||
"/{id}/groups",
|
||||
get(list_host_groups).post(add_host_to_group),
|
||||
)
|
||||
.route("/{id}/groups/{group_id}", delete(remove_host_from_group))
|
||||
.route("/{id}/refresh", post(refresh_host))
|
||||
}
|
||||
|
||||
// ── Query params ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct HostListQuery {
|
||||
pub group_id: Option<Uuid>,
|
||||
pub health_status: Option<String>,
|
||||
pub os_family: Option<String>,
|
||||
pub search: Option<String>,
|
||||
pub limit: Option<i64>,
|
||||
pub offset: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Response types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct HostListResponse {
|
||||
hosts: Vec<HostSummary>,
|
||||
total: i64,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
}
|
||||
|
||||
// ── Helper: check if operator can access a host ───────────────────────────────
|
||||
|
||||
async fn operator_can_access_host(
|
||||
pool: &sqlx::PgPool,
|
||||
user_id: Uuid,
|
||||
host_id: Uuid,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
// Admins can access all; operators can access hosts in their groups
|
||||
// OR ungrouped hosts (no group memberships)
|
||||
let in_group: bool = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM host_groups hg
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE hg.host_id = $1 AND ug.user_id = $2
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if in_group {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Ungrouped hosts are accessible to any operator
|
||||
let ungrouped: bool =
|
||||
sqlx::query_scalar("SELECT NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(ungrouped)
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn list_hosts(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Query(q): Query<HostListQuery>,
|
||||
) -> Result<Json<HostListResponse>, (StatusCode, Json<Value>)> {
|
||||
let limit = q.limit.unwrap_or(50).min(200);
|
||||
let offset = q.offset.unwrap_or(0);
|
||||
|
||||
// For operators: only show hosts in their groups (or ungrouped)
|
||||
let hosts: Vec<HostSummary> = if auth.role.is_admin() {
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT h.id, h.fqdn, host(h.ip_address)::text AS ip_address, h.display_name,
|
||||
h.os_family, h.os_name, h.health_status, h.agent_version,
|
||||
COALESCE(hpd.patch_count, 0) AS patches_missing,
|
||||
CASE
|
||||
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
|
||||
THEN NULL
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM host_health_checks hc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT healthy FROM host_health_check_results r
|
||||
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
|
||||
) lr ON TRUE
|
||||
WHERE hc.host_id = h.id AND hc.enabled = TRUE
|
||||
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
|
||||
)
|
||||
THEN 'some_unhealthy'
|
||||
ELSE 'all_healthy'
|
||||
END AS health_check_status,
|
||||
h.registered_at
|
||||
FROM hosts h
|
||||
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
|
||||
ORDER BY h.fqdn
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT h.id, h.fqdn, host(h.ip_address)::text AS ip_address,
|
||||
h.display_name, h.os_family, h.os_name,
|
||||
h.health_status, h.agent_version,
|
||||
COALESCE(hpd.patch_count, 0) AS patches_missing,
|
||||
CASE
|
||||
WHEN NOT EXISTS (SELECT 1 FROM host_health_checks hc WHERE hc.host_id = h.id AND hc.enabled = TRUE)
|
||||
THEN NULL
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM host_health_checks hc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT healthy FROM host_health_check_results r
|
||||
WHERE r.check_id = hc.id ORDER BY r.checked_at DESC LIMIT 1
|
||||
) lr ON TRUE
|
||||
WHERE hc.host_id = h.id AND hc.enabled = TRUE
|
||||
AND (lr.healthy IS NULL OR lr.healthy = FALSE)
|
||||
)
|
||||
THEN 'some_unhealthy'
|
||||
ELSE 'all_healthy'
|
||||
END AS health_check_status,
|
||||
h.registered_at
|
||||
FROM hosts h
|
||||
LEFT JOIN host_patch_data hpd ON hpd.host_id = h.id
|
||||
WHERE
|
||||
-- Hosts in operator's groups
|
||||
EXISTS (
|
||||
SELECT 1 FROM host_groups hg
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE hg.host_id = h.id AND ug.user_id = $3
|
||||
)
|
||||
-- OR ungrouped hosts
|
||||
OR NOT EXISTS (SELECT 1 FROM host_groups WHERE host_id = h.id)
|
||||
ORDER BY h.fqdn
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.bind(auth.user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
}
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM hosts")
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(Json(HostListResponse {
|
||||
hosts,
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
}))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts ────────────────────────────────────────────────────────
|
||||
|
||||
async fn register_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateHostRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Admin only
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Resolve FQDN to IP address
|
||||
let ip_address = resolve_fqdn(&req.fqdn).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "fqdn_resolution_failed", "message": e } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let display_name = req.display_name.clone().unwrap_or_else(|| req.fqdn.clone());
|
||||
let agent_port = req.agent_port.unwrap_or(12443);
|
||||
let notes = req.notes.clone().unwrap_or_default();
|
||||
|
||||
// Insert host
|
||||
let host_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO hosts (fqdn, ip_address, display_name, agent_port, notes)
|
||||
VALUES ($1, $2::inet, $3, $4, $5)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(&req.fqdn)
|
||||
.bind(&ip_address)
|
||||
.bind(&display_name)
|
||||
.bind(agent_port)
|
||||
.bind(¬es)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"Host with this FQDN and IP already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
tracing::error!(error = %e, "Failed to register host");
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Assign to groups if specified
|
||||
if let Some(group_ids) = &req.group_ids {
|
||||
for gid in group_ids {
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(gid)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Audit log
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::HostRegistered,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&host_id.to_string()),
|
||||
json!({ "fqdn": req.fqdn, "ip": ip_address }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(host_id = %host_id, fqdn = %req.fqdn, "Host registered");
|
||||
Ok(Json(json!({ "id": host_id, "message": "Host registered" })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn get_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !can_access {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let host: Option<Value> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT row_to_json(h) FROM (
|
||||
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
|
||||
os_family, os_name, arch, agent_version, health_status,
|
||||
last_health_at, last_patch_at, agent_port, notes,
|
||||
registered_at, updated_at
|
||||
FROM hosts WHERE id = $1
|
||||
) h
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to get host");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
host.map(Json).ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:id ──────────────────────────────────────────────────
|
||||
|
||||
async fn remove_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch FQDN for audit before deletion
|
||||
let fqdn: Option<String> = sqlx::query_scalar("SELECT fqdn FROM hosts WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let result = sqlx::query("DELETE FROM hosts WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to remove host");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::HostRemoved,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "fqdn": fqdn }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(host_id = %id, "Host removed");
|
||||
Ok(Json(json!({ "message": "Host removed" })))
|
||||
}
|
||||
|
||||
// ── PUT /api/v1/hosts/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn update_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<UpdateHostRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Update only fields that were provided; COALESCE preserves existing values.
|
||||
let host = sqlx::query_scalar(
|
||||
r#"
|
||||
WITH updated AS (
|
||||
UPDATE hosts SET
|
||||
fqdn = COALESCE($1, fqdn),
|
||||
ip_address = COALESCE($2::inet, ip_address),
|
||||
display_name = COALESCE($3, display_name),
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
RETURNING id
|
||||
)
|
||||
SELECT row_to_json(h) FROM (
|
||||
SELECT id, fqdn, host(ip_address)::text AS ip_address, display_name,
|
||||
os_family, os_name, arch, agent_version, health_status,
|
||||
last_health_at, last_patch_at, agent_port, notes,
|
||||
registered_at, updated_at
|
||||
FROM hosts WHERE id = (SELECT id FROM updated)
|
||||
) h
|
||||
"#,
|
||||
)
|
||||
.bind(&req.fqdn)
|
||||
.bind(&req.ip_address)
|
||||
.bind(&req.display_name)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, host_id = %id, "Failed to update host");
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"A host with this FQDN and IP already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
host.map(Json).ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:id/groups ──────────────────────────────────────────────
|
||||
|
||||
async fn list_host_groups(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Vec<Group>>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
let can_access = operator_can_access_host(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !can_access {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let groups: Vec<Group> = sqlx::query_as(
|
||||
r#"SELECT g.id, g.name, g.description, g.created_at, g.updated_at
|
||||
FROM groups g
|
||||
JOIN host_groups hg ON hg.group_id = g.id
|
||||
WHERE hg.host_id = $1
|
||||
ORDER BY g.name"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to list host groups");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(groups))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:id/groups ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AddToGroupRequest {
|
||||
group_id: Uuid,
|
||||
}
|
||||
|
||||
async fn add_host_to_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<AddToGroupRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO host_groups (host_id, group_id) VALUES ($1, $2) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(id)
|
||||
.bind(req.group_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to add host to group");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "group_id": req.group_id, "action": "added" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Host added to group" })))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:id/groups/:group_id ─────────────────────────────────
|
||||
|
||||
async fn remove_host_from_group(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((id, group_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM host_groups WHERE host_id = $1 AND group_id = $2")
|
||||
.bind(id)
|
||||
.bind(group_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to remove host from group");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::GroupMembershipChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("host"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "group_id": group_id, "action": "removed" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Host removed from group" })))
|
||||
}
|
||||
|
||||
// ── FQDN resolution ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Resolve an FQDN (or IP) to its primary IP address.
|
||||
/// If the input is already a valid IP, returns it as-is.
|
||||
async fn resolve_fqdn(fqdn: &str) -> Result<String, String> {
|
||||
use std::net::ToSocketAddrs;
|
||||
// Try direct IP parse first
|
||||
if fqdn.parse::<std::net::IpAddr>().is_ok() {
|
||||
return Ok(fqdn.to_string());
|
||||
}
|
||||
// DNS resolution
|
||||
let addr = format!("{fqdn}:0");
|
||||
match tokio::task::spawn_blocking(move || addr.to_socket_addrs()).await {
|
||||
Ok(Ok(mut addrs)) => addrs
|
||||
.next()
|
||||
.map(|a| a.ip().to_string())
|
||||
.ok_or_else(|| format!("No addresses found for {fqdn}")),
|
||||
_ => Err(format!("Failed to resolve FQDN: {fqdn}")),
|
||||
}
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:id/refresh ───────────────────────────────────────────
|
||||
|
||||
/// Queue an on-demand health + patch refresh for a single host.
|
||||
///
|
||||
/// Sends a PostgreSQL NOTIFY on the `refresh_requested` channel; the
|
||||
/// pm-worker refresh listener picks this up and polls the host immediately.
|
||||
/// Requires Operator or Admin role (any authenticated user).
|
||||
async fn refresh_host(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<(StatusCode, Json<Value>), (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
// Verify the host exists.
|
||||
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "refresh_host: db error checking host existence");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !exists {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// NOTIFY the worker's refresh listener.
|
||||
sqlx::query("SELECT pg_notify('refresh_requested', $1)")
|
||||
.bind(id.to_string())
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "refresh_host: pg_notify failed");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to queue refresh" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(%id, "On-demand refresh queued");
|
||||
|
||||
Ok((
|
||||
StatusCode::ACCEPTED,
|
||||
Json(json!({ "message": "Refresh queued" })),
|
||||
))
|
||||
}
|
||||
677
crates/pm-web/src/routes/jobs.rs
Executable file
677
crates/pm-web/src/routes/jobs.rs
Executable file
@ -0,0 +1,677 @@
|
||||
//! Patch job management routes.
|
||||
//!
|
||||
//! POST /api/v1/jobs — create a new patch job (operator+)
|
||||
//! GET /api/v1/jobs — list jobs with pagination (RBAC scoped)
|
||||
//! GET /api/v1/jobs/{id} — get job detail + per-host status
|
||||
//! POST /api/v1/jobs/{id}/cancel — cancel a queued/pending job (admin or creator)
|
||||
//! POST /api/v1/jobs/{id}/rollback — create a rollback job (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateJobRequest, PatchJobSummary},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_jobs).post(create_job))
|
||||
.route("/{id}", get(get_job))
|
||||
.route("/{id}/cancel", post(cancel_job))
|
||||
.route("/{id}/rollback", post(rollback_job))
|
||||
}
|
||||
|
||||
// ── Query params ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct JobListQuery {
|
||||
pub limit: Option<i64>,
|
||||
pub offset: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Response types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JobListResponse {
|
||||
jobs: Vec<PatchJobSummary>,
|
||||
total: i64,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
}
|
||||
|
||||
/// Per-host row included in `GET /api/v1/jobs/{id}` response.
|
||||
#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
|
||||
struct JobHostRow {
|
||||
pub id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub display_name: String,
|
||||
pub status: String,
|
||||
pub agent_job_id: Option<String>,
|
||||
pub retry_count: i32,
|
||||
pub output: String,
|
||||
pub error_message: Option<String>,
|
||||
pub last_error: Option<String>,
|
||||
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
// ── Error helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── RBAC helper ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Returns `true` when the operator's groups contain at least one host that
|
||||
/// belongs to the given job. Admins always pass this check at the call site.
|
||||
async fn operator_can_access_job(
|
||||
pool: &sqlx::PgPool,
|
||||
user_id: Uuid,
|
||||
job_id: Uuid,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN host_groups hg ON hg.host_id = pjh.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh.job_id = $1
|
||||
AND ug.user_id = $2
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn create_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateJobRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
if req.host_ids.is_empty() {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"host_ids must not be empty",
|
||||
));
|
||||
}
|
||||
|
||||
// Encode package list as JSONB.
|
||||
let patch_selection = serde_json::to_value(&req.packages).unwrap_or(json!([]));
|
||||
let notes = req.notes.clone().unwrap_or_default();
|
||||
|
||||
// Insert the parent job row; the DB NOTIFY trigger fires automatically
|
||||
// when immediate = TRUE (see migration 003_jobs_scheduling.sql).
|
||||
let job_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, created_by_user_id, maintenance_window_id,
|
||||
immediate, patch_selection, notes)
|
||||
VALUES
|
||||
('patch_apply'::job_kind, 'queued'::job_status, $1, $2, $3, $4, $5)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.bind(req.maintenance_window_id)
|
||||
.bind(req.immediate)
|
||||
.bind(&patch_selection)
|
||||
.bind(¬es)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "create_job: insert patch_jobs failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Insert one patch_job_hosts row per requested host.
|
||||
for host_id in &req.host_ids {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
VALUES ($1, $2, 'queued'::job_status)
|
||||
ON CONFLICT (job_id, host_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e, %job_id, %host_id,
|
||||
"create_job: insert patch_job_hosts failed"
|
||||
);
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&job_id.to_string()),
|
||||
json!({
|
||||
"kind": "patch_apply",
|
||||
"immediate": req.immediate,
|
||||
"host_count": req.host_ids.len(),
|
||||
"packages": req.packages,
|
||||
"notes": notes,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
job_id = %job_id,
|
||||
host_count = req.host_ids.len(),
|
||||
immediate = req.immediate,
|
||||
user = %auth.username,
|
||||
"Patch job created"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "id": job_id, "message": "Job created" })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/jobs ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn list_jobs(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Query(q): Query<JobListQuery>,
|
||||
) -> Result<Json<JobListResponse>, (StatusCode, Json<Value>)> {
|
||||
let limit = q.limit.unwrap_or(50).min(200);
|
||||
let offset = q.offset.unwrap_or(0);
|
||||
|
||||
let jobs: Vec<PatchJobSummary> = if auth.role.is_admin() {
|
||||
// Admins see every job.
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.kind,
|
||||
pj.status,
|
||||
pj.immediate,
|
||||
pj.notes,
|
||||
pj.created_at,
|
||||
pj.started_at,
|
||||
pj.completed_at,
|
||||
COUNT(pjh.id) AS host_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
|
||||
FROM patch_jobs pj
|
||||
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
|
||||
GROUP BY pj.id
|
||||
ORDER BY pj.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
} else {
|
||||
// Operators: only jobs where at least one host is in their groups.
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.kind,
|
||||
pj.status,
|
||||
pj.immediate,
|
||||
pj.notes,
|
||||
pj.created_at,
|
||||
pj.started_at,
|
||||
pj.completed_at,
|
||||
COUNT(pjh.id) AS host_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
|
||||
FROM patch_jobs pj
|
||||
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh2
|
||||
JOIN host_groups hg ON hg.host_id = pjh2.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh2.job_id = pj.id
|
||||
AND ug.user_id = $3
|
||||
)
|
||||
GROUP BY pj.id
|
||||
ORDER BY pj.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.bind(auth.user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
}
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "list_jobs: query failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
// Total count for pagination metadata.
|
||||
let total: i64 = if auth.role.is_admin() {
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM patch_jobs")
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(DISTINCT pj.id)
|
||||
FROM patch_jobs pj
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN host_groups hg ON hg.host_id = pjh.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh.job_id = pj.id
|
||||
AND ug.user_id = $1
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
};
|
||||
|
||||
Ok(Json(JobListResponse {
|
||||
jobs,
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
}))
|
||||
}
|
||||
// ── GET /api/v1/jobs/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn get_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// RBAC: operators may only view jobs touching their group's hosts.
|
||||
if !auth.role.is_admin() {
|
||||
let allowed = operator_can_access_job(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !allowed {
|
||||
return Err(err(StatusCode::FORBIDDEN, "forbidden", "Access denied"));
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the job header row as JSON.
|
||||
let job: Option<Value> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT row_to_json(j) FROM (
|
||||
SELECT id, kind, status, created_by_user_id, parent_job_id,
|
||||
maintenance_window_id, immediate, patch_selection, notes,
|
||||
created_at, started_at, completed_at
|
||||
FROM patch_jobs
|
||||
WHERE id = $1
|
||||
) j
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "get_job: failed to fetch job");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
let job = job.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
|
||||
|
||||
// Fetch per-host status rows joined to the host display name.
|
||||
let hosts: Vec<JobHostRow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pjh.id,
|
||||
pjh.host_id,
|
||||
COALESCE(h.display_name, h.fqdn) AS display_name,
|
||||
pjh.status::text AS status,
|
||||
pjh.agent_job_id,
|
||||
pjh.retry_count,
|
||||
pjh.output,
|
||||
pjh.error_message,
|
||||
pjh.last_error,
|
||||
pjh.started_at,
|
||||
pjh.completed_at
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.job_id = $1
|
||||
ORDER BY h.display_name
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "get_job: failed to fetch host rows");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "job": job, "hosts": hosts })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs/:id/cancel ─────────────────────────────────────────────
|
||||
|
||||
async fn cancel_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Fetch the job to verify it exists and check ownership.
|
||||
let row: Option<(String, Option<Uuid>)> =
|
||||
sqlx::query_as("SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: db fetch failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
let (status_str, creator_id) =
|
||||
row.ok_or_else(|| err(StatusCode::NOT_FOUND, "not_found", "Job not found"))?;
|
||||
|
||||
// Only admin or the job creator may cancel.
|
||||
if !auth.role.can_write() {
|
||||
let is_creator = creator_id == Some(auth.user_id);
|
||||
if !is_creator {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Only queued or pending jobs can be cancelled.
|
||||
if status_str != "queued" && status_str != "pending" {
|
||||
return Err(err(
|
||||
StatusCode::CONFLICT,
|
||||
"invalid_state",
|
||||
format!(
|
||||
"Cannot cancel a job in '{}' state; only queued or pending jobs may be cancelled",
|
||||
status_str
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Cancel the parent job.
|
||||
sqlx::query("UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Cancel all queued/pending host rows for this job.
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'cancelled'::job_status
|
||||
WHERE job_id = $1
|
||||
AND status IN ('queued'::job_status, 'pending'::job_status)
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: update patch_job_hosts failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Fire job-level pg_notify so the frontend can update the job row.
|
||||
let notify_payload = json!({
|
||||
"event_type": "job",
|
||||
"job_id": id.to_string(),
|
||||
"host_id": "",
|
||||
"status": "cancelled",
|
||||
"succeeded_count": 0,
|
||||
"failed_count": 0,
|
||||
"host_count": 0,
|
||||
});
|
||||
if let Ok(payload_str) = serde_json::to_string(¬ify_payload) {
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_str)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, %id, "cancel_job: job-level pg_notify failed");
|
||||
} else {
|
||||
tracing::info!(%id, "cancel_job: job-level pg_notify sent");
|
||||
}
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobCancelled,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "previous_status": status_str }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(job_id = %id, user = %auth.username, "Patch job cancelled");
|
||||
Ok(Json(json!({ "message": "Job cancelled" })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs/:id/rollback ────────────────────────────────────────────
|
||||
|
||||
async fn rollback_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Admin-only operation.
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
|
||||
// Verify the original job exists.
|
||||
let original_exists: bool =
|
||||
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM patch_jobs WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "rollback_job: existence check failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !original_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Job not found"));
|
||||
}
|
||||
|
||||
// Gather the host IDs from the original job.
|
||||
let host_ids: Vec<Uuid> =
|
||||
sqlx::query_scalar("SELECT host_id FROM patch_job_hosts WHERE job_id = $1")
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "rollback_job: host fetch failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if host_ids.is_empty() {
|
||||
return Err(err(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"no_hosts",
|
||||
"Original job has no host entries to roll back",
|
||||
));
|
||||
}
|
||||
|
||||
// Create the rollback job row (immediate = true so the worker picks it up
|
||||
// right away and the NOTIFY trigger fires).
|
||||
let rollback_job_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, created_by_user_id, parent_job_id, immediate,
|
||||
patch_selection, notes)
|
||||
VALUES
|
||||
('rollback'::job_kind, 'queued'::job_status, $1, $2, TRUE,
|
||||
'[]'::jsonb, $3)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.bind(id)
|
||||
.bind(format!("Rollback of job {}", id))
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, parent_job_id = %id, "rollback_job: insert failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Replicate host list into the rollback job.
|
||||
for host_id in &host_ids {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
VALUES ($1, $2, 'queued'::job_status)
|
||||
ON CONFLICT (job_id, host_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(rollback_job_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e, %rollback_job_id, %host_id,
|
||||
"rollback_job: insert patch_job_hosts failed"
|
||||
);
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobRollback,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&rollback_job_id.to_string()),
|
||||
json!({
|
||||
"original_job_id": id,
|
||||
"rollback_job_id": rollback_job_id,
|
||||
"host_count": host_ids.len(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
rollback_job_id = %rollback_job_id,
|
||||
original_job_id = %id,
|
||||
user = %auth.username,
|
||||
"Rollback job created"
|
||||
);
|
||||
|
||||
Ok(Json(json!({
|
||||
"id": rollback_job_id,
|
||||
"parent_job_id": id,
|
||||
"message": "Rollback job created"
|
||||
})))
|
||||
}
|
||||
452
crates/pm-web/src/routes/maintenance_windows.rs
Normal file
452
crates/pm-web/src/routes/maintenance_windows.rs
Normal file
@ -0,0 +1,452 @@
|
||||
//! Maintenance window management routes.
|
||||
//!
|
||||
//! GET /api/v1/hosts/{id}/maintenance-windows — list windows for host
|
||||
//! GET /api/v1/maintenance-windows — list ALL windows (bulk)
|
||||
//! POST /api/v1/hosts/{id}/maintenance-windows — create window for host
|
||||
//! PUT /api/v1/hosts/{id}/maintenance-windows/{win_id} — update window
|
||||
//! DELETE /api/v1/hosts/{id}/maintenance-windows/{win_id} — delete window
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, put},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Mount as a nested router under `/hosts/{host_id}/maintenance-windows`.
|
||||
/// Axum will merge the `{host_id}` path segment from the parent nest.
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_windows).post(create_window))
|
||||
.route("/{win_id}", put(update_window).delete(delete_window))
|
||||
}
|
||||
|
||||
/// Top-level router for `/api/v1/maintenance-windows` — bulk list-all endpoint.
|
||||
pub fn all_windows_router() -> Router<AppState> {
|
||||
Router::new().route("/", get(list_all_windows))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/maintenance-windows ──────────────────────────────────────────
|
||||
|
||||
/// Bulk endpoint: return every maintenance window across all hosts.
|
||||
/// Eliminates N+1 queries from the frontend (one request instead of one per host).
|
||||
async fn list_all_windows(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
ORDER BY host_id, created_at ASC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "list_all_windows: query failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "windows": windows })))
|
||||
}
|
||||
|
||||
// ── Error helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:host_id/maintenance-windows ────────────────────────────
|
||||
|
||||
async fn list_windows(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Verify host exists.
|
||||
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "list_windows: host existence check failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !host_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
|
||||
}
|
||||
|
||||
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
WHERE host_id = $1
|
||||
ORDER BY created_at ASC
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "list_windows: query failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "windows": windows })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/maintenance-windows ───────────────────────────
|
||||
|
||||
async fn create_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
Json(req): Json<CreateMaintenanceWindowRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
// Validate: weekly requires recurrence_day 0-6
|
||||
if req.recurrence == pm_core::models::WindowRecurrence::Weekly {
|
||||
match req.recurrence_day {
|
||||
Some(d) if (0..=6).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Weekly recurrence requires recurrence_day 0-6 (0=Sunday)",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate: monthly requires recurrence_day 1-31
|
||||
if req.recurrence == pm_core::models::WindowRecurrence::Monthly {
|
||||
match req.recurrence_day {
|
||||
Some(d) if (1..=31).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Monthly recurrence requires recurrence_day 1-31",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Verify host exists.
|
||||
let host_exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "create_window: host existence check failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !host_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
|
||||
}
|
||||
|
||||
let duration = req.duration_minutes.unwrap_or(60);
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
let auto_apply = req.auto_apply.unwrap_or(true);
|
||||
|
||||
let window: MaintenanceWindow = sqlx::query_as(
|
||||
r#"
|
||||
INSERT INTO maintenance_windows
|
||||
(host_id, label, recurrence, start_at, duration_minutes, recurrence_day, enabled, auto_apply)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(&req.label)
|
||||
.bind(&req.recurrence)
|
||||
.bind(req.start_at)
|
||||
.bind(duration)
|
||||
.bind(req.recurrence_day)
|
||||
.bind(enabled)
|
||||
.bind(auto_apply)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "create_window: insert failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&window.id.to_string()),
|
||||
json!({
|
||||
"host_id": host_id,
|
||||
"label": window.label,
|
||||
"recurrence": window.recurrence.to_string(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %window.id,
|
||||
%host_id,
|
||||
recurrence = %window.recurrence,
|
||||
user = %auth.username,
|
||||
"Maintenance window created"
|
||||
);
|
||||
|
||||
Ok(Json(json!(window)))
|
||||
}
|
||||
|
||||
// ── PUT /api/v1/hosts/:host_id/maintenance-windows/:win_id ───────────────────
|
||||
|
||||
async fn update_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
|
||||
Json(req): Json<UpdateMaintenanceWindowRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
// Fetch existing record (verify ownership and existence).
|
||||
let existing: Option<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
WHERE id = $1 AND host_id = $2
|
||||
"#,
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "update_window: fetch failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
let existing = existing.ok_or_else(|| {
|
||||
err(
|
||||
StatusCode::NOT_FOUND,
|
||||
"not_found",
|
||||
"Maintenance window not found",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Apply partial updates using existing values as defaults.
|
||||
let new_label = req.label.unwrap_or(existing.label);
|
||||
let new_recurrence = req.recurrence.unwrap_or(existing.recurrence);
|
||||
let new_start_at = req.start_at.unwrap_or(existing.start_at);
|
||||
let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes);
|
||||
let new_rec_day = req.recurrence_day.or(existing.recurrence_day);
|
||||
let new_enabled = req.enabled.unwrap_or(existing.enabled);
|
||||
let new_auto_apply = req.auto_apply.unwrap_or(existing.auto_apply);
|
||||
|
||||
// Validate recurrence_day for the final recurrence type.
|
||||
if new_recurrence == pm_core::models::WindowRecurrence::Weekly {
|
||||
match new_rec_day {
|
||||
Some(d) if (0..=6).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Weekly recurrence requires recurrence_day 0-6",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
if new_recurrence == pm_core::models::WindowRecurrence::Monthly {
|
||||
match new_rec_day {
|
||||
Some(d) if (1..=31).contains(&d) => {},
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Monthly recurrence requires recurrence_day 1-31",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let updated: MaintenanceWindow = sqlx::query_as(
|
||||
r#"
|
||||
UPDATE maintenance_windows
|
||||
SET label = $3,
|
||||
recurrence = $4,
|
||||
start_at = $5,
|
||||
duration_minutes = $6,
|
||||
recurrence_day = $7,
|
||||
enabled = $8,
|
||||
auto_apply = $9,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND host_id = $2
|
||||
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, auto_apply, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.bind(&new_label)
|
||||
.bind(&new_recurrence)
|
||||
.bind(new_start_at)
|
||||
.bind(new_duration)
|
||||
.bind(new_rec_day)
|
||||
.bind(new_enabled)
|
||||
.bind(new_auto_apply)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "update_window: update failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&win_id.to_string()),
|
||||
json!({ "host_id": host_id }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %win_id,
|
||||
%host_id,
|
||||
user = %auth.username,
|
||||
"Maintenance window updated"
|
||||
);
|
||||
|
||||
Ok(Json(json!(updated)))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:host_id/maintenance-windows/:win_id ────────────────
|
||||
|
||||
async fn delete_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Write access required",
|
||||
));
|
||||
}
|
||||
let result = sqlx::query("DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2")
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "delete_window: delete failed");
|
||||
err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"internal_error",
|
||||
"Database error",
|
||||
)
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(err(
|
||||
StatusCode::NOT_FOUND,
|
||||
"not_found",
|
||||
"Maintenance window not found",
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&win_id.to_string()),
|
||||
json!({ "host_id": host_id }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %win_id,
|
||||
%host_id,
|
||||
user = %auth.username,
|
||||
"Maintenance window deleted"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "message": "Maintenance window deleted" })))
|
||||
}
|
||||
17
crates/pm-web/src/routes/mod.rs
Executable file
17
crates/pm-web/src/routes/mod.rs
Executable file
@ -0,0 +1,17 @@
|
||||
//! Route modules for the pm-web API.
|
||||
pub mod auth;
|
||||
pub mod ca;
|
||||
pub mod discovery;
|
||||
pub mod enrollment;
|
||||
pub mod groups;
|
||||
pub mod health_checks;
|
||||
pub mod hosts;
|
||||
pub mod jobs;
|
||||
pub mod maintenance_windows;
|
||||
pub mod settings;
|
||||
pub mod sso;
|
||||
pub mod status;
|
||||
pub mod users;
|
||||
pub mod ws;
|
||||
|
||||
pub mod reports;
|
||||
163
crates/pm-web/src/routes/reports.rs
Executable file
163
crates/pm-web/src/routes/reports.rs
Executable file
@ -0,0 +1,163 @@
|
||||
//! Report generation endpoints.
|
||||
//!
|
||||
//! GET /api/v1/reports/compliance?format=csv|pdf&from=...&to=...&group_id=...
|
||||
//! GET /api/v1/reports/patch-history?format=csv|pdf&from=...&to=...
|
||||
//! GET /api/v1/reports/vulnerability?format=csv|pdf&from=...&to=...
|
||||
//! GET /api/v1/reports/audit?format=csv|pdf&from=...&to=...
|
||||
|
||||
use axum::{
|
||||
body::Bytes,
|
||||
extract::{Query, State},
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use pm_reports::{ReportParams, ReportType};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ReportQuery {
|
||||
/// "csv" or "pdf" (defaults to "csv")
|
||||
format: Option<String>,
|
||||
from: Option<chrono::DateTime<chrono::Utc>>,
|
||||
to: Option<chrono::DateTime<chrono::Utc>>,
|
||||
group_id: Option<uuid::Uuid>,
|
||||
}
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/compliance", get(compliance_report))
|
||||
.route("/patch-history", get(patch_history_report))
|
||||
.route("/vulnerability", get(vulnerability_report))
|
||||
.route("/audit", get(audit_report))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_report(
|
||||
db: sqlx::PgPool,
|
||||
params: ReportParams,
|
||||
use_pdf: bool,
|
||||
csv_name: &'static str,
|
||||
pdf_name: &'static str,
|
||||
) -> Response {
|
||||
let (ct, disposition, result) = if use_pdf {
|
||||
let disp = format!("attachment; filename=\"{}\"", pdf_name);
|
||||
let data = pm_reports::generate_pdf(&db, ¶ms).await;
|
||||
("application/pdf", disp, data)
|
||||
} else {
|
||||
let disp = format!("attachment; filename=\"{}\"", csv_name);
|
||||
let data = pm_reports::generate_csv(&db, ¶ms).await;
|
||||
("text/csv; charset=utf-8", disp, data)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(bytes) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::CONTENT_TYPE, HeaderValue::from_static(ct));
|
||||
headers.insert(
|
||||
header::CONTENT_DISPOSITION,
|
||||
HeaderValue::from_str(&disposition)
|
||||
.unwrap_or_else(|_| HeaderValue::from_static("attachment")),
|
||||
);
|
||||
(headers, Bytes::from(bytes)).into_response()
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "report generation failed");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Report error: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn compliance_report(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<ReportQuery>,
|
||||
) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::Compliance,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"compliance-report.csv",
|
||||
"compliance-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn patch_history_report(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<ReportQuery>,
|
||||
) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::PatchHistory,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"patch-history-report.csv",
|
||||
"patch-history-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn vulnerability_report(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<ReportQuery>,
|
||||
) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::Vulnerability,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"vulnerability-report.csv",
|
||||
"vulnerability-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn audit_report(State(state): State<AppState>, Query(q): Query<ReportQuery>) -> Response {
|
||||
let params = ReportParams {
|
||||
report_type: ReportType::Audit,
|
||||
from: q.from,
|
||||
to: q.to,
|
||||
group_id: q.group_id,
|
||||
};
|
||||
let use_pdf = matches!(q.format.as_deref(), Some("pdf"));
|
||||
run_report(
|
||||
state.db,
|
||||
params,
|
||||
use_pdf,
|
||||
"audit-report.csv",
|
||||
"audit-report.pdf",
|
||||
)
|
||||
.await
|
||||
}
|
||||
977
crates/pm-web/src/routes/settings.rs
Executable file
977
crates/pm-web/src/routes/settings.rs
Executable file
@ -0,0 +1,977 @@
|
||||
//! Settings management routes.
|
||||
//!
|
||||
//! GET /api/v1/settings — get all settings (admin only)
|
||||
//! PUT /api/v1/settings — update settings (admin only)
|
||||
//! POST /api/v1/settings/sso/discover — discover OIDC endpoints (admin only)
|
||||
//! POST /api/v1/settings/sso/test — test OIDC provider connectivity (admin only)
|
||||
//! POST /api/v1/settings/azure-sso/test — backward-compat alias for SSO test (admin only)
|
||||
//! POST /api/v1/settings/smtp/test — send test email (admin only)
|
||||
//! GET /api/v1/settings/ip-whitelist — get IP whitelist (admin only)
|
||||
//! PUT /api/v1/settings/ip-whitelist — update IP whitelist (admin only)
|
||||
//! POST /api/v1/settings/audit-integrity — verify audit log integrity (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use lettre::{
|
||||
message::{header::ContentType, Mailbox},
|
||||
transport::smtp::authentication::Credentials,
|
||||
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::audit::{log_event, verify_integrity, AuditAction};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ============================================================
|
||||
// Data structures
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SettingsResponse {
|
||||
pub oidc: OidcConfigResponse,
|
||||
pub smtp: SmtpConfig,
|
||||
pub polling: PollingConfig,
|
||||
pub ip_whitelist: Vec<String>,
|
||||
pub web_tls_strategy: String,
|
||||
pub notification: NotificationConfig,
|
||||
pub sso_callback_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OidcConfigResponse {
|
||||
pub enabled: bool,
|
||||
pub provider_type: String, // "keycloak", "azure", "custom"
|
||||
pub display_name: String,
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String, // Always masked in responses
|
||||
pub redirect_uri: String,
|
||||
pub scopes: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SmtpConfig {
|
||||
pub enabled: bool,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub username: String,
|
||||
pub from: String,
|
||||
pub tls_mode: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PollingConfig {
|
||||
pub health_poll_interval_secs: u64,
|
||||
pub patch_poll_interval_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateSettingsRequest {
|
||||
pub oidc: Option<OidcConfigUpdate>,
|
||||
pub smtp: Option<SmtpConfigUpdate>,
|
||||
pub polling: Option<PollingConfigUpdate>,
|
||||
pub ip_whitelist: Option<Vec<String>>,
|
||||
pub web_tls_strategy: Option<String>,
|
||||
pub notification: Option<NotificationConfigUpdate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct NotificationConfig {
|
||||
pub email_enabled: bool,
|
||||
pub email_from: String,
|
||||
pub recipients: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct NotificationConfigUpdate {
|
||||
pub email_enabled: Option<bool>,
|
||||
pub email_from: Option<String>,
|
||||
pub recipients: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub provider_type: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub discovery_url: Option<String>,
|
||||
pub client_id: Option<String>,
|
||||
pub client_secret: Option<String>,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub scopes: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OidcDiscoveryRequest {
|
||||
pub discovery_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct OidcDiscoveryResult {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmtpConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub host: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub username: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub from: Option<String>,
|
||||
pub tls_mode: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PollingConfigUpdate {
|
||||
pub health_poll_interval_secs: Option<u64>,
|
||||
pub patch_poll_interval_secs: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IpWhitelistUpdate {
|
||||
pub entries: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Router
|
||||
// ============================================================
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(get_settings).put(update_settings))
|
||||
.route("/sso/discover", post(discover_oidc))
|
||||
.route("/sso/test", post(test_oidc))
|
||||
.route("/azure-sso/test", post(test_azure_sso_compat))
|
||||
.route("/smtp/test", post(test_smtp))
|
||||
.route(
|
||||
"/ip-whitelist",
|
||||
get(get_ip_whitelist).put(update_ip_whitelist),
|
||||
)
|
||||
.route("/audit-integrity", post(audit_integrity))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Helpers
|
||||
// ============================================================
|
||||
|
||||
const MASKED: &str = "********";
|
||||
|
||||
fn write_access_required(auth: &AuthUser) -> Result<(), (StatusCode, Json<Value>)> {
|
||||
if !auth.role.can_write() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Write access required" } })),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_system_config(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Result<HashMap<String, String>, (StatusCode, Json<Value>)> {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as("SELECT key, value FROM system_config")
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load system_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(rows.into_iter().collect())
|
||||
}
|
||||
|
||||
fn build_settings_response(
|
||||
cfg: &HashMap<String, String>,
|
||||
oidc: OidcConfigResponse,
|
||||
) -> SettingsResponse {
|
||||
let get = |key: &str| -> String { cfg.get(key).cloned().unwrap_or_default() };
|
||||
|
||||
let recipients: Vec<String> =
|
||||
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
|
||||
|
||||
SettingsResponse {
|
||||
oidc,
|
||||
smtp: SmtpConfig {
|
||||
enabled: get("smtp_enabled") == "true",
|
||||
host: get("smtp_host"),
|
||||
port: get("smtp_port").parse().unwrap_or(587),
|
||||
username: get("smtp_username"),
|
||||
from: get("smtp_from"),
|
||||
tls_mode: get("smtp_tls_mode"),
|
||||
},
|
||||
polling: PollingConfig {
|
||||
health_poll_interval_secs: get("health_poll_interval_secs").parse().unwrap_or(300),
|
||||
patch_poll_interval_secs: get("patch_poll_interval_secs").parse().unwrap_or(1800),
|
||||
},
|
||||
ip_whitelist: serde_json::from_str(&get("ip_whitelist")).unwrap_or_default(),
|
||||
web_tls_strategy: get("web_tls_strategy"),
|
||||
notification: NotificationConfig {
|
||||
email_enabled: get("notification_email_enabled") == "true",
|
||||
email_from: get("notification_email_from"),
|
||||
recipients,
|
||||
},
|
||||
sso_callback_url: get("sso_callback_url"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_config_key(
|
||||
pool: &sqlx::PgPool,
|
||||
key: &str,
|
||||
value: &str,
|
||||
) -> Result<(), (StatusCode, Json<Value>)> {
|
||||
sqlx::query("UPDATE system_config SET value = $1, updated_at = NOW() WHERE key = $2")
|
||||
.bind(value)
|
||||
.bind(key)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, key, "Failed to update system_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_oidc_config(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Result<OidcConfigResponse, (StatusCode, Json<Value>)> {
|
||||
let row: Option<(bool, String, String, String, String, String, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(match row {
|
||||
Some((
|
||||
enabled,
|
||||
provider_type,
|
||||
display_name,
|
||||
discovery_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
scopes,
|
||||
)) => OidcConfigResponse {
|
||||
enabled,
|
||||
provider_type,
|
||||
display_name,
|
||||
discovery_url,
|
||||
client_id,
|
||||
client_secret: if client_secret.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
MASKED.to_string()
|
||||
},
|
||||
redirect_uri,
|
||||
scopes,
|
||||
},
|
||||
None => OidcConfigResponse {
|
||||
enabled: false,
|
||||
provider_type: "azure".to_string(),
|
||||
display_name: "Azure AD".to_string(),
|
||||
discovery_url: String::new(),
|
||||
client_id: String::new(),
|
||||
client_secret: String::new(),
|
||||
redirect_uri: String::new(),
|
||||
scopes: "openid profile email".to_string(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/settings
|
||||
// ============================================================
|
||||
|
||||
async fn get_settings(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
let cfg = load_system_config(&state.db).await?;
|
||||
// Inject read-only config values from TOML file (not stored in DB)
|
||||
let mut cfg = cfg;
|
||||
cfg.insert(
|
||||
"sso_callback_url".to_string(),
|
||||
state.config.security.sso_callback_url.clone(),
|
||||
);
|
||||
let oidc = fetch_oidc_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, oidc)))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/settings
|
||||
// ============================================================
|
||||
|
||||
async fn update_settings(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<UpdateSettingsRequest>,
|
||||
) -> Result<Json<SettingsResponse>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
// Update OIDC config
|
||||
if let Some(oidc) = req.oidc {
|
||||
let update_secret = oidc
|
||||
.client_secret
|
||||
.as_ref()
|
||||
.is_some_and(|s| s != MASKED && !s.is_empty());
|
||||
|
||||
let result = if update_secret {
|
||||
sqlx::query(
|
||||
"UPDATE oidc_config SET \
|
||||
enabled = COALESCE($1, enabled), \
|
||||
provider_type = COALESCE($2, provider_type), \
|
||||
display_name = COALESCE($3, display_name), \
|
||||
discovery_url = COALESCE($4, discovery_url), \
|
||||
client_id = COALESCE($5, client_id), \
|
||||
client_secret = $6, \
|
||||
redirect_uri = COALESCE($7, redirect_uri), \
|
||||
scopes = COALESCE($8, scopes), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = 1",
|
||||
)
|
||||
.bind(oidc.enabled)
|
||||
.bind(&oidc.provider_type)
|
||||
.bind(&oidc.display_name)
|
||||
.bind(&oidc.discovery_url)
|
||||
.bind(&oidc.client_id)
|
||||
.bind(oidc.client_secret.as_deref().unwrap_or(""))
|
||||
.bind(&oidc.redirect_uri)
|
||||
.bind(&oidc.scopes)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE oidc_config SET \
|
||||
enabled = COALESCE($1, enabled), \
|
||||
provider_type = COALESCE($2, provider_type), \
|
||||
display_name = COALESCE($3, display_name), \
|
||||
discovery_url = COALESCE($4, discovery_url), \
|
||||
client_id = COALESCE($5, client_id), \
|
||||
redirect_uri = COALESCE($6, redirect_uri), \
|
||||
scopes = COALESCE($7, scopes), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = 1",
|
||||
)
|
||||
.bind(oidc.enabled)
|
||||
.bind(&oidc.provider_type)
|
||||
.bind(&oidc.display_name)
|
||||
.bind(&oidc.discovery_url)
|
||||
.bind(&oidc.client_id)
|
||||
.bind(&oidc.redirect_uri)
|
||||
.bind(&oidc.scopes)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
result.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": format!("Failed to update OIDC config: {}", e) } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("oidc"),
|
||||
Some("1"),
|
||||
json!({ "section": "oidc" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update SMTP config
|
||||
if let Some(smtp) = &req.smtp {
|
||||
if let Some(v) = smtp.enabled {
|
||||
update_config_key(&state.db, "smtp_enabled", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.host {
|
||||
update_config_key(&state.db, "smtp_host", v).await?;
|
||||
}
|
||||
if let Some(v) = smtp.port {
|
||||
update_config_key(&state.db, "smtp_port", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.username {
|
||||
update_config_key(&state.db, "smtp_username", v).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.password {
|
||||
if v != MASKED {
|
||||
update_config_key(&state.db, "smtp_password", v).await?;
|
||||
}
|
||||
}
|
||||
if let Some(ref v) = smtp.from {
|
||||
update_config_key(&state.db, "smtp_from", v).await?;
|
||||
}
|
||||
if let Some(ref v) = smtp.tls_mode {
|
||||
update_config_key(&state.db, "smtp_tls_mode", v).await?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("smtp"),
|
||||
Some("system_config"),
|
||||
json!({ "section": "smtp" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update polling config
|
||||
if let Some(polling) = &req.polling {
|
||||
if let Some(v) = polling.health_poll_interval_secs {
|
||||
update_config_key(&state.db, "health_poll_interval_secs", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(v) = polling.patch_poll_interval_secs {
|
||||
update_config_key(&state.db, "patch_poll_interval_secs", &v.to_string()).await?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("polling"),
|
||||
Some("system_config"),
|
||||
json!({ "section": "polling" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update IP whitelist
|
||||
if let Some(ref entries) = req.ip_whitelist {
|
||||
let json_str = serde_json::to_string(entries).unwrap_or_else(|_| "[]".to_string());
|
||||
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
|
||||
|
||||
// Update in-memory AuthConfig for immediate enforcement
|
||||
state.auth_config.update_ip_whitelist(entries.clone());
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("ip_whitelist"),
|
||||
Some("system_config"),
|
||||
json!({ "entries": entries }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update web TLS strategy
|
||||
if let Some(ref v) = req.web_tls_strategy {
|
||||
update_config_key(&state.db, "web_tls_strategy", v).await?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("web_tls_strategy"),
|
||||
Some("system_config"),
|
||||
json!({ "web_tls_strategy": v }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update notification config
|
||||
if let Some(notif) = &req.notification {
|
||||
if let Some(v) = notif.email_enabled {
|
||||
update_config_key(&state.db, "notification_email_enabled", &v.to_string()).await?;
|
||||
}
|
||||
if let Some(ref v) = notif.email_from {
|
||||
update_config_key(&state.db, "notification_email_from", v).await?;
|
||||
}
|
||||
if let Some(ref v) = notif.recipients {
|
||||
let json_str = serde_json::to_string(v).unwrap_or_else(|_| "[]".to_string());
|
||||
update_config_key(&state.db, "notification_email_recipients", &json_str).await?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("notification"),
|
||||
Some("system_config"),
|
||||
json!({ "section": "notification" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Return updated settings
|
||||
let cfg = load_system_config(&state.db).await?;
|
||||
// Inject read-only config values from TOML file (not stored in DB)
|
||||
let mut cfg = cfg;
|
||||
cfg.insert(
|
||||
"sso_callback_url".to_string(),
|
||||
state.config.security.sso_callback_url.clone(),
|
||||
);
|
||||
let oidc = fetch_oidc_config(&state.db).await?;
|
||||
Ok(Json(build_settings_response(&cfg, oidc)))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/sso/discover
|
||||
// ============================================================
|
||||
|
||||
async fn discover_oidc(
|
||||
State(_state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<OidcDiscoveryRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
if req.discovery_url.is_empty() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": "discovery_url is required" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
match client.get(&req.discovery_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: Value = resp.json().await.unwrap_or(json!({}));
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"issuer": body.get("issuer").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"authorization_endpoint": body.get("authorization_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"token_endpoint": body.get("token_endpoint").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"jwks_uri": body.get("jwks_uri").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"userinfo_endpoint": body.get("userinfo_endpoint").and_then(|v| v.as_str()),
|
||||
})))
|
||||
},
|
||||
Ok(resp) => Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({ "error": { "code": "discovery_failed", "message": format!("Discovery endpoint returned HTTP {}", resp.status()) } }),
|
||||
),
|
||||
)),
|
||||
Err(e) => Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
Json(
|
||||
json!({ "error": { "code": "discovery_failed", "message": format!("Failed to reach discovery endpoint: {}", e) } }),
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/sso/test
|
||||
// ============================================================
|
||||
|
||||
async fn test_oidc(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let row: Option<(bool, String, String)> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, discovery_url FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (enabled, provider_type, discovery_url) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC is not configured"
|
||||
})));
|
||||
},
|
||||
};
|
||||
|
||||
if !enabled {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC is not enabled"
|
||||
})));
|
||||
}
|
||||
|
||||
if discovery_url.is_empty() {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "OIDC discovery URL is not set"
|
||||
})));
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "HTTP client error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
match client.get(&discovery_url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: Value = resp.json().await.unwrap_or(json!({}));
|
||||
let issuer = body.get("issuer").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let provider_label = match provider_type.as_str() {
|
||||
"keycloak" => "Keycloak",
|
||||
"azure" => "Azure AD",
|
||||
_ => "OIDC",
|
||||
};
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": format!("{} provider verified successfully", provider_label),
|
||||
"issuer": issuer,
|
||||
"provider_type": provider_type,
|
||||
})))
|
||||
},
|
||||
Ok(resp) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to reach OIDC provider: HTTP {}", resp.status())
|
||||
}))),
|
||||
Err(e) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to reach OIDC provider: {}", e)
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/azure-sso/test (backward-compatible alias)
|
||||
// ============================================================
|
||||
|
||||
async fn test_azure_sso_compat(
|
||||
state: State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
test_oidc(state, auth).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/smtp/test
|
||||
// ============================================================
|
||||
|
||||
async fn test_smtp(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let cfg = load_system_config(&state.db).await?;
|
||||
|
||||
let smtp_enabled = cfg.get("smtp_enabled").map(|v| v.as_str()) == Some("true");
|
||||
if !smtp_enabled {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "SMTP is not enabled"
|
||||
})));
|
||||
}
|
||||
|
||||
let host = cfg.get("smtp_host").cloned().unwrap_or_default();
|
||||
let port: u16 = cfg
|
||||
.get("smtp_port")
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(587);
|
||||
let username = cfg.get("smtp_username").cloned().unwrap_or_default();
|
||||
let password = cfg.get("smtp_password").cloned().unwrap_or_default();
|
||||
let from_addr = cfg.get("smtp_from").cloned().unwrap_or_default();
|
||||
let tls_mode = cfg
|
||||
.get("smtp_tls_mode")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "starttls".to_string());
|
||||
|
||||
let recipients_str = cfg
|
||||
.get("notification_email_recipients")
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let recipients: Vec<String> = serde_json::from_str(&recipients_str).unwrap_or_default();
|
||||
|
||||
if host.is_empty() || from_addr.is_empty() {
|
||||
return Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": "SMTP host or from address is not configured"
|
||||
})));
|
||||
}
|
||||
|
||||
let result = send_smtp_test(
|
||||
&host,
|
||||
port,
|
||||
&username,
|
||||
&password,
|
||||
&from_addr,
|
||||
&tls_mode,
|
||||
&recipients,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
let recipient_info = if recipients.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" and {} recipient(s)", recipients.len())
|
||||
};
|
||||
Ok(Json(json!({
|
||||
"success": true,
|
||||
"message": format!("Test email sent successfully to from address{}", recipient_info)
|
||||
})))
|
||||
},
|
||||
Err(e) => Ok(Json(json!({
|
||||
"success": false,
|
||||
"message": format!("Failed to send test email: {}", e)
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_smtp_test(
|
||||
host: &str,
|
||||
port: u16,
|
||||
username: &str,
|
||||
password: &str,
|
||||
from_addr: &str,
|
||||
tls_mode: &str,
|
||||
recipients: &[String],
|
||||
) -> Result<(), String> {
|
||||
let from_mailbox: Mailbox = from_addr
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid from address: {}", e))?;
|
||||
|
||||
let mut builder = Message::builder()
|
||||
.from(from_mailbox.clone())
|
||||
.to(from_mailbox);
|
||||
|
||||
for recipient in recipients {
|
||||
if let Ok(addr) = recipient.parse() {
|
||||
builder = builder.bcc(addr);
|
||||
}
|
||||
}
|
||||
|
||||
let body = if recipients.is_empty() {
|
||||
"This is a test email from Linux Patch Manager.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"This is a test email from Linux Patch Manager.\n\nSent to: {}",
|
||||
recipients.join(", ")
|
||||
)
|
||||
};
|
||||
|
||||
let email = builder
|
||||
.subject("Linux Patch Manager — SMTP Test")
|
||||
.header(ContentType::TEXT_PLAIN)
|
||||
.body(body)
|
||||
.map_err(|e| format!("Failed to build email: {}", e))?;
|
||||
|
||||
let result = match tls_mode {
|
||||
"tls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(host)
|
||||
.map_err(|e| format!("TLS relay error: {}", e))?;
|
||||
builder = builder.port(port);
|
||||
if !username.is_empty() {
|
||||
builder = builder
|
||||
.credentials(Credentials::new(username.to_string(), password.to_string()));
|
||||
}
|
||||
let transport = builder.build();
|
||||
transport.send(email).await
|
||||
},
|
||||
"starttls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(host)
|
||||
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
|
||||
builder = builder.port(port);
|
||||
if !username.is_empty() {
|
||||
builder = builder
|
||||
.credentials(Credentials::new(username.to_string(), password.to_string()));
|
||||
}
|
||||
let transport = builder.build();
|
||||
transport.send(email).await
|
||||
},
|
||||
_ => {
|
||||
// "none" — plaintext / no TLS
|
||||
let mut builder =
|
||||
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(host).port(port);
|
||||
if !username.is_empty() {
|
||||
builder = builder
|
||||
.credentials(Credentials::new(username.to_string(), password.to_string()));
|
||||
}
|
||||
let transport = builder.build();
|
||||
transport.send(email).await
|
||||
},
|
||||
};
|
||||
|
||||
result
|
||||
.map(|_| ())
|
||||
.map_err(|e| format!("SMTP send error: {}", e))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/settings/ip-whitelist
|
||||
// ============================================================
|
||||
|
||||
async fn get_ip_whitelist(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let value: Option<String> = sqlx::query_scalar(
|
||||
"SELECT value FROM system_config WHERE key = 'ip_whitelist'",
|
||||
)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load ip_whitelist");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let entries: Vec<String> = serde_json::from_str(&value.unwrap_or_default()).unwrap_or_default();
|
||||
Ok(Json(json!({ "entries": entries })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/settings/ip-whitelist
|
||||
// ============================================================
|
||||
|
||||
async fn update_ip_whitelist(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<IpWhitelistUpdate>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
// Validate each entry
|
||||
for entry in &req.entries {
|
||||
if entry.parse::<ipnet::IpNet>().is_err() && entry.parse::<std::net::IpAddr>().is_err() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": format!("Invalid CIDR or IP: {}", entry) } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let json_str = serde_json::to_string(&req.entries).unwrap_or_else(|_| "[]".to_string());
|
||||
update_config_key(&state.db, "ip_whitelist", &json_str).await?;
|
||||
|
||||
// Update in-memory AuthConfig for immediate enforcement
|
||||
state.auth_config.update_ip_whitelist(req.entries.clone());
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::ConfigChanged,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("ip_whitelist"),
|
||||
Some("system_config"),
|
||||
json!({ "entries": req.entries }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
Ok(Json(json!({ "entries": req.entries })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/settings/audit-integrity
|
||||
// ============================================================
|
||||
|
||||
/// Verify audit log hash chain integrity.
|
||||
/// Returns whether the chain is intact, rows checked, and any errors.
|
||||
async fn audit_integrity(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
write_access_required(&auth)?;
|
||||
|
||||
let result = verify_integrity(&state.db).await;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::AuditIntegrityVerified,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("audit_log"),
|
||||
None,
|
||||
json!({
|
||||
"intact": result.intact,
|
||||
"rows_checked": result.rows_checked,
|
||||
"error_count": result.errors.len(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({
|
||||
"intact": result.intact,
|
||||
"rows_checked": result.rows_checked,
|
||||
"errors": result.errors.iter().map(|e| json!({
|
||||
"row_id": e.row_id,
|
||||
"expected_hash": e.expected_hash,
|
||||
"actual_hash": e.actual_hash,
|
||||
})).collect::<Vec<_>>(),
|
||||
})))
|
||||
}
|
||||
838
crates/pm-web/src/routes/sso.rs
Executable file
838
crates/pm-web/src/routes/sso.rs
Executable file
@ -0,0 +1,838 @@
|
||||
//! Generic OIDC SSO routes (Keycloak, Azure AD, Custom).
|
||||
//!
|
||||
//! Public routes (no auth required):
|
||||
//! GET /api/v1/auth/sso/login — redirect to OIDC provider authorization URL
|
||||
//! GET /api/v1/auth/sso/callback — handle OIDC provider callback, redirect to frontend SPA
|
||||
//!
|
||||
//! Backward-compatible aliases:
|
||||
//! GET /api/v1/auth/azure/login → redirects to generic SSO login
|
||||
//! GET /api/v1/auth/azure/callback → redirects to generic SSO callback
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json, Redirect},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
|
||||
use pm_auth::{jwt::issue_access_token, refresh};
|
||||
use pm_core::audit::{log_event, AuditAction};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ============================================================
|
||||
// Data structures
|
||||
// ============================================================
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SsoSession {
|
||||
pub code_verifier: String,
|
||||
pub created_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
#[allow(dead_code)]
|
||||
access_token: Option<String>,
|
||||
id_token: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
expires_in: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IdTokenClaims {
|
||||
email: Option<String>,
|
||||
name: Option<String>,
|
||||
sub: Option<String>,
|
||||
oid: Option<String>,
|
||||
preferred_username: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct DbUserForSso {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
display_name: String,
|
||||
role: String,
|
||||
is_active: bool,
|
||||
mfa_enabled: bool,
|
||||
}
|
||||
|
||||
/// OIDC provider configuration from database.
|
||||
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||
pub struct OidcConfig {
|
||||
pub enabled: bool,
|
||||
pub provider_type: String,
|
||||
pub display_name: String,
|
||||
pub discovery_url: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub redirect_uri: String,
|
||||
pub scopes: String,
|
||||
}
|
||||
|
||||
/// Cached OIDC discovery document.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OidcDiscovery {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>,
|
||||
pub fetched_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Cache for OIDC discovery documents and JWKS with TTL-based refresh.
|
||||
#[derive(Default)]
|
||||
pub struct OidcCache {
|
||||
pub discovery: Option<OidcDiscovery>,
|
||||
pub jwks: Option<serde_json::Value>,
|
||||
pub jwks_fetched_at: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// JWKS cache TTL in seconds (1 hour).
|
||||
const JWKS_CACHE_TTL_SECS: i64 = 3600;
|
||||
/// Discovery cache TTL in seconds (1 hour).
|
||||
const DISCOVERY_CACHE_TTL_SECS: i64 = 3600;
|
||||
|
||||
// ============================================================
|
||||
// Router
|
||||
// ============================================================
|
||||
|
||||
pub fn public_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", get(sso_login))
|
||||
.route("/callback", get(sso_callback))
|
||||
.route("/config", get(sso_config))
|
||||
}
|
||||
|
||||
/// Backward-compatible Azure SSO routes — redirect to generic SSO endpoints.
|
||||
pub fn azure_compat_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/login", get(azure_login_redirect))
|
||||
.route("/callback", get(azure_callback_redirect))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/sso/config
|
||||
// ============================================================
|
||||
|
||||
/// Public endpoint returning minimal SSO configuration for the login page.
|
||||
/// Returns only: enabled, display_name, auth_url — no secrets exposed.
|
||||
async fn sso_config(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
// If we can't load config, SSO is effectively disabled
|
||||
return Ok(Json(json!({
|
||||
"enabled": false,
|
||||
"display_name": "SSO",
|
||||
"auth_url": ""
|
||||
})));
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Json(json!({
|
||||
"enabled": config.enabled,
|
||||
"display_name": if config.display_name.is_empty() { "SSO".to_string() } else { config.display_name },
|
||||
"auth_url": "/api/v1/auth/sso/login"
|
||||
})))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/sso/login
|
||||
// ============================================================
|
||||
|
||||
async fn sso_login(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
let config = load_oidc_config(&state.db).await?;
|
||||
|
||||
if !config.enabled {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "SSO is not enabled" } })),
|
||||
));
|
||||
}
|
||||
|
||||
if config.discovery_url.is_empty() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "SSO discovery URL is not configured" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch OIDC discovery document (with caching)
|
||||
let discovery = match fetch_discovery(&state).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
json!({ "error": { "code": "internal_error", "message": format!("Failed to fetch OIDC discovery: {}", e) } }),
|
||||
),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Generate PKCE code_verifier (32 random bytes → base64url)
|
||||
let mut verifier_bytes = [0u8; 32];
|
||||
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut verifier_bytes);
|
||||
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
|
||||
|
||||
// code_challenge = BASE64URL(SHA256(code_verifier))
|
||||
let challenge_digest = Sha256::digest(code_verifier.as_bytes());
|
||||
let code_challenge = URL_SAFE_NO_PAD.encode(challenge_digest);
|
||||
|
||||
// Generate state token
|
||||
let state_token = Uuid::new_v4().to_string();
|
||||
|
||||
// Store (state_token, code_verifier) in sso_sessions DashMap
|
||||
state.sso_sessions.insert(
|
||||
state_token.clone(),
|
||||
SsoSession {
|
||||
code_verifier,
|
||||
created_at: Utc::now(),
|
||||
},
|
||||
);
|
||||
|
||||
// Build authorization URL from discovery
|
||||
let encoded_scopes = urlencoding::encode(&config.scopes);
|
||||
let auth_url = format!(
|
||||
"{}?client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}",
|
||||
discovery.authorization_endpoint,
|
||||
urlencoding::encode(&config.client_id),
|
||||
urlencoding::encode(&config.redirect_uri),
|
||||
encoded_scopes,
|
||||
code_challenge,
|
||||
state_token
|
||||
);
|
||||
|
||||
Ok(Redirect::to(&auth_url))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/auth/sso/callback
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: Option<String>,
|
||||
state: Option<String>,
|
||||
error: Option<String>,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
async fn sso_callback(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<Redirect, Redirect> {
|
||||
let callback_url = &state.config.security.sso_callback_url;
|
||||
|
||||
let error_redirect = |code: &str, message: &str| -> Redirect {
|
||||
let url = format!(
|
||||
"{}?error={}&error_description={}",
|
||||
callback_url,
|
||||
urlencoding::encode(code),
|
||||
urlencoding::encode(message)
|
||||
);
|
||||
Redirect::to(&url)
|
||||
};
|
||||
|
||||
if let Some(error) = params.error {
|
||||
let desc = params.error_description.unwrap_or_default();
|
||||
let message = format!("OIDC provider error: {} - {}", error, desc);
|
||||
return Err(error_redirect("sso_error", &message));
|
||||
}
|
||||
|
||||
let code = match params.code {
|
||||
Some(c) => c,
|
||||
None => return Err(error_redirect("bad_request", "Missing authorization code")),
|
||||
};
|
||||
|
||||
let state_token = match params.state {
|
||||
Some(s) => s,
|
||||
None => return Err(error_redirect("bad_request", "Missing state parameter")),
|
||||
};
|
||||
|
||||
let sso_session = match state.sso_sessions.remove(&state_token).map(|(_, v)| v) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Err(error_redirect(
|
||||
"bad_request",
|
||||
"Invalid or expired state token",
|
||||
))
|
||||
},
|
||||
};
|
||||
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to load OIDC config",
|
||||
))
|
||||
},
|
||||
};
|
||||
|
||||
let discovery = match fetch_discovery(&state).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to fetch OIDC discovery");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to fetch OIDC discovery",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
// Exchange code for tokens
|
||||
let client = match reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to build HTTP client");
|
||||
return Err(error_redirect("internal_error", "HTTP client error"));
|
||||
},
|
||||
};
|
||||
|
||||
let mut params_vec: Vec<(&str, String)> = vec![
|
||||
("grant_type", "authorization_code".to_string()),
|
||||
("code", code.clone()),
|
||||
("redirect_uri", config.redirect_uri.clone()),
|
||||
("client_id", config.client_id.clone()),
|
||||
("code_verifier", sso_session.code_verifier.clone()),
|
||||
];
|
||||
|
||||
// For confidential clients (Azure AD), include client_secret
|
||||
if !config.client_secret.is_empty() {
|
||||
params_vec.push(("client_secret", config.client_secret.clone()));
|
||||
}
|
||||
|
||||
let token_resp = match client
|
||||
.post(&discovery.token_endpoint)
|
||||
.form(¶ms_vec)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Token exchange request failed");
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
&format!("Token exchange failed: {}", e),
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
if !token_resp.status().is_success() {
|
||||
let status = token_resp.status();
|
||||
let body = token_resp.text().await.unwrap_or_default();
|
||||
tracing::error!(status = %status, body = %body, "Token exchange failed");
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
&format!("Token exchange failed: HTTP {}", status),
|
||||
));
|
||||
}
|
||||
|
||||
let token_data: TokenResponse = match token_resp.json().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to parse token response");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to parse token response",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
let id_token = match token_data.id_token {
|
||||
Some(t) => t,
|
||||
None => return Err(error_redirect("sso_error", "No id_token in response")),
|
||||
};
|
||||
|
||||
let claims = match verify_id_token(&id_token, &config, &discovery, &state.oidc_cache).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to verify id_token");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Failed to verify id_token",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
let email = claims.email.unwrap_or_default();
|
||||
let name = claims.name.unwrap_or_default();
|
||||
let oidc_sub = claims.sub.unwrap_or_default();
|
||||
let azure_oid = claims.oid.unwrap_or_default();
|
||||
let preferred_username = claims.preferred_username.unwrap_or_else(|| email.clone());
|
||||
|
||||
let provider_subject = if !oidc_sub.is_empty() {
|
||||
oidc_sub.clone()
|
||||
} else if !azure_oid.is_empty() {
|
||||
azure_oid.clone()
|
||||
} else {
|
||||
return Err(error_redirect(
|
||||
"sso_error",
|
||||
"Missing subject identifier in id_token",
|
||||
));
|
||||
};
|
||||
|
||||
if email.is_empty() {
|
||||
return Err(error_redirect("sso_error", "Missing email in id_token"));
|
||||
}
|
||||
|
||||
let auth_provider = match config.provider_type.as_str() {
|
||||
"keycloak" => "keycloak",
|
||||
"azure" => "azure_sso",
|
||||
_ => "oidc",
|
||||
};
|
||||
|
||||
// First try exact match: email AND auth_provider
|
||||
let user_opt: Option<DbUserForSso> = match sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
|
||||
FROM users WHERE email = $1 AND auth_provider = $2::auth_provider"#,
|
||||
)
|
||||
.bind(&email)
|
||||
.bind(auth_provider)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to look up SSO user");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
},
|
||||
};
|
||||
|
||||
let user = match user_opt {
|
||||
Some(u) if !u.is_active => {
|
||||
return Err(error_redirect("account_disabled", "Account is disabled"));
|
||||
},
|
||||
Some(u) => u,
|
||||
None => {
|
||||
// Try to find existing user by email alone (may have different auth_provider)
|
||||
let existing_user: Option<DbUserForSso> = match sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, role::text as role, is_active, mfa_enabled
|
||||
FROM users WHERE email = $1"#,
|
||||
)
|
||||
.bind(&email)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to look up existing user by email");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
},
|
||||
};
|
||||
|
||||
match existing_user {
|
||||
Some(existing) if !existing.is_active => {
|
||||
return Err(error_redirect("account_disabled", "Account is disabled"));
|
||||
},
|
||||
Some(existing) => {
|
||||
// Link existing local user to SSO provider
|
||||
tracing::info!(user_id = %existing.id, "Linking existing user to SSO provider");
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE users SET auth_provider = $1::auth_provider, azure_oid = COALESCE(azure_oid, $2), oidc_sub = COALESCE(oidc_sub, $3) WHERE id = $4",
|
||||
)
|
||||
.bind(auth_provider)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.bind(existing.id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, "Failed to link user to SSO provider");
|
||||
return Err(error_redirect("internal_error", "Failed to link SSO account"));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
None,
|
||||
Some(auth_provider),
|
||||
Some("user"),
|
||||
Some(&existing.id.to_string()),
|
||||
json!({ "action": "sso_link", "auth_provider": auth_provider, "email": email }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
DbUserForSso {
|
||||
id: existing.id,
|
||||
username: existing.username.clone(),
|
||||
display_name: if name.is_empty() {
|
||||
existing.display_name.clone()
|
||||
} else {
|
||||
name
|
||||
},
|
||||
role: existing.role.clone(),
|
||||
is_active: existing.is_active,
|
||||
mfa_enabled: existing.mfa_enabled,
|
||||
}
|
||||
},
|
||||
None => {
|
||||
// No existing user - create new one
|
||||
let id: Uuid = match sqlx::query_scalar(
|
||||
r#"INSERT INTO users (username, display_name, email, role, auth_provider, azure_oid, oidc_sub)
|
||||
VALUES ($1, $2, $3, 'reporter'::user_role, $4::auth_provider, $5, $6)
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(&preferred_username)
|
||||
.bind(&name)
|
||||
.bind(&email)
|
||||
.bind(auth_provider)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to create SSO user");
|
||||
return Err(error_redirect("internal_error", "Failed to create user"));
|
||||
},
|
||||
};
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
None,
|
||||
Some(auth_provider),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "auth_provider": auth_provider, "email": email }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
DbUserForSso {
|
||||
id,
|
||||
username: preferred_username,
|
||||
display_name: name,
|
||||
role: "reporter".to_string(),
|
||||
is_active: true,
|
||||
mfa_enabled: false,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Update last_login_at and provider subject IDs
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE users SET last_login_at = NOW(), azure_oid = COALESCE(azure_oid, $1), oidc_sub = COALESCE(oidc_sub, $2) WHERE id = $3",
|
||||
)
|
||||
.bind(if azure_oid.is_empty() { None } else { Some(azure_oid.as_str()) })
|
||||
.bind(if provider_subject.is_empty() { None } else { Some(provider_subject.as_str()) })
|
||||
.bind(user.id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, "Failed to update last_login_at");
|
||||
return Err(error_redirect("internal_error", "Database error"));
|
||||
}
|
||||
|
||||
let access_ttl = state.config.security.jwt_access_ttl_secs as i64;
|
||||
let access_token = match issue_access_token(
|
||||
user.id,
|
||||
&user.username,
|
||||
&user.role,
|
||||
access_ttl,
|
||||
&state.signing_key_pem,
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to issue access token");
|
||||
return Err(error_redirect("internal_error", "Token issuance failed"));
|
||||
},
|
||||
};
|
||||
|
||||
let raw_refresh = match refresh::issue(&state.db, user.id, None, None).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to issue refresh token");
|
||||
return Err(error_redirect(
|
||||
"internal_error",
|
||||
"Refresh token issuance failed",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserLogin,
|
||||
Some(user.id),
|
||||
Some(&user.username),
|
||||
None,
|
||||
None,
|
||||
json!({ "auth_provider": auth_provider }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let user_json = json!({
|
||||
"id": user.id.to_string(),
|
||||
"username": user.username,
|
||||
"display_name": user.display_name,
|
||||
"role": user.role,
|
||||
"auth_provider": auth_provider,
|
||||
"mfa_enabled": user.mfa_enabled,
|
||||
});
|
||||
|
||||
let redirect_url = format!(
|
||||
"{}?access_token={}&refresh_token={}&token_type=Bearer&expires_in={}&user={}",
|
||||
callback_url,
|
||||
urlencoding::encode(&access_token),
|
||||
urlencoding::encode(&raw_refresh.0),
|
||||
access_ttl,
|
||||
urlencoding::encode(&user_json.to_string()),
|
||||
);
|
||||
|
||||
Ok(Redirect::to(&redirect_url))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Backward-compatible Azure SSO redirect handlers
|
||||
// ============================================================
|
||||
|
||||
async fn azure_login_redirect(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
sso_login(State(state)).await
|
||||
}
|
||||
|
||||
async fn azure_callback_redirect(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<Redirect, Redirect> {
|
||||
sso_callback(State(state), axum::extract::Query(params)).await
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Database helpers
|
||||
// ============================================================
|
||||
|
||||
async fn load_oidc_config(pool: &sqlx::PgPool) -> Result<OidcConfig, (StatusCode, Json<Value>)> {
|
||||
let row: Option<OidcConfig> = sqlx::query_as(
|
||||
"SELECT enabled, provider_type, display_name, discovery_url, client_id, client_secret, redirect_uri, scopes FROM oidc_config WHERE id = 1",
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to load oidc_config");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(row.unwrap_or(OidcConfig {
|
||||
enabled: false,
|
||||
provider_type: "azure".to_string(),
|
||||
display_name: "Azure AD".to_string(),
|
||||
discovery_url: String::new(),
|
||||
client_id: String::new(),
|
||||
client_secret: String::new(),
|
||||
redirect_uri: String::new(),
|
||||
scopes: "openid profile email".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OIDC Discovery & JWKS
|
||||
// ============================================================
|
||||
|
||||
async fn fetch_discovery(state: &AppState) -> Result<OidcDiscovery, String> {
|
||||
let config = match load_oidc_config(&state.db).await {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err("Failed to load OIDC config".to_string());
|
||||
},
|
||||
};
|
||||
let discovery_url = config.discovery_url;
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let cache = state.oidc_cache.lock().await;
|
||||
if let Some(ref disc) = cache.discovery {
|
||||
let elapsed = Utc::now().signed_duration_since(disc.fetched_at);
|
||||
if elapsed.num_seconds() < DISCOVERY_CACHE_TTL_SECS {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Discovery fetch failed: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!(
|
||||
"Discovery fetch failed: HTTP {} — {}",
|
||||
status, body
|
||||
));
|
||||
}
|
||||
|
||||
let doc: Value = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse discovery document: {}", e))?;
|
||||
|
||||
let discovery = OidcDiscovery {
|
||||
issuer: doc
|
||||
.get("issuer")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
authorization_endpoint: doc
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
token_endpoint: doc
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
jwks_uri: doc
|
||||
.get("jwks_uri")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
userinfo_endpoint: doc
|
||||
.get("userinfo_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string()),
|
||||
fetched_at: Utc::now(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut cache = state.oidc_cache.lock().await;
|
||||
cache.discovery = Some(discovery.clone());
|
||||
}
|
||||
|
||||
Ok(discovery)
|
||||
}
|
||||
|
||||
async fn verify_id_token(
|
||||
token: &str,
|
||||
config: &OidcConfig,
|
||||
discovery: &OidcDiscovery,
|
||||
oidc_cache: &Arc<Mutex<OidcCache>>,
|
||||
) -> Result<IdTokenClaims, String> {
|
||||
let header = decode_header(token).map_err(|e| format!("Failed to decode JWT header: {}", e))?;
|
||||
let kid = header.kid.ok_or("JWT header missing 'kid' field")?;
|
||||
|
||||
let jwks = {
|
||||
let cache = oidc_cache.lock().await;
|
||||
let needs_fetch = match (&cache.jwks, &cache.jwks_fetched_at) {
|
||||
(None, _) => true,
|
||||
(Some(_), None) => true,
|
||||
(Some(_), Some(fetched)) => {
|
||||
let elapsed = Utc::now().signed_duration_since(*fetched);
|
||||
elapsed.num_seconds() > JWKS_CACHE_TTL_SECS
|
||||
},
|
||||
};
|
||||
|
||||
if needs_fetch {
|
||||
drop(cache);
|
||||
let jwks_value = fetch_jwks(&discovery.jwks_uri).await?;
|
||||
let mut cache = oidc_cache.lock().await;
|
||||
cache.jwks = Some(jwks_value);
|
||||
cache.jwks_fetched_at = Some(Utc::now());
|
||||
cache.jwks.clone().unwrap()
|
||||
} else {
|
||||
cache.jwks.clone().unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
let keys_array = jwks
|
||||
.get("keys")
|
||||
.ok_or("JWKS response missing 'keys' array")?
|
||||
.as_array()
|
||||
.ok_or("JWKS 'keys' is not an array")?;
|
||||
|
||||
let jwk = keys_array
|
||||
.iter()
|
||||
.find(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()))
|
||||
.ok_or_else(|| format!("No matching JWK found for kid: {}", kid))?;
|
||||
|
||||
let n = jwk
|
||||
.get("n")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or("JWK missing 'n' (modulus) field")?;
|
||||
let e = jwk
|
||||
.get("e")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or("JWK missing 'e' (exponent) field")?;
|
||||
|
||||
let decoding_key = DecodingKey::from_rsa_components(n, e)
|
||||
.map_err(|e| format!("Failed to construct RSA decoding key: {}", e))?;
|
||||
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.iss = Some(HashSet::from([discovery.issuer.clone()]));
|
||||
validation.aud = Some(HashSet::from([config.client_id.clone()]));
|
||||
validation.leeway = 60;
|
||||
|
||||
let token_data = decode::<IdTokenClaims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("JWT signature verification failed: {}", e))?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
async fn fetch_jwks(jwks_uri: &str) -> Result<serde_json::Value, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to build HTTP client for JWKS fetch: {}", e))?;
|
||||
|
||||
let resp = client
|
||||
.get(jwks_uri)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("JWKS fetch request failed: {}", e))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("JWKS fetch failed: HTTP {} — {}", status, body));
|
||||
}
|
||||
|
||||
resp.json::<serde_json::Value>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse JWKS response: {}", e))
|
||||
}
|
||||
145
crates/pm-web/src/routes/status.rs
Executable file
145
crates/pm-web/src/routes/status.rs
Executable file
@ -0,0 +1,145 @@
|
||||
//! Fleet status routes.
|
||||
//!
|
||||
//! GET /api/v1/status/fleet — aggregate health and patch summary across all hosts.
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::Json, routing::get, Router};
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new().route("/fleet", get(fleet_status))
|
||||
}
|
||||
|
||||
// ── Response type ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FleetStatus {
|
||||
pub total_hosts: i64,
|
||||
pub healthy: i64,
|
||||
pub degraded: i64,
|
||||
pub unreachable: i64,
|
||||
pub pending: i64,
|
||||
pub total_pending_patches: i64,
|
||||
pub hosts_requiring_reboot: i64,
|
||||
pub compliance_pct: f64,
|
||||
}
|
||||
|
||||
// ── GET /api/v1/status/fleet ──────────────────────────────────────────────────
|
||||
|
||||
pub async fn fleet_status(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<FleetStatus>, (StatusCode, Json<Value>)> {
|
||||
// ── 1. Host health aggregates ─────────────────────────────────────────
|
||||
let health_row: (i64, i64, i64, i64, i64) = sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) AS total_hosts,
|
||||
COUNT(*) FILTER (WHERE health_status = 'healthy') AS healthy,
|
||||
COUNT(*) FILTER (WHERE health_status = 'degraded') AS degraded,
|
||||
COUNT(*) FILTER (WHERE health_status = 'unreachable') AS unreachable,
|
||||
COUNT(*) FILTER (WHERE health_status = 'pending') AS pending
|
||||
FROM hosts
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query host health aggregates");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (total_hosts, healthy, degraded, unreachable, pending) = health_row;
|
||||
|
||||
// ── 2. Total pending patches across fleet (latest row per host) ───────
|
||||
let total_pending_patches: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COALESCE(SUM(patch_count), 0)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) patch_count
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query total pending patches");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// ── 3. Hosts requiring a reboot (latest patch row per host) ───────────
|
||||
let hosts_requiring_reboot: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) available_patches
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
WHERE available_patches @> '[{"requires_reboot": true}]'
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query reboot-required hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// ── 4. Compliance: hosts with zero pending patches / total hosts ───────
|
||||
// Hosts that have been polled and have patch_count == 0 are considered
|
||||
// compliant. Hosts with no patch data at all are excluded from the
|
||||
// compliance calculation.
|
||||
let compliant_hosts: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) patch_count
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
WHERE patch_count = 0
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query compliant hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let compliance_pct = if total_hosts == 0 {
|
||||
100.0_f64
|
||||
} else {
|
||||
(compliant_hosts as f64 / total_hosts as f64) * 100.0
|
||||
};
|
||||
|
||||
// Round to one decimal place.
|
||||
let compliance_pct = (compliance_pct * 10.0).round() / 10.0;
|
||||
|
||||
Ok(Json(FleetStatus {
|
||||
total_hosts,
|
||||
healthy,
|
||||
degraded,
|
||||
unreachable,
|
||||
pending,
|
||||
total_pending_patches,
|
||||
hosts_requiring_reboot,
|
||||
compliance_pct,
|
||||
}))
|
||||
}
|
||||
571
crates/pm-web/src/routes/users.rs
Executable file
571
crates/pm-web/src/routes/users.rs
Executable file
@ -0,0 +1,571 @@
|
||||
//! User management routes.
|
||||
//!
|
||||
//! GET /api/v1/users — list users (admin only)
|
||||
//! POST /api/v1/users — create user (admin only)
|
||||
//! GET /api/v1/users/:id — get user detail
|
||||
//! PUT /api/v1/users/:id — update user
|
||||
//! DELETE /api/v1/users/:id — delete user (admin only)
|
||||
//! GET /api/v1/users/me — current user profile
|
||||
//! POST /api/v1/users/:id/revoke — revoke all sessions (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{delete, get, post, put},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::validate_password_strength;
|
||||
use pm_auth::{hash_password, rbac::AuthUser, session::force_logout, verify_password};
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{
|
||||
AdminResetPasswordRequest, ChangePasswordRequest, CreateUserRequest, UpdateUserRequest,
|
||||
User,
|
||||
},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_users).post(create_user))
|
||||
.route("/me", get(get_current_user))
|
||||
.route("/me/password", put(change_own_password))
|
||||
.route("/{id}", get(get_user).put(update_user).delete(delete_user))
|
||||
.route("/{id}/password", put(admin_reset_password))
|
||||
.route("/{id}/mfa", delete(admin_disable_mfa))
|
||||
.route("/{id}/revoke", post(revoke_user_sessions))
|
||||
}
|
||||
|
||||
async fn list_users(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Vec<User>>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query_as::<_, User>(
|
||||
r#"SELECT id, username, display_name, email, role, auth_provider,
|
||||
mfa_enabled, is_active, force_password_reset, last_login_at,
|
||||
created_at, updated_at
|
||||
FROM users ORDER BY username"#,
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateUserRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate password strength
|
||||
if let Err(msg) = validate_password_strength(&req.password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
let hash = hash_password(&req.password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let role = match req.role.to_lowercase().as_str() {
|
||||
"admin" => "admin",
|
||||
"reporter" => "reporter",
|
||||
_ => "operator",
|
||||
};
|
||||
|
||||
let id: Uuid = sqlx::query_scalar(
|
||||
r#"INSERT INTO users (username, display_name, email, role, auth_provider, password_hash)
|
||||
VALUES ($1, $2, $3, $4::user_role, 'local', $5)
|
||||
RETURNING id"#,
|
||||
)
|
||||
.bind(&req.username)
|
||||
.bind(req.display_name.as_deref().unwrap_or(&req.username))
|
||||
.bind(&req.email)
|
||||
.bind(role)
|
||||
.bind(&hash)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = if e.to_string().contains("unique") {
|
||||
"Username or email already exists".to_string()
|
||||
} else {
|
||||
"Database error".to_string()
|
||||
};
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
Json(json!({ "error": { "code": "conflict", "message": msg } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "username": req.username }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "id": id, "message": "User created" })))
|
||||
}
|
||||
|
||||
async fn get_current_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
|
||||
fetch_user(&state.db, auth.user_id).await
|
||||
}
|
||||
|
||||
async fn get_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
|
||||
// Users can see themselves; admin can see anyone
|
||||
if !auth.role.is_admin() && auth.user_id != id {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
fetch_user(&state.db, id).await
|
||||
}
|
||||
|
||||
async fn fetch_user(
|
||||
pool: &sqlx::PgPool,
|
||||
id: Uuid,
|
||||
) -> Result<Json<User>, (StatusCode, Json<Value>)> {
|
||||
let user: Option<User> = sqlx::query_as(
|
||||
r#"SELECT id, username, display_name, email, role, auth_provider,
|
||||
mfa_enabled, is_active, force_password_reset, last_login_at,
|
||||
created_at, updated_at
|
||||
FROM users WHERE id = $1"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
user.map(Json).ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<UpdateUserRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() && auth.user_id != id {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Access denied" } })),
|
||||
));
|
||||
}
|
||||
// Only admins can change role or active status
|
||||
if (req.role.is_some() || req.is_active.is_some() || req.force_password_reset.is_some())
|
||||
&& !auth.role.is_admin()
|
||||
{
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(
|
||||
json!({ "error": { "code": "forbidden", "message": "Admin role required to change role, status, or force_password_reset" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let role_str = req
|
||||
.role
|
||||
.as_deref()
|
||||
.map(|r| match r.to_lowercase().as_str() {
|
||||
"admin" => "admin",
|
||||
"reporter" => "reporter",
|
||||
_ => "operator",
|
||||
});
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"UPDATE users SET
|
||||
display_name = COALESCE($1, display_name),
|
||||
email = COALESCE($2, email),
|
||||
role = COALESCE($3::user_role, role),
|
||||
is_active = COALESCE($4, is_active),
|
||||
force_password_reset = COALESCE($5, force_password_reset),
|
||||
updated_at = NOW()
|
||||
WHERE id = $6"#,
|
||||
)
|
||||
.bind(req.display_name.as_deref())
|
||||
.bind(req.email.as_deref())
|
||||
.bind(role_str)
|
||||
.bind(req.is_active)
|
||||
.bind(req.force_password_reset)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User updated" })))
|
||||
}
|
||||
|
||||
async fn delete_user(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
if auth.user_id == id {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "bad_request", "message": "Cannot delete your own account" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query("DELETE FROM users WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "User deleted" })))
|
||||
}
|
||||
|
||||
async fn revoke_user_sessions(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let count = force_logout(&state.db, id).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(
|
||||
json!({ "message": "Sessions revoked", "count": count }),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/users/me/password — change own password
|
||||
// ============================================================
|
||||
|
||||
async fn change_own_password(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<ChangePasswordRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Fetch current password hash
|
||||
let hash: Option<String> = sqlx::query_scalar("SELECT password_hash FROM users WHERE id = $1")
|
||||
.bind(auth.user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to fetch password hash");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?
|
||||
.flatten();
|
||||
|
||||
let hash_str = hash.unwrap_or_default();
|
||||
let valid = verify_password(&req.current_password, &hash_str).unwrap_or(false);
|
||||
|
||||
if !valid {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(
|
||||
json!({ "error": { "code": "invalid_password", "message": "Current password is incorrect" } }),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate new password strength
|
||||
if let Err(msg) = validate_password_strength(&req.new_password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
let new_hash = hash_password(&req.new_password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users SET password_hash = $1, force_password_reset = FALSE, updated_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(&new_hash)
|
||||
.bind(auth.user_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to update password");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to update password" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&auth.user_id.to_string()),
|
||||
json!({ "action": "password_change" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Password changed successfully" })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// PUT /api/v1/users/:id/password — admin reset password
|
||||
// ============================================================
|
||||
|
||||
async fn admin_reset_password(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<AdminResetPasswordRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Verify target user exists
|
||||
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !exists {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate new password strength
|
||||
if let Err(msg) = validate_password_strength(&req.new_password) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({ "error": { "code": "weak_password", "message": msg } })),
|
||||
));
|
||||
}
|
||||
|
||||
let new_hash = hash_password(&req.new_password).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": e.to_string() } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE users SET password_hash = $1, force_password_reset = $2, updated_at = NOW() WHERE id = $3",
|
||||
)
|
||||
.bind(&new_hash)
|
||||
.bind(req.force_password_reset)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to reset password");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to reset password" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "action": "admin_password_reset", "force_password_reset": req.force_password_reset }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "Password reset successfully" })))
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// DELETE /api/v1/users/:id/mfa — admin disable MFA
|
||||
// ============================================================
|
||||
|
||||
async fn admin_disable_mfa(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if !auth.role.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": { "code": "forbidden", "message": "Admin role required" } })),
|
||||
));
|
||||
}
|
||||
|
||||
let rows = sqlx::query("UPDATE users SET totp_secret = NULL, mfa_enabled = FALSE, updated_at = NOW() WHERE id = $1")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to disable MFA");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to disable MFA" } })),
|
||||
)
|
||||
})?
|
||||
.rows_affected();
|
||||
|
||||
if rows == 0 {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "User not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::UserUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("user"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "action": "admin_mfa_disabled" }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({ "message": "MFA disabled successfully" })))
|
||||
}
|
||||
205
crates/pm-web/src/routes/ws.rs
Executable file
205
crates/pm-web/src/routes/ws.rs
Executable file
@ -0,0 +1,205 @@
|
||||
//! WebSocket relay routes — M7
|
||||
//!
|
||||
//! POST /api/v1/ws/ticket — create a single-use WS auth ticket (JWT-protected)
|
||||
//! GET /api/v1/ws/jobs — browser WebSocket endpoint (ticket-authenticated)
|
||||
|
||||
use axum::{
|
||||
extract::ws::{Message, WebSocket},
|
||||
extract::{Query, State, WebSocketUpgrade},
|
||||
http::StatusCode,
|
||||
response::{Json, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::postgres::PgListener;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── WsTicket ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Single-use WebSocket authentication ticket stored in-memory.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WsTicket {
|
||||
pub user_id: Uuid,
|
||||
pub role: String,
|
||||
pub expires_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Router for ticket-issuance endpoint (JWT-protected, merged into protected_api).
|
||||
pub fn ticket_router() -> Router<AppState> {
|
||||
Router::new().route("/ws/ticket", post(create_ticket_handler))
|
||||
}
|
||||
|
||||
/// Router for the WebSocket endpoint (ticket-authenticated, NO JWT middleware).
|
||||
pub fn ws_router() -> Router<AppState> {
|
||||
Router::new().route("/api/v1/ws/jobs", get(ws_handler))
|
||||
}
|
||||
|
||||
// ── Error helper ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── POST /api/v1/ws/ticket ────────────────────────────────────────────────────
|
||||
|
||||
/// Issue a single-use WebSocket authentication ticket (60 s expiry).
|
||||
pub async fn create_ticket_handler(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let ticket_id = Ulid::new().to_string();
|
||||
let expires_at = Utc::now() + Duration::seconds(60);
|
||||
|
||||
let ticket = WsTicket {
|
||||
user_id: auth.user_id,
|
||||
role: auth.role.as_str().to_string(),
|
||||
expires_at,
|
||||
};
|
||||
|
||||
state.ws_tickets.insert(ticket_id.clone(), ticket);
|
||||
|
||||
tracing::info!(
|
||||
user_id = %auth.user_id,
|
||||
username = %auth.username,
|
||||
ticket = %ticket_id,
|
||||
"WS ticket issued"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "ticket": ticket_id })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/ws/jobs ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsQuery {
|
||||
pub ticket: String,
|
||||
}
|
||||
|
||||
/// Browser WebSocket upgrade endpoint — authenticates via single-use ticket.
|
||||
pub async fn ws_handler(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<WsQuery>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, (StatusCode, Json<Value>)> {
|
||||
// Validate and consume the ticket atomically.
|
||||
let ticket = {
|
||||
let entry = state.ws_tickets.get(&q.ticket);
|
||||
match entry {
|
||||
None => {
|
||||
return Err(err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_ticket",
|
||||
"WebSocket ticket not found or already used",
|
||||
));
|
||||
},
|
||||
Some(t) => {
|
||||
if t.expires_at < Utc::now() {
|
||||
drop(t);
|
||||
state.ws_tickets.remove(&q.ticket);
|
||||
return Err(err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"ticket_expired",
|
||||
"WebSocket ticket has expired",
|
||||
));
|
||||
}
|
||||
t.clone()
|
||||
},
|
||||
}
|
||||
};
|
||||
// Single-use: remove immediately after validation.
|
||||
state.ws_tickets.remove(&q.ticket);
|
||||
|
||||
tracing::info!(
|
||||
user_id = %ticket.user_id,
|
||||
role = %ticket.role,
|
||||
"Browser WebSocket connection upgraded"
|
||||
);
|
||||
|
||||
let db = state.db.clone();
|
||||
Ok(ws.on_upgrade(move |socket| handle_browser_ws(socket, db, ticket)))
|
||||
}
|
||||
|
||||
// ── WebSocket handler ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Drive the browser WebSocket: LISTEN on `job_update` and forward payloads.
|
||||
async fn handle_browser_ws(mut socket: WebSocket, db: sqlx::PgPool, ticket: WsTicket) {
|
||||
// Acquire a dedicated PG listener connection.
|
||||
let mut listener = match PgListener::connect_with(&db).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "Failed to create PgListener");
|
||||
let _ = socket
|
||||
.send(Message::Text(
|
||||
json!({ "error": "internal_error" }).to_string().into(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = listener.listen("job_update").await {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener LISTEN failed");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS: LISTEN job_update started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward PG notifications to the browser.
|
||||
notify_result = listener.recv() => {
|
||||
match notify_result {
|
||||
Ok(notification) => {
|
||||
let payload = notification.payload().to_string();
|
||||
tracing::debug!(user_id = %ticket.user_id, payload = %payload, "Forwarding job_update");
|
||||
if socket.send(Message::Text(payload.into())).await.is_err() {
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS send failed — client disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener recv error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle incoming frames from the browser (ping/close).
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS closed by client");
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) if socket.send(Message::Pong(data.clone())).await.is_err() => {
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
tracing::debug!(error = %e, user_id = %ticket.user_id, "Browser WS recv error");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS handler exiting");
|
||||
}
|
||||
31
crates/pm-worker/Cargo.toml
Normal file
31
crates/pm-worker/Cargo.toml
Normal file
@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "pm-worker"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "pm-worker"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
pm-agent-client = { path = "../pm-agent-client" }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { version = "0.26" }
|
||||
rustls-pemfile = { version = "2" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
|
||||
reqwest = { workspace = true }
|
||||
45
crates/pm-worker/src/agent_loader.rs
Executable file
45
crates/pm-worker/src/agent_loader.rs
Executable file
@ -0,0 +1,45 @@
|
||||
//! Helper for loading mTLS certificate/key material from disk.
|
||||
//!
|
||||
//! Reads PEM files referenced in [`SecurityConfig`] and returns the raw bytes
|
||||
//! needed by [`pm_agent_client::AgentClient`].
|
||||
|
||||
use pm_core::config::SecurityConfig;
|
||||
|
||||
/// Raw PEM bytes for mTLS client authentication and CA verification.
|
||||
pub struct AgentCerts {
|
||||
pub client_cert: Vec<u8>,
|
||||
pub client_key: Vec<u8>,
|
||||
pub ca_cert: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Load agent mTLS certificates from the paths specified in [`SecurityConfig`].
|
||||
///
|
||||
/// Returns an error if any file cannot be read. The caller should handle
|
||||
/// the error gracefully (log and skip the poll cycle) rather than crashing.
|
||||
pub fn load_agent_certs(security: &SecurityConfig) -> anyhow::Result<AgentCerts> {
|
||||
let client_cert = std::fs::read(&security.agent_client_cert_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read agent client cert '{}': {}",
|
||||
security.agent_client_cert_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let client_key = std::fs::read(&security.agent_client_key_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read agent client key '{}': {}",
|
||||
security.agent_client_key_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let ca_cert = std::fs::read(&security.ca_cert_path).map_err(|e| {
|
||||
anyhow::anyhow!("Failed to read CA cert '{}': {}", security.ca_cert_path, e)
|
||||
})?;
|
||||
|
||||
Ok(AgentCerts {
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
})
|
||||
}
|
||||
86
crates/pm-worker/src/audit_verifier.rs
Executable file
86
crates/pm-worker/src/audit_verifier.rs
Executable file
@ -0,0 +1,86 @@
|
||||
//! Periodic audit log integrity verification.
|
||||
//!
|
||||
//! Runs every 24 hours, walks the audit_log rows ordered by id,
|
||||
//! verifies each row_hash matches the recomputed hash, and logs the
|
||||
//! result as an `AuditIntegrityVerified` event. If tampering is
|
||||
//! detected, logs an error and creates an alert.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use sqlx::PgPool;
|
||||
|
||||
use pm_core::audit::{log_event, verify_integrity, AuditAction};
|
||||
use pm_core::config::AppConfig;
|
||||
|
||||
/// Run the audit integrity verifier every 24 hours.
|
||||
pub async fn run_audit_verifier(pool: PgPool, _config: Arc<AppConfig>) {
|
||||
tracing::info!("Audit integrity verifier started");
|
||||
|
||||
// Run immediately on startup
|
||||
verify_once(&pool).await;
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(24 * 60 * 60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
tracing::info!("Running scheduled audit integrity verification");
|
||||
verify_once(&pool).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a single integrity verification pass.
|
||||
async fn verify_once(pool: &PgPool) {
|
||||
let result = verify_integrity(pool).await;
|
||||
|
||||
if result.intact {
|
||||
tracing::info!(
|
||||
rows_checked = result.rows_checked,
|
||||
"Audit integrity verification passed"
|
||||
);
|
||||
} else {
|
||||
tracing::error!(
|
||||
rows_checked = result.rows_checked,
|
||||
error_count = result.errors.len(),
|
||||
"Audit integrity verification FAILED — tampering detected!"
|
||||
);
|
||||
|
||||
for err in &result.errors {
|
||||
tracing::error!(
|
||||
row_id = err.row_id,
|
||||
expected_hash = %err.expected_hash,
|
||||
actual_hash = %err.actual_hash,
|
||||
"Audit chain integrity error"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Log the verification event
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::AuditIntegrityVerified,
|
||||
None,
|
||||
None,
|
||||
Some("audit_log"),
|
||||
None,
|
||||
serde_json::json!({
|
||||
"intact": result.intact,
|
||||
"rows_checked": result.rows_checked,
|
||||
"error_count": result.errors.len(),
|
||||
"errors": result.errors.iter().take(10).map(|e| serde_json::json!({
|
||||
"row_id": e.row_id,
|
||||
"expected_hash": e.expected_hash,
|
||||
"actual_hash": e.actual_hash,
|
||||
})).collect::<Vec<_>>(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Update last verified timestamp
|
||||
let _ = sqlx::query(
|
||||
"UPDATE system_config SET value = NOW()::text, updated_at = NOW() WHERE key = 'audit_integrity_last_verified'",
|
||||
)
|
||||
.execute(pool)
|
||||
.await;
|
||||
}
|
||||
331
crates/pm-worker/src/email.rs
Executable file
331
crates/pm-worker/src/email.rs
Executable file
@ -0,0 +1,331 @@
|
||||
//! Email notification module.
|
||||
//!
|
||||
//! Loads SMTP configuration from `system_config` and sends notification emails
|
||||
//! for patch job events (completion, failure) and maintenance window reminders.
|
||||
//! All emails are optional and disabled by default via `notification_email_enabled`.
|
||||
|
||||
use lettre::{
|
||||
message::{header::ContentType, Mailbox},
|
||||
transport::smtp::authentication::Credentials,
|
||||
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use pm_core::audit::{log_event, AuditAction};
|
||||
|
||||
/// SMTP configuration loaded from `system_config`.
|
||||
struct SmtpSettings {
|
||||
enabled: bool,
|
||||
host: String,
|
||||
port: u16,
|
||||
username: String,
|
||||
password: String,
|
||||
from: String,
|
||||
tls_mode: String,
|
||||
}
|
||||
|
||||
/// Notification preferences loaded from `system_config`.
|
||||
struct NotificationSettings {
|
||||
email_enabled: bool,
|
||||
email_from: String,
|
||||
recipients: Vec<String>,
|
||||
}
|
||||
|
||||
/// Load SMTP settings from the `system_config` table.
|
||||
async fn load_smtp_settings(pool: &PgPool) -> SmtpSettings {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(
|
||||
"SELECT key, value FROM system_config WHERE key IN (
|
||||
'smtp_enabled', 'smtp_host', 'smtp_port', 'smtp_username',
|
||||
'smtp_password', 'smtp_from', 'smtp_tls_mode'
|
||||
)",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let get = |key: &str| -> String {
|
||||
rows.iter()
|
||||
.find(|(k, _)| k == key)
|
||||
.map(|(_, v)| v.clone())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
SmtpSettings {
|
||||
enabled: get("smtp_enabled") == "true",
|
||||
host: get("smtp_host"),
|
||||
port: get("smtp_port").parse().unwrap_or(587),
|
||||
username: get("smtp_username"),
|
||||
password: get("smtp_password"),
|
||||
from: get("smtp_from"),
|
||||
tls_mode: get("smtp_tls_mode"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load notification preferences from `system_config`.
|
||||
async fn load_notification_settings(pool: &PgPool) -> NotificationSettings {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(
|
||||
"SELECT key, value FROM system_config WHERE key IN (
|
||||
'notification_email_enabled', 'notification_email_from', 'notification_email_recipients'
|
||||
)",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let get = |key: &str| -> String {
|
||||
rows.iter()
|
||||
.find(|(k, _)| k == key)
|
||||
.map(|(_, v)| v.clone())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
let recipients: Vec<String> =
|
||||
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
|
||||
|
||||
NotificationSettings {
|
||||
email_enabled: get("notification_email_enabled") == "true",
|
||||
email_from: get("notification_email_from"),
|
||||
recipients,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an async SMTP transport from settings.
|
||||
fn build_transport(settings: &SmtpSettings) -> Result<AsyncSmtpTransport<Tokio1Executor>, String> {
|
||||
match settings.tls_mode.as_str() {
|
||||
"tls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(&settings.host)
|
||||
.map_err(|e| format!("TLS relay error: {}", e))?;
|
||||
builder = builder.port(settings.port);
|
||||
if !settings.username.is_empty() {
|
||||
builder = builder.credentials(Credentials::new(
|
||||
settings.username.clone(),
|
||||
settings.password.clone(),
|
||||
));
|
||||
}
|
||||
Ok(builder.build())
|
||||
},
|
||||
"starttls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(&settings.host)
|
||||
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
|
||||
builder = builder.port(settings.port);
|
||||
if !settings.username.is_empty() {
|
||||
builder = builder.credentials(Credentials::new(
|
||||
settings.username.clone(),
|
||||
settings.password.clone(),
|
||||
));
|
||||
}
|
||||
Ok(builder.build())
|
||||
},
|
||||
_ => {
|
||||
// "none" — plaintext / no TLS
|
||||
let mut builder =
|
||||
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(&settings.host)
|
||||
.port(settings.port);
|
||||
if !settings.username.is_empty() {
|
||||
builder = builder.credentials(Credentials::new(
|
||||
settings.username.clone(),
|
||||
settings.password.clone(),
|
||||
));
|
||||
}
|
||||
Ok(builder.build())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an email notification. Returns true if the email was sent successfully.
|
||||
async fn send_email(pool: &PgPool, subject: &str, body: &str) -> bool {
|
||||
let smtp = match load_smtp_settings(pool).await {
|
||||
s if !s.enabled => {
|
||||
tracing::debug!("SMTP not enabled, skipping email notification");
|
||||
return false;
|
||||
},
|
||||
s => s,
|
||||
};
|
||||
|
||||
let notif = load_notification_settings(pool).await;
|
||||
if !notif.email_enabled {
|
||||
tracing::debug!("Email notifications disabled, skipping");
|
||||
return false;
|
||||
}
|
||||
|
||||
if notif.recipients.is_empty() {
|
||||
tracing::debug!("No email recipients configured, skipping notification");
|
||||
return false;
|
||||
}
|
||||
|
||||
let from_addr = if notif.email_from.is_empty() {
|
||||
smtp.from.clone()
|
||||
} else {
|
||||
notif.email_from
|
||||
};
|
||||
|
||||
let from_mailbox: Mailbox = match from_addr.parse() {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Invalid from address for email notification");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let mut builder = Message::builder()
|
||||
.from(from_mailbox.clone())
|
||||
.subject(subject)
|
||||
.header(ContentType::TEXT_PLAIN);
|
||||
|
||||
// Add all recipients
|
||||
for recipient in ¬if.recipients {
|
||||
let mailbox: Mailbox = match recipient.parse() {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, recipient = %recipient, "Invalid recipient address");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
builder = builder.to(mailbox);
|
||||
}
|
||||
|
||||
let email = match builder.body(body.to_string()) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to build email message");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let transport = match build_transport(&smtp) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to build SMTP transport");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
match transport.send(email).await {
|
||||
Ok(_) => {
|
||||
tracing::info!(subject, "Email notification sent successfully");
|
||||
true
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, subject, "Failed to send email notification");
|
||||
false
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a patch failure notification email for a specific host.
|
||||
pub async fn send_patch_failure_email(
|
||||
pool: &PgPool,
|
||||
host_fqdn: &str,
|
||||
job_id: &str,
|
||||
error_message: &str,
|
||||
) {
|
||||
let subject = format!("[Patch Manager] Patch Failed on {}", host_fqdn);
|
||||
let body = format!(
|
||||
"Patch operation failed on host: {host_fqdn}\n\
|
||||
Job ID: {job_id}\n\
|
||||
Error: {error_message}\n\
|
||||
\n\
|
||||
Please review the job details in the Patch Manager dashboard."
|
||||
);
|
||||
|
||||
let sent = send_email(pool, &subject, &body).await;
|
||||
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::EmailNotificationSent,
|
||||
None,
|
||||
None,
|
||||
Some("patch_job"),
|
||||
Some(job_id),
|
||||
serde_json::json!({
|
||||
"type": "patch_failure",
|
||||
"host_fqdn": host_fqdn,
|
||||
"sent": sent,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Send a job completion notification email.
|
||||
pub async fn send_job_completion_email(
|
||||
pool: &PgPool,
|
||||
job_id: &str,
|
||||
host_count: i64,
|
||||
succeeded_count: i64,
|
||||
failed_count: i64,
|
||||
) {
|
||||
let subject = format!("[Patch Manager] Job {} Completed", job_id);
|
||||
let body = format!(
|
||||
"Patch job completed: {job_id}\n\
|
||||
Total hosts: {host_count}\n\
|
||||
Succeeded: {succeeded_count}\n\
|
||||
Failed: {failed_count}\n\
|
||||
\n\
|
||||
Please review the job details in the Patch Manager dashboard."
|
||||
);
|
||||
|
||||
let sent = send_email(pool, &subject, &body).await;
|
||||
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::EmailNotificationSent,
|
||||
None,
|
||||
None,
|
||||
Some("patch_job"),
|
||||
Some(job_id),
|
||||
serde_json::json!({
|
||||
"type": "job_completion",
|
||||
"host_count": host_count,
|
||||
"succeeded_count": succeeded_count,
|
||||
"failed_count": failed_count,
|
||||
"sent": sent,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Send a maintenance window reminder email.
|
||||
#[allow(dead_code)]
|
||||
pub async fn send_maintenance_window_reminder_email(
|
||||
pool: &PgPool,
|
||||
host_fqdn: &str,
|
||||
window_label: &str,
|
||||
start_at: &str,
|
||||
) {
|
||||
let subject = format!(
|
||||
"[Patch Manager] Upcoming Maintenance Window: {}",
|
||||
window_label
|
||||
);
|
||||
let body = format!(
|
||||
"Maintenance window reminder:\n\
|
||||
Host: {host_fqdn}\n\
|
||||
Window: {window_label}\n\
|
||||
Starts at: {start_at}\n\
|
||||
\n\
|
||||
Patch operations will begin at the scheduled time."
|
||||
);
|
||||
|
||||
let sent = send_email(pool, &subject, &body).await;
|
||||
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::MaintenanceWindowReminder,
|
||||
None,
|
||||
None,
|
||||
Some("maintenance_window"),
|
||||
None,
|
||||
serde_json::json!({
|
||||
"type": "maintenance_reminder",
|
||||
"host_fqdn": host_fqdn,
|
||||
"window_label": window_label,
|
||||
"sent": sent,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
480
crates/pm-worker/src/health_check_poller.rs
Executable file
480
crates/pm-worker/src/health_check_poller.rs
Executable file
@ -0,0 +1,480 @@
|
||||
//! Periodic health check poller for configured service and HTTP checks.
|
||||
//!
|
||||
//! Polls every `health_check_poll_interval_secs`, querying each enabled health
|
||||
//! check definition and storing results in `host_health_check_results`.
|
||||
//! Results older than 4 days are pruned on each cycle.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use pm_core::{config::AppConfig, crypto};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// DB row types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Row fetched for each enabled health check, joined with host connection info.
|
||||
#[derive(FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct HealthCheckRow {
|
||||
id: Uuid,
|
||||
host_id: Uuid,
|
||||
name: String,
|
||||
check_type: String,
|
||||
service_name: Option<String>,
|
||||
url: Option<String>,
|
||||
expected_body: Option<String>,
|
||||
ignore_cert_errors: Option<bool>,
|
||||
basic_auth_user: Option<String>,
|
||||
basic_auth_pass_encrypted: Option<Vec<u8>>,
|
||||
basic_auth_pass_nonce: Option<Vec<u8>>,
|
||||
target_host_id: Option<Uuid>,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public entry point
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run the health check poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all enabled health checks are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_health_check_results` and stale rows are pruned.
|
||||
pub async fn run_health_check_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.health_check_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(interval_secs, "Health check poller started");
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Load certs on each cycle so cert rotation is picked up automatically.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Health check poller: failed to load agent certs — skipping cycle"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
// Load the crypto key for decrypting HTTP check passwords.
|
||||
let crypto_key = match crypto::load_or_create_key(Path::new(crypto::KEY_PATH)) {
|
||||
Ok(k) => Arc::new(k),
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Health check poller: failed to load crypto key — skipping cycle"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Fetch all enabled health checks with host connection info.
|
||||
let checks: Vec<HealthCheckRow> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
hc.id,
|
||||
hc.host_id,
|
||||
hc.name,
|
||||
hc.check_type,
|
||||
hc.service_name,
|
||||
hc.url,
|
||||
hc.expected_body,
|
||||
hc.ignore_cert_errors,
|
||||
hc.basic_auth_user,
|
||||
hc.basic_auth_pass_encrypted,
|
||||
hc.basic_auth_pass_nonce,
|
||||
hc.target_host_id,
|
||||
host(COALESCE(th.ip_address, h.ip_address))::text AS ip_address,
|
||||
COALESCE(th.agent_port, h.agent_port) AS agent_port
|
||||
FROM host_health_checks hc
|
||||
JOIN hosts h ON h.id = hc.host_id
|
||||
LEFT JOIN hosts th ON th.id = hc.target_host_id
|
||||
WHERE hc.enabled = TRUE
|
||||
ORDER BY hc.id
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health check poller: failed to fetch health checks");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if checks.is_empty() {
|
||||
tracing::debug!("Health check poller: no enabled health checks, skipping cycle");
|
||||
prune_old_results(&pool).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = checks.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for check in checks {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
let ckey = crypto_key.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
run_check(pool, check, &cert, &key, &ca, &ckey).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect results and tally counts.
|
||||
let mut healthy_count = 0usize;
|
||||
let mut unhealthy_count = 0usize;
|
||||
let mut error_count = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(true) => healthy_count += 1,
|
||||
Ok(false) => unhealthy_count += 1,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health check poller task panicked");
|
||||
error_count += 1;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
total,
|
||||
healthy_count,
|
||||
unhealthy_count,
|
||||
error_count,
|
||||
"Health check poll cycle complete"
|
||||
);
|
||||
|
||||
// Prune results older than 4 days.
|
||||
prune_old_results(&pool).await;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Check dispatch
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run a single health check and persist the result. Returns `true` if healthy.
|
||||
async fn run_check(
|
||||
pool: PgPool,
|
||||
check: HealthCheckRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
crypto_key: &[u8; 32],
|
||||
) -> bool {
|
||||
let start = Instant::now();
|
||||
|
||||
let (healthy, detail) = match check.check_type.as_str() {
|
||||
"service" => run_service_check(&check, client_cert, client_key, ca_cert).await,
|
||||
"http" => run_http_check(&check, crypto_key).await,
|
||||
other => {
|
||||
tracing::warn!(
|
||||
check_id = %check.id,
|
||||
check_type = other,
|
||||
"Unknown health check type — treating as unhealthy"
|
||||
);
|
||||
(false, format!("Unknown check type: {other}"))
|
||||
},
|
||||
};
|
||||
|
||||
let latency_ms = start.elapsed().as_millis() as i32;
|
||||
|
||||
// Persist the result.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_check_results (check_id, healthy, detail, latency_ms)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
)
|
||||
.bind(check.id)
|
||||
.bind(healthy)
|
||||
.bind(&detail)
|
||||
.bind(latency_ms)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
check_id = %check.id,
|
||||
error = %e,
|
||||
"Health check poller: failed to insert result"
|
||||
);
|
||||
}
|
||||
|
||||
healthy
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Service check (via mTLS AgentClient)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute a service check by calling the agent's `/api/v1/system/services/{name}` endpoint.
|
||||
async fn run_service_check(
|
||||
check: &HealthCheckRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> (bool, String) {
|
||||
let service_name = match &check.service_name {
|
||||
Some(name) => name.clone(),
|
||||
None => {
|
||||
return (false, "Service check missing service_name".to_string());
|
||||
},
|
||||
};
|
||||
|
||||
let client = match AgentClient::new(
|
||||
&check.ip_address,
|
||||
check.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return (false, format!("Failed to build AgentClient: {e}"));
|
||||
},
|
||||
};
|
||||
|
||||
match client.service_status(&service_name).await {
|
||||
Ok(data) => {
|
||||
let detail = if data.healthy {
|
||||
format!(
|
||||
"Service '{}' is {}/{} (enabled: {})",
|
||||
data.name, data.active_state, data.sub_state, data.enabled_state
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Service '{}' status: {}/{} (unhealthy, enabled: {})",
|
||||
data.name, data.active_state, data.sub_state, data.enabled_state
|
||||
)
|
||||
};
|
||||
(data.healthy, detail)
|
||||
},
|
||||
Err(AgentClientError::Timeout) => (
|
||||
false,
|
||||
format!("Agent timed out querying service '{service_name}'"),
|
||||
),
|
||||
Err(AgentClientError::Connect(_)) => (
|
||||
false,
|
||||
format!("Agent connection refused for service '{service_name}'"),
|
||||
),
|
||||
Err(AgentClientError::ApiError { code, message }) => {
|
||||
// 404, 400, 500 etc. from the agent means the service is unhealthy.
|
||||
(false, format!("Agent error [{code}]: {message}"))
|
||||
},
|
||||
Err(e) => (
|
||||
false,
|
||||
format!("Agent error querying service '{service_name}': {e}"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// HTTP check (via reqwest, no mTLS)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute an HTTP check by making a GET request to the configured URL.
|
||||
/// Supports optional basic auth (decrypted from DB) and substring body matching.
|
||||
async fn run_http_check(check: &HealthCheckRow, crypto_key: &[u8; 32]) -> (bool, String) {
|
||||
let url = match &check.url {
|
||||
Some(u) => u.clone(),
|
||||
None => {
|
||||
return (false, "HTTP check missing URL".to_string());
|
||||
},
|
||||
};
|
||||
|
||||
// Build a reqwest client for this check.
|
||||
// Use danger_accept_invalid_certs if ignore_cert_errors is set (default true).
|
||||
let ignore_cert_errors = check.ignore_cert_errors.unwrap_or(true);
|
||||
|
||||
let client_builder = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.redirect(reqwest::redirect::Policy::limited(5));
|
||||
|
||||
let client = if ignore_cert_errors {
|
||||
client_builder
|
||||
.danger_accept_invalid_certs(true)
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
} else {
|
||||
client_builder
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
};
|
||||
|
||||
// Build the request.
|
||||
let mut request = client.get(&url);
|
||||
|
||||
// Add basic auth if configured.
|
||||
if let Some(user) = &check.basic_auth_user {
|
||||
// Decrypt the password if present.
|
||||
let password = match (
|
||||
&check.basic_auth_pass_encrypted,
|
||||
&check.basic_auth_pass_nonce,
|
||||
) {
|
||||
(Some(enc), Some(nonce)) => match crypto::decrypt(enc, nonce, crypto_key) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return (false, format!("Failed to decrypt basic auth password: {e}"));
|
||||
},
|
||||
},
|
||||
_ => {
|
||||
// No encrypted password stored — treat as missing credentials.
|
||||
return (
|
||||
false,
|
||||
"HTTP check has basic_auth_user but no encrypted password".to_string(),
|
||||
);
|
||||
},
|
||||
};
|
||||
request = request.basic_auth(user.as_str(), Some(password.as_str()));
|
||||
}
|
||||
|
||||
// Execute the request.
|
||||
let response = match request.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
if e.is_timeout() {
|
||||
return (false, format!("HTTP check timed out: {url}"));
|
||||
} else if e.is_connect() {
|
||||
return (false, format!("HTTP check connection failed: {url}"));
|
||||
} else {
|
||||
return (false, format!("HTTP check request error: {e}"));
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
|
||||
// Check HTTP status code.
|
||||
if !status.is_success() {
|
||||
return (
|
||||
false,
|
||||
format!("HTTP check returned status {} for {url}", status.as_u16()),
|
||||
);
|
||||
}
|
||||
|
||||
// Read the response body for substring matching.
|
||||
let body = match response.text().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
return (
|
||||
false,
|
||||
format!("HTTP check failed to read response body: {e}"),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
// Check expected_body substring match.
|
||||
if let Some(expected) = &check.expected_body {
|
||||
if !body.contains(expected) {
|
||||
return (
|
||||
false,
|
||||
format!("HTTP check body mismatch for {url}: expected substring not found"),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
true,
|
||||
format!("HTTP check OK for {url} (status {})", status.as_u16()),
|
||||
)
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Prune old results
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Delete health check results older than 4 days.
|
||||
async fn prune_old_results(pool: &PgPool) {
|
||||
match sqlx::query(
|
||||
"DELETE FROM host_health_check_results WHERE checked_at < NOW() - INTERVAL '4 days'",
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!(
|
||||
rows_deleted = result.rows_affected(),
|
||||
"Health check poller: pruned old results"
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health check poller: failed to prune old results");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Health check gate for job executor
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Check whether all enabled health checks for a host are healthy.
|
||||
///
|
||||
/// Returns `Ok(true)` if all checks pass (or no checks are configured),
|
||||
/// `Ok(false)` if any check is unhealthy or has no result yet.
|
||||
pub async fn check_host_health_checks(pool: &PgPool, host_id: Uuid) -> anyhow::Result<bool> {
|
||||
// Check if there are any enabled health checks for this host.
|
||||
let check_count: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM host_health_checks WHERE host_id = $1 AND enabled = TRUE",
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if check_count.0 == 0 {
|
||||
// No health checks configured for this host — treat as healthy.
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Find any enabled check that has no healthy result or an unhealthy latest result.
|
||||
let unhealthy_count: (i64,) = sqlx::query_as(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM host_health_checks hc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT healthy
|
||||
FROM host_health_check_results r
|
||||
WHERE r.check_id = hc.id
|
||||
ORDER BY r.checked_at DESC
|
||||
LIMIT 1
|
||||
) latest ON true
|
||||
WHERE hc.host_id = $1
|
||||
AND hc.enabled = TRUE
|
||||
AND (latest.healthy IS NULL OR latest.healthy = FALSE)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(unhealthy_count.0 == 0)
|
||||
}
|
||||
249
crates/pm-worker/src/health_poller.rs
Normal file
249
crates/pm-worker/src/health_poller.rs
Normal file
@ -0,0 +1,249 @@
|
||||
//! Periodic health poller for all registered hosts.
|
||||
//!
|
||||
//! Polls every host via the agent `/health` endpoint on each tick of
|
||||
//! `health_poll_interval_secs`, with bounded concurrency controlled by a
|
||||
//! [`tokio::sync::Semaphore`]. Also calls `/system/info` to refresh
|
||||
//! `os_family`, `os_name`, `arch`, and `agent_version` in the hosts table.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
use pm_core::{config::AppConfig, models::HostHealthStatus};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host projection fetched for each poll cycle.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the health poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all registered hosts are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_health_data` and the `hosts` table is updated.
|
||||
pub async fn run_health_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.health_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(interval_secs, "Health poller started");
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Load certs on each cycle so cert rotation is picked up automatically.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health poller: failed to load agent certs — skipping cycle");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
// Fetch all hosts.
|
||||
let hosts: Vec<HostRow> = match sqlx::query_as(
|
||||
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health poller: failed to fetch hosts");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!("Health poller: no hosts registered, skipping cycle");
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = hosts.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for host in hosts {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
poll_host_health(pool, host, &cert, &key, &ca).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect results and tally counts.
|
||||
let mut healthy = 0usize;
|
||||
let mut degraded = 0usize;
|
||||
let mut unreachable = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(HostHealthStatus::Healthy) => healthy += 1,
|
||||
Ok(HostHealthStatus::Degraded) => degraded += 1,
|
||||
Ok(HostHealthStatus::Unreachable) => unreachable += 1,
|
||||
Ok(_) => {},
|
||||
Err(e) => tracing::error!(error = %e, "Health poller task panicked"),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
total,
|
||||
healthy,
|
||||
degraded,
|
||||
unreachable,
|
||||
"Health poll cycle complete"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll a single host, persist the result, and return the determined status.
|
||||
///
|
||||
/// Also updates `agent_version` from the health response and
|
||||
/// `os_family`/`os_name`/`arch` from the `/system/info` endpoint when available.
|
||||
async fn poll_host_health(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> HostHealthStatus {
|
||||
// Determine status, payload, agent version, and optional system info.
|
||||
let (status, payload, agent_version, sys_info) = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Health poller: failed to build AgentClient"
|
||||
);
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
},
|
||||
Ok(client) => {
|
||||
let (status, payload, version) = match client.health().await {
|
||||
Ok(data) => {
|
||||
let payload = serde_json::to_value(&data).unwrap_or_default();
|
||||
(HostHealthStatus::Healthy, payload, Some(data.version))
|
||||
},
|
||||
Err(AgentClientError::Timeout) => {
|
||||
tracing::warn!(host_id = %host.id, "Health poller: agent timed out");
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
)
|
||||
},
|
||||
Err(AgentClientError::Connect(_)) => {
|
||||
tracing::warn!(host_id = %host.id, "Health poller: agent connection refused");
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
)
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Health poller: agent error");
|
||||
(
|
||||
HostHealthStatus::Degraded,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
)
|
||||
},
|
||||
};
|
||||
|
||||
// Try to fetch system info for OS/arch details (best-effort).
|
||||
let sys_info = if status != HostHealthStatus::Unreachable {
|
||||
match client.system_info().await {
|
||||
Ok(info) => Some(info),
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Health poller: failed to get system info (non-fatal)"
|
||||
);
|
||||
None
|
||||
},
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(status, payload, version, sys_info)
|
||||
},
|
||||
};
|
||||
|
||||
// Insert into host_health_data.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_data (host_id, status, payload)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&status)
|
||||
.bind(&payload)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to insert health data");
|
||||
}
|
||||
|
||||
// Build OS name from system info components (e.g. "Ubuntu 24.04").
|
||||
let os_name_from_sysinfo = sys_info
|
||||
.as_ref()
|
||||
.map(|i| format!("{} {}", i.os, i.os_version));
|
||||
|
||||
// Update hosts table with health status, agent version, and OS details.
|
||||
// COALESCE preserves existing values when new data is unavailable.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE hosts
|
||||
SET health_status = $2, last_health_at = NOW(),
|
||||
agent_version = COALESCE($3, agent_version),
|
||||
os_family = COALESCE($4, os_family),
|
||||
os_name = COALESCE($5, os_name),
|
||||
arch = COALESCE($6, arch)
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&status)
|
||||
.bind(&agent_version)
|
||||
.bind(sys_info.as_ref().map(|i| i.os.as_str()))
|
||||
.bind(os_name_from_sysinfo)
|
||||
.bind(sys_info.as_ref().map(|i| i.architecture.as_str()))
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to update host status");
|
||||
}
|
||||
|
||||
status
|
||||
}
|
||||
1057
crates/pm-worker/src/job_executor.rs
Executable file
1057
crates/pm-worker/src/job_executor.rs
Executable file
File diff suppressed because it is too large
Load Diff
208
crates/pm-worker/src/main.rs
Executable file
208
crates/pm-worker/src/main.rs
Executable file
@ -0,0 +1,208 @@
|
||||
//! pm-worker — Linux Patch Manager background worker.
|
||||
//!
|
||||
//! Handles scheduled polling, job execution, maintenance window scheduling,
|
||||
//! retry logic, email notifications, audit integrity verification, and data pruning.
|
||||
|
||||
mod agent_loader;
|
||||
mod audit_verifier;
|
||||
mod email;
|
||||
mod health_check_poller;
|
||||
mod health_poller;
|
||||
mod job_executor;
|
||||
mod maintenance_scheduler;
|
||||
mod patch_poller;
|
||||
mod refresh_listener;
|
||||
mod ws_relay;
|
||||
|
||||
use chrono::Utc;
|
||||
use pm_core::{config::AppConfig, db, logging};
|
||||
use sqlx::PgPool;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time;
|
||||
|
||||
use audit_verifier::run_audit_verifier;
|
||||
use health_check_poller::run_health_check_poller;
|
||||
use health_poller::run_health_poller;
|
||||
use job_executor::run_job_executor;
|
||||
use maintenance_scheduler::run_maintenance_scheduler;
|
||||
use patch_poller::run_patch_poller;
|
||||
use refresh_listener::run_refresh_listener;
|
||||
use ws_relay::run_ws_relay;
|
||||
|
||||
/// Minimum number of applied migrations the worker requires before
|
||||
/// accepting work. Prevents the worker from running against a schema
|
||||
/// that hasn't been migrated yet.
|
||||
const REQUIRED_MIGRATION_COUNT: i64 = 16;
|
||||
|
||||
/// How long to wait between schema-version checks before giving up.
|
||||
const SCHEMA_CHECK_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Install the default crypto provider for rustls (required since 0.23)
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install rustls crypto provider");
|
||||
|
||||
// Load configuration
|
||||
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
|
||||
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
|
||||
|
||||
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
|
||||
eprintln!("Config file not found or invalid, using defaults");
|
||||
AppConfig::default()
|
||||
});
|
||||
|
||||
// Initialize logging
|
||||
logging::init(&config.logging);
|
||||
|
||||
tracing::info!(
|
||||
version = env!("CARGO_PKG_VERSION"),
|
||||
"patch-manager-worker starting"
|
||||
);
|
||||
|
||||
// Initialize database pool
|
||||
let pool = db::init_pool(&config.database).await?;
|
||||
|
||||
// Wait for schema to be at the expected version (web process runs migrations)
|
||||
wait_for_schema(&pool).await?;
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Spawn worker tasks
|
||||
let heartbeat_handle = tokio::spawn(run_heartbeat(
|
||||
pool.clone(),
|
||||
config.worker.heartbeat_interval_secs,
|
||||
));
|
||||
|
||||
// M4: agent health poller, patch data poller, on-demand refresh listener
|
||||
let health_handle = tokio::spawn(run_health_poller(pool.clone(), config.clone()));
|
||||
let patch_handle = tokio::spawn(run_patch_poller(pool.clone(), config.clone()));
|
||||
let refresh_handle = tokio::spawn(run_refresh_listener(pool.clone(), config.clone()));
|
||||
|
||||
// M5: job execution engine
|
||||
let job_exec_handle = tokio::spawn(run_job_executor(pool.clone(), config.clone()));
|
||||
|
||||
// M6: maintenance window scheduler
|
||||
let maint_sched_handle = tokio::spawn(run_maintenance_scheduler(pool.clone(), config.clone()));
|
||||
|
||||
// M7: WS relay — streams agent job events → DB → pg_notify → browser WS
|
||||
let ws_relay_handle = tokio::spawn(run_ws_relay(pool.clone(), config.clone()));
|
||||
|
||||
// M11: audit integrity verification (runs every 24 hours)
|
||||
let audit_verifier_handle = tokio::spawn(run_audit_verifier(pool.clone(), config.clone()));
|
||||
|
||||
// Health check poller — runs configured service/HTTP health checks
|
||||
let health_check_handle = tokio::spawn(run_health_check_poller(pool.clone(), config.clone()));
|
||||
|
||||
// Enrollment cleanup task (runs every hour)
|
||||
let enrollment_cleanup_handle = tokio::spawn(run_enrollment_cleanup_task(pool.clone()));
|
||||
|
||||
tracing::info!("Worker tasks started");
|
||||
|
||||
// Wait for all tasks (they run indefinitely)
|
||||
let _ = tokio::join!(
|
||||
heartbeat_handle,
|
||||
health_handle,
|
||||
patch_handle,
|
||||
refresh_handle,
|
||||
job_exec_handle,
|
||||
maint_sched_handle,
|
||||
ws_relay_handle,
|
||||
audit_verifier_handle,
|
||||
health_check_handle,
|
||||
enrollment_cleanup_handle,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Wait until the database schema has at least `REQUIRED_MIGRATION_COUNT`
|
||||
/// successful migrations applied. Retries every 5 seconds up to
|
||||
/// `SCHEMA_CHECK_TIMEOUT`.
|
||||
async fn wait_for_schema(pool: &PgPool) -> anyhow::Result<()> {
|
||||
let deadline = tokio::time::Instant::now() + SCHEMA_CHECK_TIMEOUT;
|
||||
|
||||
loop {
|
||||
match db::check_schema_version(pool).await {
|
||||
Ok(count) if count >= REQUIRED_MIGRATION_COUNT => {
|
||||
tracing::info!(migration_count = count, "Schema version check passed");
|
||||
return Ok(());
|
||||
},
|
||||
Ok(count) => {
|
||||
tracing::warn!(
|
||||
migration_count = count,
|
||||
required = REQUIRED_MIGRATION_COUNT,
|
||||
"Schema not ready, waiting..."
|
||||
);
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Schema version check failed, retrying...");
|
||||
},
|
||||
}
|
||||
|
||||
if tokio::time::Instant::now() >= deadline {
|
||||
anyhow::bail!(
|
||||
"Schema not ready after {}s — is the web process running migrations?",
|
||||
SCHEMA_CHECK_TIMEOUT.as_secs()
|
||||
);
|
||||
}
|
||||
|
||||
time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes a heartbeat row to `worker_heartbeat` every `interval_secs`.
|
||||
/// The web process can query this to confirm the worker is alive.
|
||||
async fn run_heartbeat(pool: PgPool, interval_secs: u64) {
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
let mut ticker = time::interval(interval);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO worker_heartbeat (id, last_seen, worker_version)
|
||||
VALUES (1, NOW(), $1)
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET last_seen = EXCLUDED.last_seen,
|
||||
worker_version = EXCLUDED.worker_version
|
||||
"#,
|
||||
)
|
||||
.bind(env!("CARGO_PKG_VERSION"))
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => tracing::debug!("Worker heartbeat written"),
|
||||
Err(e) => tracing::error!(error = %e, "Worker heartbeat failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Periodically deletes expired enrollment requests.
|
||||
async fn run_enrollment_cleanup_task(pool: PgPool) {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(3600)); // Every hour
|
||||
interval.tick().await; // Initial tick to run immediately if needed
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = Utc::now();
|
||||
match sqlx::query("DELETE FROM enrollment_requests WHERE expires_at < $1")
|
||||
.bind(now)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!(
|
||||
removed = result.rows_affected(),
|
||||
"Purged expired enrollment requests"
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => tracing::error!(error = %e, "Failed to purge expired enrollment requests"),
|
||||
}
|
||||
}
|
||||
}
|
||||
381
crates/pm-worker/src/maintenance_scheduler.rs
Executable file
381
crates/pm-worker/src/maintenance_scheduler.rs
Executable file
@ -0,0 +1,381 @@
|
||||
//! Maintenance window scheduler.
|
||||
//!
|
||||
//! Polls every 60 seconds and performs two tasks:
|
||||
//!
|
||||
//! 1. **Auto-apply**: For each enabled maintenance window with `auto_apply = true`
|
||||
//! that is currently open, if the host has pending patches and no existing
|
||||
//! patch_apply job queued/running for that window, automatically creates one.
|
||||
//!
|
||||
//! 2. **Dispatch**: For each open window, dispatch any queued non-immediate
|
||||
//! patch jobs associated with the window's host.
|
||||
//!
|
||||
//! A window is considered "open" when:
|
||||
//! - `once` — `start_at <= NOW() < start_at + duration_minutes * '1 minute'`
|
||||
//! - `daily` — current UTC time-of-day is within the window's daily slot
|
||||
//! - `weekly` — same as daily, but only on the matching `recurrence_day` (0=Sun)
|
||||
//! - `monthly` — same as daily, but only on the matching `recurrence_day` (1-31)
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::time;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::job_executor::process_job;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Internal types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct OpenWindowHost {
|
||||
host_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct QueuedJobId {
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct AutoApplyWindow {
|
||||
window_id: Uuid,
|
||||
host_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct PendingPatchHost {
|
||||
host_id: Uuid,
|
||||
patch_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct InsertedJobId {
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public entry point
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run the maintenance scheduler indefinitely.
|
||||
/// Spawned by `pm-worker/src/main.rs` alongside the job executor.
|
||||
pub async fn run_maintenance_scheduler(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Maintenance scheduler started");
|
||||
|
||||
// First tick fires immediately; consume it to align with job_executor.
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(60));
|
||||
ticker.tick().await;
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
tracing::debug!("Maintenance scheduler: checking open windows");
|
||||
|
||||
// Step 1: Auto-create patch_apply jobs for windows with auto_apply=true
|
||||
auto_create_patch_jobs(pool.clone(), config.clone()).await;
|
||||
|
||||
// Step 2: Dispatch any queued non-immediate jobs for open windows
|
||||
dispatch_open_window_jobs(pool.clone(), config.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Step 1: Auto-create patch_apply jobs
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// For each enabled maintenance window that is currently open AND has
|
||||
/// `auto_apply = true`, check if the host has pending patches and no
|
||||
/// existing patch_apply job for this window cycle. If so, create one.
|
||||
async fn auto_create_patch_jobs(pool: PgPool, _config: Arc<AppConfig>) {
|
||||
// Find all open windows with auto_apply=true
|
||||
let auto_windows: Vec<AutoApplyWindow> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT mw.id AS window_id, mw.host_id
|
||||
FROM maintenance_windows mw
|
||||
WHERE mw.enabled = TRUE
|
||||
AND mw.auto_apply = TRUE
|
||||
AND (
|
||||
( mw.recurrence = 'once'
|
||||
AND mw.start_at <= NOW()
|
||||
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
|
||||
)
|
||||
OR
|
||||
( mw.recurrence = 'daily'
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
( mw.recurrence = 'weekly'
|
||||
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
( mw.recurrence = 'monthly'
|
||||
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(w) => w,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "auto_create_patch_jobs: open-windows query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if auto_windows.is_empty() {
|
||||
tracing::debug!("auto_create: no open auto-apply windows this cycle");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
auto_window_count = auto_windows.len(),
|
||||
"auto_create: found open auto-apply windows"
|
||||
);
|
||||
|
||||
for win in &auto_windows {
|
||||
// Check if host has pending patches
|
||||
let pending: Option<PendingPatchHost> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT host_id, patch_count
|
||||
FROM host_patch_data
|
||||
WHERE host_id = $1 AND patch_count > 0
|
||||
"#,
|
||||
)
|
||||
.bind(win.host_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %win.host_id,
|
||||
"auto_create: patch data query failed"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let Some(pending) = pending else {
|
||||
tracing::debug!(
|
||||
host_id = %win.host_id,
|
||||
"auto_create: no pending patches, skipping"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
// Check if there's already a queued/running patch_apply job for this host
|
||||
// that was created during this window cycle (within the window's time range).
|
||||
// We use a simpler check: any non-completed patch_apply job for this host
|
||||
// that references this maintenance window, OR any non-immediate job without
|
||||
// a window that was created since the window opened.
|
||||
let existing_job: bool = match sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM patch_jobs pj
|
||||
JOIN patch_job_hosts pjh ON pj.id = pjh.job_id
|
||||
WHERE pjh.host_id = $1
|
||||
AND pj.status IN ('queued', 'running', 'pending')
|
||||
AND pj.kind = 'patch_apply'
|
||||
AND (
|
||||
pj.maintenance_window_id = $2
|
||||
OR
|
||||
(pj.immediate = FALSE AND pj.created_at >=
|
||||
(SELECT start_at - INTERVAL '5 minutes' FROM maintenance_windows WHERE id = $2)
|
||||
)
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(win.host_id)
|
||||
.bind(win.window_id)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %win.host_id,
|
||||
"auto_create: existing job check failed"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if existing_job {
|
||||
tracing::debug!(
|
||||
host_id = %win.host_id,
|
||||
window_id = %win.window_id,
|
||||
"auto_create: existing job already queued/running, skipping"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create a new patch_apply job for this host, linked to the window.
|
||||
let job: Option<InsertedJobId> = match sqlx::query_as(
|
||||
r#"
|
||||
WITH new_job AS (
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, maintenance_window_id, immediate, patch_selection, notes)
|
||||
VALUES
|
||||
('patch_apply', 'queued', $1, FALSE, '[]'::jsonb,
|
||||
'Auto-created by maintenance window scheduler')
|
||||
RETURNING id AS job_id
|
||||
)
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
SELECT new_job.job_id, $2, 'queued'
|
||||
FROM new_job
|
||||
RETURNING job_id
|
||||
"#,
|
||||
)
|
||||
.bind(win.window_id)
|
||||
.bind(win.host_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %win.host_id,
|
||||
window_id = %win.window_id,
|
||||
"auto_create: job insert failed"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(job) = job {
|
||||
tracing::info!(
|
||||
job_id = %job.job_id,
|
||||
host_id = %win.host_id,
|
||||
window_id = %win.window_id,
|
||||
patch_count = pending.patch_count,
|
||||
"auto_create: created patch_apply job for host in maintenance window"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Step 2: Dispatch queued non-immediate jobs
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Find all hosts with a currently-open maintenance window, then for each,
|
||||
/// find their queued non-immediate job entries and dispatch them.
|
||||
async fn dispatch_open_window_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
// ── 1. Find all host_ids with an open window right now ─────────────────
|
||||
let open_hosts: Vec<OpenWindowHost> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT mw.host_id
|
||||
FROM maintenance_windows mw
|
||||
WHERE mw.enabled = TRUE
|
||||
AND (
|
||||
-- One-time: absolute window
|
||||
( mw.recurrence = 'once'
|
||||
AND mw.start_at <= NOW()
|
||||
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
|
||||
)
|
||||
OR
|
||||
-- Daily: time-of-day slot, any day
|
||||
( mw.recurrence = 'daily'
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
-- Weekly: matching day-of-week + time-of-day slot
|
||||
( mw.recurrence = 'weekly'
|
||||
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
-- Monthly: matching day-of-month + time-of-day slot
|
||||
( mw.recurrence = 'monthly'
|
||||
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(hosts) => hosts,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "dispatch_open_window_jobs: open-hosts query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if open_hosts.is_empty() {
|
||||
tracing::debug!("Maintenance scheduler: no open windows this cycle");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
open_host_count = open_hosts.len(),
|
||||
"Maintenance scheduler: found hosts with open windows"
|
||||
);
|
||||
|
||||
// ── 2. For each open host, find distinct queued non-immediate job IDs ──
|
||||
for host in open_hosts {
|
||||
let job_ids: Vec<QueuedJobId> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT pjh.job_id
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs j ON j.id = pjh.job_id
|
||||
WHERE pjh.host_id = $1
|
||||
AND pjh.status = 'queued'
|
||||
AND j.immediate = FALSE
|
||||
AND j.status != 'cancelled'
|
||||
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
|
||||
"#,
|
||||
)
|
||||
.bind(host.host_id)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %host.host_id,
|
||||
"dispatch_open_window_jobs: queued jobs query failed"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
for job in job_ids {
|
||||
tracing::info!(
|
||||
job_id = %job.job_id,
|
||||
host_id = %host.host_id,
|
||||
"Maintenance scheduler: dispatching non-immediate job (window open)"
|
||||
);
|
||||
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
let job_id = job.job_id;
|
||||
tokio::spawn(async move {
|
||||
process_job(p, c, job_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
202
crates/pm-worker/src/patch_poller.rs
Executable file
202
crates/pm-worker/src/patch_poller.rs
Executable file
@ -0,0 +1,202 @@
|
||||
//! Periodic patch-data poller for all registered hosts.
|
||||
//!
|
||||
//! Polls every host via the agent `/patches` and `/packages` endpoints on
|
||||
//! each tick of `patch_poll_interval_secs`, with bounded concurrency
|
||||
//! controlled by a [`tokio::sync::Semaphore`].
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::AgentClient;
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host projection fetched for each poll cycle.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the patch poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all registered hosts are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_patch_data` and `hosts.last_patch_at` is updated.
|
||||
pub async fn run_patch_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.patch_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(interval_secs, "Patch poller started");
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller: failed to load agent certs — skipping cycle");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
let hosts: Vec<HostRow> = match sqlx::query_as(
|
||||
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller: failed to fetch hosts");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!("Patch poller: no hosts registered, skipping cycle");
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = hosts.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for host in hosts {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
poll_host_patches(pool, host, &cert, &key, &ca).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let mut succeeded = 0usize;
|
||||
let mut failed = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(true) => succeeded += 1,
|
||||
Ok(false) => failed += 1,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller task panicked");
|
||||
failed += 1;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(total, succeeded, failed, "Patch poll cycle complete");
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll a single host for patch and package data, persist the result.
|
||||
/// Returns `true` on success, `false` on any error.
|
||||
async fn poll_host_patches(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> bool {
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: failed to build AgentClient");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
// Fetch patches and packages concurrently.
|
||||
let (patches_result, packages_result) =
|
||||
tokio::join!(client.patches(), client.packages_upgradable());
|
||||
|
||||
let patches_data = match patches_result {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: patches() failed");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let packages_data = match packages_result {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: packages_upgradable() failed");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
|
||||
let installed_packages = serde_json::to_value(&packages_data.packages).unwrap_or_default();
|
||||
let patch_count = patches_data.total as i32;
|
||||
let cve_count = patches_data
|
||||
.patches
|
||||
.iter()
|
||||
.filter(|p| !p.cve_ids.is_empty())
|
||||
.count() as i32;
|
||||
|
||||
// Upsert into host_patch_data (one row per host, latest poll wins).
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_patch_data
|
||||
(host_id, available_patches, installed_packages, patch_count, cve_count)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (host_id) DO UPDATE SET
|
||||
available_patches = EXCLUDED.available_patches,
|
||||
installed_packages = EXCLUDED.installed_packages,
|
||||
patch_count = EXCLUDED.patch_count,
|
||||
cve_count = EXCLUDED.cve_count,
|
||||
polled_at = NOW()
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&available_patches)
|
||||
.bind(&installed_packages)
|
||||
.bind(patch_count)
|
||||
.bind(cve_count)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to insert patch data");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update hosts.last_patch_at.
|
||||
if let Err(e) = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
|
||||
.bind(host.id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to update last_patch_at");
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
host_id = %host.id,
|
||||
patch_count,
|
||||
cve_count,
|
||||
"Patch data collected"
|
||||
);
|
||||
|
||||
true
|
||||
}
|
||||
269
crates/pm-worker/src/refresh_listener.rs
Executable file
269
crates/pm-worker/src/refresh_listener.rs
Executable file
@ -0,0 +1,269 @@
|
||||
//! On-demand refresh listener.
|
||||
//!
|
||||
//! Listens on the PostgreSQL `refresh_requested` NOTIFY channel. When a
|
||||
//! notification arrives the payload is expected to be a host UUID string.
|
||||
//! The listener immediately polls that host for health and patch data and
|
||||
//! persists the results — bypassing the normal poll intervals.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
use pm_core::{config::AppConfig, models::HostHealthStatus};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::time;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host row used for on-demand refresh.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the LISTEN/NOTIFY refresh listener indefinitely.
|
||||
///
|
||||
/// Automatically reconnects if the underlying PostgreSQL connection drops.
|
||||
pub async fn run_refresh_listener(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Refresh listener started — listening on 'refresh_requested'");
|
||||
|
||||
loop {
|
||||
if let Err(e) = listen_loop(&pool, &config).await {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Refresh listener disconnected, reconnecting in 5s"
|
||||
);
|
||||
time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inner loop — returns `Err` only on a fatal listener error so the outer
|
||||
/// loop can reconnect.
|
||||
async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> {
|
||||
let mut listener = sqlx::postgres::PgListener::connect(&config.database.url).await?;
|
||||
|
||||
listener.listen("refresh_requested").await?;
|
||||
|
||||
tracing::debug!("Refresh listener connected and listening");
|
||||
|
||||
loop {
|
||||
let notification = listener.recv().await?;
|
||||
let payload = notification.payload().to_string();
|
||||
|
||||
tracing::info!(payload, "Refresh notification received");
|
||||
|
||||
let host_id = match payload.parse::<Uuid>() {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
payload,
|
||||
error = %e,
|
||||
"Refresh listener: invalid UUID in notification payload"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Fetch the host from the database.
|
||||
let host: Option<HostRow> = sqlx::query_as(
|
||||
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts WHERE id = $1",
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let host = match host {
|
||||
Some(h) => h,
|
||||
None => {
|
||||
tracing::warn!(%host_id, "Refresh listener: host not found");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Load certs for this refresh.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
%host_id,
|
||||
error = %e,
|
||||
"Refresh listener: failed to load agent certs"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Spawn the actual work so the listener loop is not blocked.
|
||||
let pool_clone = pool.clone();
|
||||
let cert = certs.client_cert;
|
||||
let key = certs.client_key;
|
||||
let ca = certs.ca_cert;
|
||||
|
||||
tokio::spawn(async move {
|
||||
refresh_host(pool_clone, host, &cert, &key, &ca).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a full health + patch refresh for one host and persist results.
|
||||
async fn refresh_host(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) {
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to build AgentClient"
|
||||
);
|
||||
persist_health_unreachable(&pool, host.id).await;
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
// ── Health ────────────────────────────────────────────────────────────
|
||||
let (health_status, health_payload) = match client.health().await {
|
||||
Ok(data) => {
|
||||
let payload = serde_json::to_value(&data).unwrap_or_default();
|
||||
(HostHealthStatus::Healthy, payload)
|
||||
},
|
||||
Err(AgentClientError::Timeout) | Err(AgentClientError::Connect(_)) => {
|
||||
tracing::warn!(host_id = %host.id, "Refresh: agent unreachable");
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
)
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Refresh: health error");
|
||||
(
|
||||
HostHealthStatus::Degraded,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
)
|
||||
},
|
||||
};
|
||||
|
||||
persist_health(&pool, host.id, &health_status, &health_payload).await;
|
||||
|
||||
// ── Patch data ────────────────────────────────────────────────────────
|
||||
let (patches_result, packages_result) =
|
||||
tokio::join!(client.patches(), client.packages_upgradable());
|
||||
|
||||
match (patches_result, packages_result) {
|
||||
(Ok(patches_data), Ok(packages_data)) => {
|
||||
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
|
||||
let installed_packages =
|
||||
serde_json::to_value(&packages_data.packages).unwrap_or_default();
|
||||
let patch_count = patches_data.total as i32;
|
||||
let cve_count = patches_data
|
||||
.patches
|
||||
.iter()
|
||||
.filter(|p| !p.cve_ids.is_empty())
|
||||
.count() as i32;
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_patch_data
|
||||
(host_id, available_patches, installed_packages, patch_count, cve_count)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (host_id) DO UPDATE SET
|
||||
available_patches = EXCLUDED.available_patches,
|
||||
installed_packages = EXCLUDED.installed_packages,
|
||||
patch_count = EXCLUDED.patch_count,
|
||||
cve_count = EXCLUDED.cve_count,
|
||||
polled_at = NOW()
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&available_patches)
|
||||
.bind(&installed_packages)
|
||||
.bind(patch_count)
|
||||
.bind(cve_count)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to insert patch data"
|
||||
);
|
||||
} else {
|
||||
let _ = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
|
||||
.bind(host.id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
host_id = %host.id,
|
||||
patch_count,
|
||||
cve_count,
|
||||
"On-demand refresh complete"
|
||||
);
|
||||
}
|
||||
},
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to collect patch data"
|
||||
);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_health_unreachable(pool: &PgPool, host_id: Uuid) {
|
||||
let status = HostHealthStatus::Unreachable;
|
||||
let payload = serde_json::Value::Object(Default::default());
|
||||
persist_health(pool, host_id, &status, &payload).await;
|
||||
}
|
||||
|
||||
async fn persist_health(
|
||||
pool: &PgPool,
|
||||
host_id: Uuid,
|
||||
status: &HostHealthStatus,
|
||||
payload: &serde_json::Value,
|
||||
) {
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_data (host_id, status, payload)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(status)
|
||||
.bind(payload)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
%host_id,
|
||||
error = %e,
|
||||
"Refresh: failed to insert health data"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
sqlx::query("UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.bind(status)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%host_id, error = %e, "Refresh: failed to update host health_status");
|
||||
}
|
||||
}
|
||||
718
crates/pm-worker/src/ws_relay.rs
Executable file
718
crates/pm-worker/src/ws_relay.rs
Executable file
@ -0,0 +1,718 @@
|
||||
//! WS relay — M7
|
||||
//!
|
||||
//! For every running `patch_job_hosts` row that has an `agent_job_id`, open a
|
||||
//! WebSocket to the corresponding agent, stream job-status events, update the
|
||||
//! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS
|
||||
//! handler can forward the event to connected clients.
|
||||
|
||||
use std::{collections::HashSet, error::Error, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::StreamExt;
|
||||
use rustls::{
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
ClientConfig as TlsClientConfig, RootCertStore,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::protocol::Message, Connector};
|
||||
use uuid::Uuid;
|
||||
|
||||
use pm_agent_client::client::AgentClient;
|
||||
use pm_agent_client::client::DEFAULT_AGENT_PORT;
|
||||
use pm_core::config::AppConfig;
|
||||
|
||||
// ── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct RunningHostJob {
|
||||
job_id: Uuid,
|
||||
host_id: Uuid,
|
||||
agent_job_id: String,
|
||||
host_address: String,
|
||||
}
|
||||
|
||||
/// JSON event streamed by the agent over its WS endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentWsEvent {
|
||||
#[allow(dead_code)]
|
||||
job_id: String,
|
||||
status: String,
|
||||
output: Option<String>,
|
||||
error: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
progress_percent: Option<u8>,
|
||||
}
|
||||
|
||||
/// Payload broadcast via `pg_notify('job_update', …)`.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NotifyPayload {
|
||||
event_type: String, // "host" or "job"
|
||||
job_id: String,
|
||||
host_id: String,
|
||||
status: String,
|
||||
output: Option<String>,
|
||||
error_message: Option<String>,
|
||||
agent_job_id: String,
|
||||
// Job-level fields (only present when event_type === "job")
|
||||
succeeded_count: Option<i64>,
|
||||
failed_count: Option<i64>,
|
||||
host_count: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Cert PEM bytes for building AgentClient ────────────────────────────────────
|
||||
|
||||
/// Raw PEM bytes read from the security config cert paths.
|
||||
/// Used to build an [`AgentClient`] for HTTP polling fallback.
|
||||
#[derive(Clone)]
|
||||
struct CertPems {
|
||||
client_cert: Vec<u8>,
|
||||
client_key: Vec<u8>,
|
||||
ca_cert: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Read the three PEM files referenced by the security config.
|
||||
/// Mirrors the file reads in [`build_tls_config`] but returns raw bytes instead of
|
||||
/// parsing them into rustls types.
|
||||
async fn read_cert_pems(config: &AppConfig) -> anyhow::Result<CertPems> {
|
||||
let sec = &config.security;
|
||||
let client_cert = tokio::fs::read(&sec.agent_client_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
|
||||
let client_key = tokio::fs::read(&sec.agent_client_key_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
|
||||
let ca_cert = tokio::fs::read(&sec.ca_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
|
||||
Ok(CertPems {
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Long-running task: polls the DB for running host-jobs and spawns a per-pair
|
||||
/// relay task for each one that isn't already being tracked.
|
||||
pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("WS relay task started");
|
||||
|
||||
let active: Arc<Mutex<HashSet<(Uuid, Uuid)>>> = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(10));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let rows = match query_running_jobs(&pool).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: DB poll failed");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
for row in rows {
|
||||
let key = (row.job_id, row.host_id);
|
||||
|
||||
// Skip pairs that already have an active relay.
|
||||
if active.lock().await.contains(&key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build the rustls ClientConfig once per connection.
|
||||
let tls_config = match build_tls_config(&config).await {
|
||||
Ok(c) => Arc::new(c),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: TLS config error");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Read raw cert PEM bytes for HTTP polling fallback.
|
||||
let cert_pems = match read_cert_pems(&config).await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: cert PEM read error");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let poll_interval = config.worker.ws_relay_poll_interval_secs;
|
||||
|
||||
active.lock().await.insert(key);
|
||||
|
||||
let pool_c = pool.clone();
|
||||
let active_c = active.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
agent_job_id = %row.agent_job_id,
|
||||
host = %row.host_address,
|
||||
"WS relay: starting relay"
|
||||
);
|
||||
|
||||
match relay_one_job(&pool_c, &row, tls_config, &cert_pems, poll_interval).await {
|
||||
Ok(()) => tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: completed"
|
||||
),
|
||||
Err(e) => tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: ended with error"
|
||||
),
|
||||
}
|
||||
|
||||
active_c.lock().await.remove(&key);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── DB helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn query_running_jobs(pool: &PgPool) -> anyhow::Result<Vec<RunningHostJob>> {
|
||||
sqlx::query_as::<_, RunningHostJob>(
|
||||
r#"
|
||||
SELECT
|
||||
pjh.job_id,
|
||||
pjh.host_id,
|
||||
pjh.agent_job_id,
|
||||
COALESCE(h.fqdn, host(h.ip_address)::text) AS host_address
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.status = 'running'::job_status
|
||||
AND pjh.agent_job_id IS NOT NULL
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("query_running_jobs")
|
||||
}
|
||||
|
||||
// ── TLS ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn build_tls_config(config: &AppConfig) -> anyhow::Result<TlsClientConfig> {
|
||||
let sec = &config.security;
|
||||
|
||||
let cert_pem = tokio::fs::read(&sec.agent_client_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
|
||||
let key_pem = tokio::fs::read(&sec.agent_client_key_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
|
||||
let ca_pem = tokio::fs::read(&sec.ca_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
|
||||
|
||||
// Parse client certificate chain.
|
||||
let client_certs: Vec<CertificateDer<'static>> = {
|
||||
let mut cur = std::io::Cursor::new(&cert_pem);
|
||||
rustls_pemfile::certs(&mut cur)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.context("parse client cert PEM")?
|
||||
};
|
||||
|
||||
// Parse client private key.
|
||||
let client_key: PrivateKeyDer<'static> = {
|
||||
let mut cur = std::io::Cursor::new(&key_pem);
|
||||
rustls_pemfile::private_key(&mut cur)
|
||||
.context("parse client key PEM")?
|
||||
.context("no private key in PEM")?
|
||||
};
|
||||
|
||||
// Build root store from CA cert.
|
||||
let mut root_store = RootCertStore::empty();
|
||||
{
|
||||
let mut cur = std::io::Cursor::new(&ca_pem);
|
||||
for cert_result in rustls_pemfile::certs(&mut cur) {
|
||||
root_store
|
||||
.add(cert_result.context("read CA cert entry")?)
|
||||
.context("add CA cert to root store")?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut config = TlsClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_client_auth_cert(client_certs, client_key)
|
||||
.context("build TlsClientConfig")?;
|
||||
|
||||
// WebSocket requires HTTP/1.1 — without ALPN the server may negotiate
|
||||
// h2 (HTTP/2), which breaks the WebSocket upgrade handshake
|
||||
// ("Key mismatch in Sec-WebSocket-Accept header").
|
||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
// ── Per-job relay ─────────────────────────────────────────────────────────────
|
||||
|
||||
async fn relay_one_job(
|
||||
pool: &PgPool,
|
||||
row: &RunningHostJob,
|
||||
tls_config: Arc<TlsClientConfig>,
|
||||
cert_pems: &CertPems,
|
||||
poll_interval_secs: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!(
|
||||
"wss://{}:{}/api/v1/ws/jobs",
|
||||
row.host_address, DEFAULT_AGENT_PORT,
|
||||
);
|
||||
|
||||
let (ws_stream, _) = match connect_async_tls_with_config(
|
||||
url.as_str(),
|
||||
None,
|
||||
false,
|
||||
Some(Connector::Rustls(tls_config)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(ws) => ws,
|
||||
Err(e) => {
|
||||
// Log the full error chain for TLS debugging
|
||||
let mut source = e.source();
|
||||
let mut depth = 0;
|
||||
while let Some(err) = source {
|
||||
tracing::warn!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
depth,
|
||||
error = %err,
|
||||
"WS relay: TLS connection error detail"
|
||||
);
|
||||
source = err.source();
|
||||
depth += 1;
|
||||
}
|
||||
// Fall back to HTTP polling instead of returning an error.
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
host = %row.host_address,
|
||||
"WS relay: WebSocket connection failed, falling back to HTTP polling"
|
||||
);
|
||||
return relay_one_job_poll(pool, row, cert_pems, poll_interval_secs).await;
|
||||
},
|
||||
};
|
||||
|
||||
let (_sink, mut stream) = ws_stream.split();
|
||||
|
||||
while let Some(frame) = stream.next().await {
|
||||
let frame = match frame {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: stream error"
|
||||
);
|
||||
break;
|
||||
},
|
||||
};
|
||||
|
||||
let text = match frame {
|
||||
Message::Text(t) => t.to_string(),
|
||||
Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(),
|
||||
Message::Close(_) => {
|
||||
tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly");
|
||||
break;
|
||||
},
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event: AgentWsEvent = match serde_json::from_str(&text) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e, raw = %text,
|
||||
"WS relay: unparseable agent frame"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
process_event(pool, row, &event).await;
|
||||
|
||||
if matches!(event.status.as_str(), "succeeded" | "failed" | "cancelled") {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %event.status,
|
||||
"WS relay: terminal state — stopping"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Per-job HTTP polling fallback ─────────────────────────────────────────────
|
||||
|
||||
/// Fall back to HTTP polling when the WebSocket connection fails.
|
||||
///
|
||||
/// Builds an [`AgentClient`] from the same cert paths used for TLS, polls
|
||||
/// `GET /api/v1/jobs/{id}` every `poll_interval_secs`, and calls
|
||||
/// [`process_event`] whenever the status changes from the previous poll.
|
||||
/// Stops when the job reaches a terminal state (succeeded/failed/cancelled).
|
||||
async fn relay_one_job_poll(
|
||||
pool: &PgPool,
|
||||
row: &RunningHostJob,
|
||||
cert_pems: &CertPems,
|
||||
poll_interval_secs: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
// Build the HTTP client using the same mTLS certs.
|
||||
let agent_client = AgentClient::new(
|
||||
&row.host_address,
|
||||
DEFAULT_AGENT_PORT,
|
||||
&cert_pems.client_cert,
|
||||
&cert_pems.client_key,
|
||||
&cert_pems.ca_cert,
|
||||
)
|
||||
.context("build AgentClient for polling fallback")?;
|
||||
|
||||
let mut last_status: Option<String> = None;
|
||||
let poll_interval = Duration::from_secs(poll_interval_secs);
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let job_status = match agent_client.job_status(&row.agent_job_id).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
error = %e,
|
||||
"WS relay poll: job_status request failed, will retry"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Map agent status to the WS event format.
|
||||
// The agent uses "completed" but pm-worker expects "succeeded".
|
||||
// The agent uses "queued" but pm-worker expects "running".
|
||||
let mapped_status = match job_status.status.as_str() {
|
||||
"queued" => "running",
|
||||
"running" => "running",
|
||||
"succeeded" => "succeeded",
|
||||
"completed" => "succeeded",
|
||||
"failed" => "failed",
|
||||
"cancelled" => "cancelled",
|
||||
other => {
|
||||
tracing::warn!(
|
||||
status = %other,
|
||||
job_id = %row.job_id,
|
||||
"WS relay poll: unknown agent status, treating as running"
|
||||
);
|
||||
"running"
|
||||
},
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
agent_status = %job_status.status,
|
||||
mapped_status = %mapped_status,
|
||||
progress = ?job_status.progress_percent,
|
||||
"WS relay poll: fetched job status"
|
||||
);
|
||||
|
||||
// Only process when the status has changed since the last poll.
|
||||
if last_status.as_deref() == Some(mapped_status) {
|
||||
continue;
|
||||
}
|
||||
|
||||
last_status = Some(mapped_status.to_string());
|
||||
|
||||
let event = AgentWsEvent {
|
||||
job_id: job_status.job_id.clone(),
|
||||
status: mapped_status.to_string(),
|
||||
output: job_status.output.clone(),
|
||||
error: job_status.error.clone(),
|
||||
progress_percent: job_status.progress_percent,
|
||||
};
|
||||
|
||||
process_event(pool, row, &event).await;
|
||||
|
||||
// Stop polling on terminal states.
|
||||
if matches!(mapped_status, "succeeded" | "failed" | "cancelled") {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %mapped_status,
|
||||
"WS relay poll: terminal state — stopping"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Event processing ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) {
|
||||
// Map agent status string to DB job_status enum value.
|
||||
let db_status = match event.status.as_str() {
|
||||
"running" => "running",
|
||||
"succeeded" => "succeeded",
|
||||
"failed" => "failed",
|
||||
"cancelled" => "cancelled",
|
||||
other => {
|
||||
tracing::warn!(status = %other, "WS relay: unknown agent status");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
let output = event.output.as_deref().unwrap_or("");
|
||||
let error_msg = event.error.as_deref();
|
||||
|
||||
// Determine timestamps based on terminal state.
|
||||
let is_terminal = matches!(db_status, "succeeded" | "failed" | "cancelled");
|
||||
|
||||
// Update the DB row.
|
||||
let update_result = if is_terminal {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = $1::job_status,
|
||||
output = CASE WHEN $2 != '' THEN $2 ELSE output END,
|
||||
error_message = $3,
|
||||
completed_at = NOW()
|
||||
WHERE job_id = $4
|
||||
AND host_id = $5
|
||||
"#,
|
||||
)
|
||||
.bind(db_status)
|
||||
.bind(output)
|
||||
.bind(error_msg)
|
||||
.bind(row.job_id)
|
||||
.bind(row.host_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = $1::job_status,
|
||||
output = CASE WHEN $2 != '' THEN $2 ELSE output END
|
||||
WHERE job_id = $3
|
||||
AND host_id = $4
|
||||
"#,
|
||||
)
|
||||
.bind(db_status)
|
||||
.bind(output)
|
||||
.bind(row.job_id)
|
||||
.bind(row.host_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
};
|
||||
|
||||
if let Err(e) = update_result {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: DB update failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Also update the parent patch_jobs status when the host-level job reaches
|
||||
// a terminal state: running → if all hosts terminal then update parent.
|
||||
if is_terminal {
|
||||
update_parent_job_status(pool, row.job_id).await;
|
||||
}
|
||||
|
||||
// Fire pg_notify so browser WS handlers forward the host-level event.
|
||||
let payload = NotifyPayload {
|
||||
event_type: "host".to_string(),
|
||||
job_id: row.job_id.to_string(),
|
||||
host_id: row.host_id.to_string(),
|
||||
status: db_status.to_string(),
|
||||
output: event.output.clone(),
|
||||
error_message: event.error.clone(),
|
||||
agent_job_id: row.agent_job_id.clone(),
|
||||
succeeded_count: None,
|
||||
failed_count: None,
|
||||
host_count: None,
|
||||
};
|
||||
|
||||
let payload_json = match serde_json::to_string(&payload) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "WS relay: failed to serialize notify payload");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_json)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: pg_notify failed"
|
||||
);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %db_status,
|
||||
"WS relay: pg_notify sent"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Parent job status rollup ──────────────────────────────────────────────────
|
||||
|
||||
/// After a host-level job reaches a terminal state, check whether ALL hosts for
|
||||
/// that job are now terminal and update the parent `patch_jobs` row accordingly.
|
||||
///
|
||||
/// If the parent job transitions to a terminal status, also fires a `job_update`
|
||||
/// pg_notify with `event_type: "job"` so the frontend can update the job row.
|
||||
async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) {
|
||||
// Count hosts that are still in a non-terminal state.
|
||||
let pending: i64 = match sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
AND status NOT IN (
|
||||
'succeeded'::job_status,
|
||||
'failed'::job_status,
|
||||
'cancelled'::job_status
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if pending > 0 {
|
||||
return; // still hosts running — parent stays running
|
||||
}
|
||||
|
||||
// All hosts terminal — determine final parent status and counts.
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct RollupCounts {
|
||||
total: i64,
|
||||
succeeded: i64,
|
||||
failed: i64,
|
||||
}
|
||||
|
||||
let counts: RollupCounts = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE status = 'succeeded') AS succeeded,
|
||||
COUNT(*) FILTER (WHERE status = 'failed') AS failed
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: rollup query failed");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
let final_status = if counts.failed > 0 {
|
||||
"failed"
|
||||
} else {
|
||||
"succeeded"
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(final_status)
|
||||
.bind(job_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"update_parent_job_status: UPDATE failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"Parent job status updated"
|
||||
);
|
||||
|
||||
// Fire job-level pg_notify so the frontend can update the job row.
|
||||
let payload = NotifyPayload {
|
||||
event_type: "job".to_string(),
|
||||
job_id: job_id.to_string(),
|
||||
host_id: String::new(), // no specific host for job-level events
|
||||
status: final_status.to_string(),
|
||||
output: None,
|
||||
error_message: None,
|
||||
agent_job_id: String::new(),
|
||||
succeeded_count: Some(counts.succeeded),
|
||||
failed_count: Some(counts.failed),
|
||||
host_count: Some(counts.total),
|
||||
};
|
||||
|
||||
let payload_json = match serde_json::to_string(&payload) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: failed to serialize job-level notify payload");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_json)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
%job_id,
|
||||
"update_parent_job_status: job-level pg_notify failed"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"Job-level pg_notify sent"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user