Private
Public Access
1
0

feat: add bump-version.sh script for version management

Automates version bumps across all version source files:
- Cargo.toml (PRIMARY - workspace.package.version)
- debian/changelog (prepend new entry)
- debian/control (update Version field)
- scripts/build-package.sh (update VERSION variable)
- frontend/package.json (update version field)
- Stale references check after bump

Usage: ./scripts/bump-version.sh <new_version> <old_version>
This commit is contained in:
2026-05-28 10:52:16 -05:00
commit 124b5b0e3b
153 changed files with 41878 additions and 0 deletions

View File

@ -0,0 +1,31 @@
[package]
name = "pm-worker"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[[bin]]
name = "pm-worker"
path = "src/main.rs"
[dependencies]
pm-core = { path = "../pm-core" }
pm-agent-client = { path = "../pm-agent-client" }
tokio = { workspace = true, features = ["full"] }
sqlx = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
futures = { workspace = true }
rustls = { workspace = true }
tokio-rustls = { version = "0.26" }
rustls-pemfile = { version = "2" }
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
reqwest = { workspace = true }

View File

@ -0,0 +1,45 @@
//! Helper for loading mTLS certificate/key material from disk.
//!
//! Reads PEM files referenced in [`SecurityConfig`] and returns the raw bytes
//! needed by [`pm_agent_client::AgentClient`].
use pm_core::config::SecurityConfig;
/// Raw PEM bytes for mTLS client authentication and CA verification.
pub struct AgentCerts {
pub client_cert: Vec<u8>,
pub client_key: Vec<u8>,
pub ca_cert: Vec<u8>,
}
/// Load agent mTLS certificates from the paths specified in [`SecurityConfig`].
///
/// Returns an error if any file cannot be read. The caller should handle
/// the error gracefully (log and skip the poll cycle) rather than crashing.
pub fn load_agent_certs(security: &SecurityConfig) -> anyhow::Result<AgentCerts> {
let client_cert = std::fs::read(&security.agent_client_cert_path).map_err(|e| {
anyhow::anyhow!(
"Failed to read agent client cert '{}': {}",
security.agent_client_cert_path,
e
)
})?;
let client_key = std::fs::read(&security.agent_client_key_path).map_err(|e| {
anyhow::anyhow!(
"Failed to read agent client key '{}': {}",
security.agent_client_key_path,
e
)
})?;
let ca_cert = std::fs::read(&security.ca_cert_path).map_err(|e| {
anyhow::anyhow!("Failed to read CA cert '{}': {}", security.ca_cert_path, e)
})?;
Ok(AgentCerts {
client_cert,
client_key,
ca_cert,
})
}

View File

@ -0,0 +1,86 @@
//! Periodic audit log integrity verification.
//!
//! Runs every 24 hours, walks the audit_log rows ordered by id,
//! verifies each row_hash matches the recomputed hash, and logs the
//! result as an `AuditIntegrityVerified` event. If tampering is
//! detected, logs an error and creates an alert.
use std::sync::Arc;
use std::time::Duration;
use sqlx::PgPool;
use pm_core::audit::{log_event, verify_integrity, AuditAction};
use pm_core::config::AppConfig;
/// Run the audit integrity verifier every 24 hours.
pub async fn run_audit_verifier(pool: PgPool, _config: Arc<AppConfig>) {
tracing::info!("Audit integrity verifier started");
// Run immediately on startup
verify_once(&pool).await;
let mut interval = tokio::time::interval(Duration::from_secs(24 * 60 * 60));
loop {
interval.tick().await;
tracing::info!("Running scheduled audit integrity verification");
verify_once(&pool).await;
}
}
/// Run a single integrity verification pass.
async fn verify_once(pool: &PgPool) {
let result = verify_integrity(pool).await;
if result.intact {
tracing::info!(
rows_checked = result.rows_checked,
"Audit integrity verification passed"
);
} else {
tracing::error!(
rows_checked = result.rows_checked,
error_count = result.errors.len(),
"Audit integrity verification FAILED — tampering detected!"
);
for err in &result.errors {
tracing::error!(
row_id = err.row_id,
expected_hash = %err.expected_hash,
actual_hash = %err.actual_hash,
"Audit chain integrity error"
);
}
}
// Log the verification event
log_event(
pool,
AuditAction::AuditIntegrityVerified,
None,
None,
Some("audit_log"),
None,
serde_json::json!({
"intact": result.intact,
"rows_checked": result.rows_checked,
"error_count": result.errors.len(),
"errors": result.errors.iter().take(10).map(|e| serde_json::json!({
"row_id": e.row_id,
"expected_hash": e.expected_hash,
"actual_hash": e.actual_hash,
})).collect::<Vec<_>>(),
}),
None,
None,
)
.await;
// Update last verified timestamp
let _ = sqlx::query(
"UPDATE system_config SET value = NOW()::text, updated_at = NOW() WHERE key = 'audit_integrity_last_verified'",
)
.execute(pool)
.await;
}

331
crates/pm-worker/src/email.rs Executable file
View File

@ -0,0 +1,331 @@
//! Email notification module.
//!
//! Loads SMTP configuration from `system_config` and sends notification emails
//! for patch job events (completion, failure) and maintenance window reminders.
//! All emails are optional and disabled by default via `notification_email_enabled`.
use lettre::{
message::{header::ContentType, Mailbox},
transport::smtp::authentication::Credentials,
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
};
use sqlx::PgPool;
use pm_core::audit::{log_event, AuditAction};
/// SMTP configuration loaded from `system_config`.
struct SmtpSettings {
enabled: bool,
host: String,
port: u16,
username: String,
password: String,
from: String,
tls_mode: String,
}
/// Notification preferences loaded from `system_config`.
struct NotificationSettings {
email_enabled: bool,
email_from: String,
recipients: Vec<String>,
}
/// Load SMTP settings from the `system_config` table.
async fn load_smtp_settings(pool: &PgPool) -> SmtpSettings {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT key, value FROM system_config WHERE key IN (
'smtp_enabled', 'smtp_host', 'smtp_port', 'smtp_username',
'smtp_password', 'smtp_from', 'smtp_tls_mode'
)",
)
.fetch_all(pool)
.await
.unwrap_or_default();
let get = |key: &str| -> String {
rows.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.clone())
.unwrap_or_default()
};
SmtpSettings {
enabled: get("smtp_enabled") == "true",
host: get("smtp_host"),
port: get("smtp_port").parse().unwrap_or(587),
username: get("smtp_username"),
password: get("smtp_password"),
from: get("smtp_from"),
tls_mode: get("smtp_tls_mode"),
}
}
/// Load notification preferences from `system_config`.
async fn load_notification_settings(pool: &PgPool) -> NotificationSettings {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT key, value FROM system_config WHERE key IN (
'notification_email_enabled', 'notification_email_from', 'notification_email_recipients'
)",
)
.fetch_all(pool)
.await
.unwrap_or_default();
let get = |key: &str| -> String {
rows.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.clone())
.unwrap_or_default()
};
let recipients: Vec<String> =
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
NotificationSettings {
email_enabled: get("notification_email_enabled") == "true",
email_from: get("notification_email_from"),
recipients,
}
}
/// Build an async SMTP transport from settings.
fn build_transport(settings: &SmtpSettings) -> Result<AsyncSmtpTransport<Tokio1Executor>, String> {
match settings.tls_mode.as_str() {
"tls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(&settings.host)
.map_err(|e| format!("TLS relay error: {}", e))?;
builder = builder.port(settings.port);
if !settings.username.is_empty() {
builder = builder.credentials(Credentials::new(
settings.username.clone(),
settings.password.clone(),
));
}
Ok(builder.build())
},
"starttls" => {
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(&settings.host)
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
builder = builder.port(settings.port);
if !settings.username.is_empty() {
builder = builder.credentials(Credentials::new(
settings.username.clone(),
settings.password.clone(),
));
}
Ok(builder.build())
},
_ => {
// "none" — plaintext / no TLS
let mut builder =
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(&settings.host)
.port(settings.port);
if !settings.username.is_empty() {
builder = builder.credentials(Credentials::new(
settings.username.clone(),
settings.password.clone(),
));
}
Ok(builder.build())
},
}
}
/// Send an email notification. Returns true if the email was sent successfully.
async fn send_email(pool: &PgPool, subject: &str, body: &str) -> bool {
let smtp = match load_smtp_settings(pool).await {
s if !s.enabled => {
tracing::debug!("SMTP not enabled, skipping email notification");
return false;
},
s => s,
};
let notif = load_notification_settings(pool).await;
if !notif.email_enabled {
tracing::debug!("Email notifications disabled, skipping");
return false;
}
if notif.recipients.is_empty() {
tracing::debug!("No email recipients configured, skipping notification");
return false;
}
let from_addr = if notif.email_from.is_empty() {
smtp.from.clone()
} else {
notif.email_from
};
let from_mailbox: Mailbox = match from_addr.parse() {
Ok(m) => m,
Err(e) => {
tracing::error!(error = %e, "Invalid from address for email notification");
return false;
},
};
let mut builder = Message::builder()
.from(from_mailbox.clone())
.subject(subject)
.header(ContentType::TEXT_PLAIN);
// Add all recipients
for recipient in &notif.recipients {
let mailbox: Mailbox = match recipient.parse() {
Ok(m) => m,
Err(e) => {
tracing::error!(error = %e, recipient = %recipient, "Invalid recipient address");
continue;
},
};
builder = builder.to(mailbox);
}
let email = match builder.body(body.to_string()) {
Ok(e) => e,
Err(e) => {
tracing::error!(error = %e, "Failed to build email message");
return false;
},
};
let transport = match build_transport(&smtp) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "Failed to build SMTP transport");
return false;
},
};
match transport.send(email).await {
Ok(_) => {
tracing::info!(subject, "Email notification sent successfully");
true
},
Err(e) => {
tracing::error!(error = %e, subject, "Failed to send email notification");
false
},
}
}
/// Send a patch failure notification email for a specific host.
pub async fn send_patch_failure_email(
pool: &PgPool,
host_fqdn: &str,
job_id: &str,
error_message: &str,
) {
let subject = format!("[Patch Manager] Patch Failed on {}", host_fqdn);
let body = format!(
"Patch operation failed on host: {host_fqdn}\n\
Job ID: {job_id}\n\
Error: {error_message}\n\
\n\
Please review the job details in the Patch Manager dashboard."
);
let sent = send_email(pool, &subject, &body).await;
log_event(
pool,
AuditAction::EmailNotificationSent,
None,
None,
Some("patch_job"),
Some(job_id),
serde_json::json!({
"type": "patch_failure",
"host_fqdn": host_fqdn,
"sent": sent,
}),
None,
None,
)
.await;
}
/// Send a job completion notification email.
pub async fn send_job_completion_email(
pool: &PgPool,
job_id: &str,
host_count: i64,
succeeded_count: i64,
failed_count: i64,
) {
let subject = format!("[Patch Manager] Job {} Completed", job_id);
let body = format!(
"Patch job completed: {job_id}\n\
Total hosts: {host_count}\n\
Succeeded: {succeeded_count}\n\
Failed: {failed_count}\n\
\n\
Please review the job details in the Patch Manager dashboard."
);
let sent = send_email(pool, &subject, &body).await;
log_event(
pool,
AuditAction::EmailNotificationSent,
None,
None,
Some("patch_job"),
Some(job_id),
serde_json::json!({
"type": "job_completion",
"host_count": host_count,
"succeeded_count": succeeded_count,
"failed_count": failed_count,
"sent": sent,
}),
None,
None,
)
.await;
}
/// Send a maintenance window reminder email.
#[allow(dead_code)]
pub async fn send_maintenance_window_reminder_email(
pool: &PgPool,
host_fqdn: &str,
window_label: &str,
start_at: &str,
) {
let subject = format!(
"[Patch Manager] Upcoming Maintenance Window: {}",
window_label
);
let body = format!(
"Maintenance window reminder:\n\
Host: {host_fqdn}\n\
Window: {window_label}\n\
Starts at: {start_at}\n\
\n\
Patch operations will begin at the scheduled time."
);
let sent = send_email(pool, &subject, &body).await;
log_event(
pool,
AuditAction::MaintenanceWindowReminder,
None,
None,
Some("maintenance_window"),
None,
serde_json::json!({
"type": "maintenance_reminder",
"host_fqdn": host_fqdn,
"window_label": window_label,
"sent": sent,
}),
None,
None,
)
.await;
}

View File

@ -0,0 +1,480 @@
//! Periodic health check poller for configured service and HTTP checks.
//!
//! Polls every `health_check_poll_interval_secs`, querying each enabled health
//! check definition and storing results in `host_health_check_results`.
//! Results older than 4 days are pruned on each cycle.
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use pm_core::{config::AppConfig, crypto};
use sqlx::{FromRow, PgPool};
use tokio::{sync::Semaphore, time};
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
use pm_agent_client::{AgentClient, AgentClientError};
// ─────────────────────────────────────────────────────────────────────────────
// DB row types
// ─────────────────────────────────────────────────────────────────────────────
/// Row fetched for each enabled health check, joined with host connection info.
#[derive(FromRow)]
#[allow(dead_code)]
struct HealthCheckRow {
id: Uuid,
host_id: Uuid,
name: String,
check_type: String,
service_name: Option<String>,
url: Option<String>,
expected_body: Option<String>,
ignore_cert_errors: Option<bool>,
basic_auth_user: Option<String>,
basic_auth_pass_encrypted: Option<Vec<u8>>,
basic_auth_pass_nonce: Option<Vec<u8>>,
target_host_id: Option<Uuid>,
ip_address: String,
agent_port: i32,
}
// ─────────────────────────────────────────────────────────────────────────────
// Public entry point
// ─────────────────────────────────────────────────────────────────────────────
/// Run the health check poller loop indefinitely.
///
/// On each tick all enabled health checks are queried concurrently (up to
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
/// to `host_health_check_results` and stale rows are pruned.
pub async fn run_health_check_poller(pool: PgPool, config: Arc<AppConfig>) {
let interval_secs = config.worker.health_check_poll_interval_secs;
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
tracing::info!(interval_secs, "Health check poller started");
loop {
ticker.tick().await;
// Load certs on each cycle so cert rotation is picked up automatically.
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(
error = %e,
"Health check poller: failed to load agent certs — skipping cycle"
);
continue;
},
};
let client_cert = Arc::new(certs.client_cert);
let client_key = Arc::new(certs.client_key);
let ca_cert = Arc::new(certs.ca_cert);
// Load the crypto key for decrypting HTTP check passwords.
let crypto_key = match crypto::load_or_create_key(Path::new(crypto::KEY_PATH)) {
Ok(k) => Arc::new(k),
Err(e) => {
tracing::error!(
error = %e,
"Health check poller: failed to load crypto key — skipping cycle"
);
continue;
},
};
// Fetch all enabled health checks with host connection info.
let checks: Vec<HealthCheckRow> = match sqlx::query_as(
r#"
SELECT
hc.id,
hc.host_id,
hc.name,
hc.check_type,
hc.service_name,
hc.url,
hc.expected_body,
hc.ignore_cert_errors,
hc.basic_auth_user,
hc.basic_auth_pass_encrypted,
hc.basic_auth_pass_nonce,
hc.target_host_id,
host(COALESCE(th.ip_address, h.ip_address))::text AS ip_address,
COALESCE(th.agent_port, h.agent_port) AS agent_port
FROM host_health_checks hc
JOIN hosts h ON h.id = hc.host_id
LEFT JOIN hosts th ON th.id = hc.target_host_id
WHERE hc.enabled = TRUE
ORDER BY hc.id
"#,
)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(e) => {
tracing::error!(error = %e, "Health check poller: failed to fetch health checks");
continue;
},
};
if checks.is_empty() {
tracing::debug!("Health check poller: no enabled health checks, skipping cycle");
prune_old_results(&pool).await;
continue;
}
let total = checks.len();
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
let mut handles = Vec::with_capacity(total);
for check in checks {
let pool = pool.clone();
let sem = semaphore.clone();
let cert = client_cert.clone();
let key = client_key.clone();
let ca = ca_cert.clone();
let ckey = crypto_key.clone();
let handle = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
run_check(pool, check, &cert, &key, &ca, &ckey).await
});
handles.push(handle);
}
// Collect results and tally counts.
let mut healthy_count = 0usize;
let mut unhealthy_count = 0usize;
let mut error_count = 0usize;
for handle in handles {
match handle.await {
Ok(true) => healthy_count += 1,
Ok(false) => unhealthy_count += 1,
Err(e) => {
tracing::error!(error = %e, "Health check poller task panicked");
error_count += 1;
},
}
}
tracing::info!(
total,
healthy_count,
unhealthy_count,
error_count,
"Health check poll cycle complete"
);
// Prune results older than 4 days.
prune_old_results(&pool).await;
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Check dispatch
// ─────────────────────────────────────────────────────────────────────────────
/// Run a single health check and persist the result. Returns `true` if healthy.
async fn run_check(
pool: PgPool,
check: HealthCheckRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
crypto_key: &[u8; 32],
) -> bool {
let start = Instant::now();
let (healthy, detail) = match check.check_type.as_str() {
"service" => run_service_check(&check, client_cert, client_key, ca_cert).await,
"http" => run_http_check(&check, crypto_key).await,
other => {
tracing::warn!(
check_id = %check.id,
check_type = other,
"Unknown health check type — treating as unhealthy"
);
(false, format!("Unknown check type: {other}"))
},
};
let latency_ms = start.elapsed().as_millis() as i32;
// Persist the result.
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_health_check_results (check_id, healthy, detail, latency_ms)
VALUES ($1, $2, $3, $4)
"#,
)
.bind(check.id)
.bind(healthy)
.bind(&detail)
.bind(latency_ms)
.execute(&pool)
.await
{
tracing::error!(
check_id = %check.id,
error = %e,
"Health check poller: failed to insert result"
);
}
healthy
}
// ─────────────────────────────────────────────────────────────────────────────
// Service check (via mTLS AgentClient)
// ─────────────────────────────────────────────────────────────────────────────
/// Execute a service check by calling the agent's `/api/v1/system/services/{name}` endpoint.
async fn run_service_check(
check: &HealthCheckRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) -> (bool, String) {
let service_name = match &check.service_name {
Some(name) => name.clone(),
None => {
return (false, "Service check missing service_name".to_string());
},
};
let client = match AgentClient::new(
&check.ip_address,
check.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Ok(c) => c,
Err(e) => {
return (false, format!("Failed to build AgentClient: {e}"));
},
};
match client.service_status(&service_name).await {
Ok(data) => {
let detail = if data.healthy {
format!(
"Service '{}' is {}/{} (enabled: {})",
data.name, data.active_state, data.sub_state, data.enabled_state
)
} else {
format!(
"Service '{}' status: {}/{} (unhealthy, enabled: {})",
data.name, data.active_state, data.sub_state, data.enabled_state
)
};
(data.healthy, detail)
},
Err(AgentClientError::Timeout) => (
false,
format!("Agent timed out querying service '{service_name}'"),
),
Err(AgentClientError::Connect(_)) => (
false,
format!("Agent connection refused for service '{service_name}'"),
),
Err(AgentClientError::ApiError { code, message }) => {
// 404, 400, 500 etc. from the agent means the service is unhealthy.
(false, format!("Agent error [{code}]: {message}"))
},
Err(e) => (
false,
format!("Agent error querying service '{service_name}': {e}"),
),
}
}
// ─────────────────────────────────────────────────────────────────────────────
// HTTP check (via reqwest, no mTLS)
// ─────────────────────────────────────────────────────────────────────────────
/// Execute an HTTP check by making a GET request to the configured URL.
/// Supports optional basic auth (decrypted from DB) and substring body matching.
async fn run_http_check(check: &HealthCheckRow, crypto_key: &[u8; 32]) -> (bool, String) {
let url = match &check.url {
Some(u) => u.clone(),
None => {
return (false, "HTTP check missing URL".to_string());
},
};
// Build a reqwest client for this check.
// Use danger_accept_invalid_certs if ignore_cert_errors is set (default true).
let ignore_cert_errors = check.ignore_cert_errors.unwrap_or(true);
let client_builder = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.redirect(reqwest::redirect::Policy::limited(5));
let client = if ignore_cert_errors {
client_builder
.danger_accept_invalid_certs(true)
.build()
.unwrap_or_else(|_| reqwest::Client::new())
} else {
client_builder
.build()
.unwrap_or_else(|_| reqwest::Client::new())
};
// Build the request.
let mut request = client.get(&url);
// Add basic auth if configured.
if let Some(user) = &check.basic_auth_user {
// Decrypt the password if present.
let password = match (
&check.basic_auth_pass_encrypted,
&check.basic_auth_pass_nonce,
) {
(Some(enc), Some(nonce)) => match crypto::decrypt(enc, nonce, crypto_key) {
Ok(p) => p,
Err(e) => {
return (false, format!("Failed to decrypt basic auth password: {e}"));
},
},
_ => {
// No encrypted password stored — treat as missing credentials.
return (
false,
"HTTP check has basic_auth_user but no encrypted password".to_string(),
);
},
};
request = request.basic_auth(user.as_str(), Some(password.as_str()));
}
// Execute the request.
let response = match request.send().await {
Ok(r) => r,
Err(e) => {
if e.is_timeout() {
return (false, format!("HTTP check timed out: {url}"));
} else if e.is_connect() {
return (false, format!("HTTP check connection failed: {url}"));
} else {
return (false, format!("HTTP check request error: {e}"));
}
},
};
let status = response.status();
// Check HTTP status code.
if !status.is_success() {
return (
false,
format!("HTTP check returned status {} for {url}", status.as_u16()),
);
}
// Read the response body for substring matching.
let body = match response.text().await {
Ok(b) => b,
Err(e) => {
return (
false,
format!("HTTP check failed to read response body: {e}"),
);
},
};
// Check expected_body substring match.
if let Some(expected) = &check.expected_body {
if !body.contains(expected) {
return (
false,
format!("HTTP check body mismatch for {url}: expected substring not found"),
);
}
}
(
true,
format!("HTTP check OK for {url} (status {})", status.as_u16()),
)
}
// ─────────────────────────────────────────────────────────────────────────────
// Prune old results
// ─────────────────────────────────────────────────────────────────────────────
/// Delete health check results older than 4 days.
async fn prune_old_results(pool: &PgPool) {
match sqlx::query(
"DELETE FROM host_health_check_results WHERE checked_at < NOW() - INTERVAL '4 days'",
)
.execute(pool)
.await
{
Ok(result) => {
if result.rows_affected() > 0 {
tracing::info!(
rows_deleted = result.rows_affected(),
"Health check poller: pruned old results"
);
}
},
Err(e) => {
tracing::error!(error = %e, "Health check poller: failed to prune old results");
},
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Health check gate for job executor
// ─────────────────────────────────────────────────────────────────────────────
/// Check whether all enabled health checks for a host are healthy.
///
/// Returns `Ok(true)` if all checks pass (or no checks are configured),
/// `Ok(false)` if any check is unhealthy or has no result yet.
pub async fn check_host_health_checks(pool: &PgPool, host_id: Uuid) -> anyhow::Result<bool> {
// Check if there are any enabled health checks for this host.
let check_count: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM host_health_checks WHERE host_id = $1 AND enabled = TRUE",
)
.bind(host_id)
.fetch_one(pool)
.await?;
if check_count.0 == 0 {
// No health checks configured for this host — treat as healthy.
return Ok(true);
}
// Find any enabled check that has no healthy result or an unhealthy latest result.
let unhealthy_count: (i64,) = sqlx::query_as(
r#"
SELECT COUNT(*)
FROM host_health_checks hc
LEFT JOIN LATERAL (
SELECT healthy
FROM host_health_check_results r
WHERE r.check_id = hc.id
ORDER BY r.checked_at DESC
LIMIT 1
) latest ON true
WHERE hc.host_id = $1
AND hc.enabled = TRUE
AND (latest.healthy IS NULL OR latest.healthy = FALSE)
"#,
)
.bind(host_id)
.fetch_one(pool)
.await?;
Ok(unhealthy_count.0 == 0)
}

View File

@ -0,0 +1,249 @@
//! Periodic health poller for all registered hosts.
//!
//! Polls every host via the agent `/health` endpoint on each tick of
//! `health_poll_interval_secs`, with bounded concurrency controlled by a
//! [`tokio::sync::Semaphore`]. Also calls `/system/info` to refresh
//! `os_family`, `os_name`, `arch`, and `agent_version` in the hosts table.
use std::sync::Arc;
use pm_agent_client::{AgentClient, AgentClientError};
use pm_core::{config::AppConfig, models::HostHealthStatus};
use sqlx::{FromRow, PgPool};
use tokio::{sync::Semaphore, time};
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
/// Minimal host projection fetched for each poll cycle.
#[derive(Debug, FromRow)]
struct HostRow {
id: Uuid,
ip_address: String,
agent_port: i32,
}
/// Run the health poller loop indefinitely.
///
/// On each tick all registered hosts are queried concurrently (up to
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
/// to `host_health_data` and the `hosts` table is updated.
pub async fn run_health_poller(pool: PgPool, config: Arc<AppConfig>) {
let interval_secs = config.worker.health_poll_interval_secs;
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
tracing::info!(interval_secs, "Health poller started");
loop {
ticker.tick().await;
// Load certs on each cycle so cert rotation is picked up automatically.
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Health poller: failed to load agent certs — skipping cycle");
continue;
},
};
let client_cert = Arc::new(certs.client_cert);
let client_key = Arc::new(certs.client_key);
let ca_cert = Arc::new(certs.ca_cert);
// Fetch all hosts.
let hosts: Vec<HostRow> = match sqlx::query_as(
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(e) => {
tracing::error!(error = %e, "Health poller: failed to fetch hosts");
continue;
},
};
if hosts.is_empty() {
tracing::debug!("Health poller: no hosts registered, skipping cycle");
continue;
}
let total = hosts.len();
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
let mut handles = Vec::with_capacity(total);
for host in hosts {
let pool = pool.clone();
let sem = semaphore.clone();
let cert = client_cert.clone();
let key = client_key.clone();
let ca = ca_cert.clone();
let handle = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
poll_host_health(pool, host, &cert, &key, &ca).await
});
handles.push(handle);
}
// Collect results and tally counts.
let mut healthy = 0usize;
let mut degraded = 0usize;
let mut unreachable = 0usize;
for handle in handles {
match handle.await {
Ok(HostHealthStatus::Healthy) => healthy += 1,
Ok(HostHealthStatus::Degraded) => degraded += 1,
Ok(HostHealthStatus::Unreachable) => unreachable += 1,
Ok(_) => {},
Err(e) => tracing::error!(error = %e, "Health poller task panicked"),
}
}
tracing::info!(
total,
healthy,
degraded,
unreachable,
"Health poll cycle complete"
);
}
}
/// Poll a single host, persist the result, and return the determined status.
///
/// Also updates `agent_version` from the health response and
/// `os_family`/`os_name`/`arch` from the `/system/info` endpoint when available.
async fn poll_host_health(
pool: PgPool,
host: HostRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) -> HostHealthStatus {
// Determine status, payload, agent version, and optional system info.
let (status, payload, agent_version, sys_info) = match AgentClient::new(
&host.ip_address,
host.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Err(e) => {
tracing::warn!(
host_id = %host.id,
error = %e,
"Health poller: failed to build AgentClient"
);
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
None,
None,
)
},
Ok(client) => {
let (status, payload, version) = match client.health().await {
Ok(data) => {
let payload = serde_json::to_value(&data).unwrap_or_default();
(HostHealthStatus::Healthy, payload, Some(data.version))
},
Err(AgentClientError::Timeout) => {
tracing::warn!(host_id = %host.id, "Health poller: agent timed out");
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
None,
)
},
Err(AgentClientError::Connect(_)) => {
tracing::warn!(host_id = %host.id, "Health poller: agent connection refused");
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
None,
)
},
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Health poller: agent error");
(
HostHealthStatus::Degraded,
serde_json::Value::Object(Default::default()),
None,
)
},
};
// Try to fetch system info for OS/arch details (best-effort).
let sys_info = if status != HostHealthStatus::Unreachable {
match client.system_info().await {
Ok(info) => Some(info),
Err(e) => {
tracing::debug!(
host_id = %host.id,
error = %e,
"Health poller: failed to get system info (non-fatal)"
);
None
},
}
} else {
None
};
(status, payload, version, sys_info)
},
};
// Insert into host_health_data.
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_health_data (host_id, status, payload)
VALUES ($1, $2, $3)
"#,
)
.bind(host.id)
.bind(&status)
.bind(&payload)
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to insert health data");
}
// Build OS name from system info components (e.g. "Ubuntu 24.04").
let os_name_from_sysinfo = sys_info
.as_ref()
.map(|i| format!("{} {}", i.os, i.os_version));
// Update hosts table with health status, agent version, and OS details.
// COALESCE preserves existing values when new data is unavailable.
if let Err(e) = sqlx::query(
r#"
UPDATE hosts
SET health_status = $2, last_health_at = NOW(),
agent_version = COALESCE($3, agent_version),
os_family = COALESCE($4, os_family),
os_name = COALESCE($5, os_name),
arch = COALESCE($6, arch)
WHERE id = $1
"#,
)
.bind(host.id)
.bind(&status)
.bind(&agent_version)
.bind(sys_info.as_ref().map(|i| i.os.as_str()))
.bind(os_name_from_sysinfo)
.bind(sys_info.as_ref().map(|i| i.architecture.as_str()))
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to update host status");
}
status
}

File diff suppressed because it is too large Load Diff

208
crates/pm-worker/src/main.rs Executable file
View File

@ -0,0 +1,208 @@
//! pm-worker — Linux Patch Manager background worker.
//!
//! Handles scheduled polling, job execution, maintenance window scheduling,
//! retry logic, email notifications, audit integrity verification, and data pruning.
mod agent_loader;
mod audit_verifier;
mod email;
mod health_check_poller;
mod health_poller;
mod job_executor;
mod maintenance_scheduler;
mod patch_poller;
mod refresh_listener;
mod ws_relay;
use chrono::Utc;
use pm_core::{config::AppConfig, db, logging};
use sqlx::PgPool;
use std::{sync::Arc, time::Duration};
use tokio::time;
use audit_verifier::run_audit_verifier;
use health_check_poller::run_health_check_poller;
use health_poller::run_health_poller;
use job_executor::run_job_executor;
use maintenance_scheduler::run_maintenance_scheduler;
use patch_poller::run_patch_poller;
use refresh_listener::run_refresh_listener;
use ws_relay::run_ws_relay;
/// Minimum number of applied migrations the worker requires before
/// accepting work. Prevents the worker from running against a schema
/// that hasn't been migrated yet.
const REQUIRED_MIGRATION_COUNT: i64 = 16;
/// How long to wait between schema-version checks before giving up.
const SCHEMA_CHECK_TIMEOUT: Duration = Duration::from_secs(120);
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Install the default crypto provider for rustls (required since 0.23)
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
// Load configuration
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
eprintln!("Config file not found or invalid, using defaults");
AppConfig::default()
});
// Initialize logging
logging::init(&config.logging);
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"patch-manager-worker starting"
);
// Initialize database pool
let pool = db::init_pool(&config.database).await?;
// Wait for schema to be at the expected version (web process runs migrations)
wait_for_schema(&pool).await?;
let config = Arc::new(config);
// Spawn worker tasks
let heartbeat_handle = tokio::spawn(run_heartbeat(
pool.clone(),
config.worker.heartbeat_interval_secs,
));
// M4: agent health poller, patch data poller, on-demand refresh listener
let health_handle = tokio::spawn(run_health_poller(pool.clone(), config.clone()));
let patch_handle = tokio::spawn(run_patch_poller(pool.clone(), config.clone()));
let refresh_handle = tokio::spawn(run_refresh_listener(pool.clone(), config.clone()));
// M5: job execution engine
let job_exec_handle = tokio::spawn(run_job_executor(pool.clone(), config.clone()));
// M6: maintenance window scheduler
let maint_sched_handle = tokio::spawn(run_maintenance_scheduler(pool.clone(), config.clone()));
// M7: WS relay — streams agent job events → DB → pg_notify → browser WS
let ws_relay_handle = tokio::spawn(run_ws_relay(pool.clone(), config.clone()));
// M11: audit integrity verification (runs every 24 hours)
let audit_verifier_handle = tokio::spawn(run_audit_verifier(pool.clone(), config.clone()));
// Health check poller — runs configured service/HTTP health checks
let health_check_handle = tokio::spawn(run_health_check_poller(pool.clone(), config.clone()));
// Enrollment cleanup task (runs every hour)
let enrollment_cleanup_handle = tokio::spawn(run_enrollment_cleanup_task(pool.clone()));
tracing::info!("Worker tasks started");
// Wait for all tasks (they run indefinitely)
let _ = tokio::join!(
heartbeat_handle,
health_handle,
patch_handle,
refresh_handle,
job_exec_handle,
maint_sched_handle,
ws_relay_handle,
audit_verifier_handle,
health_check_handle,
enrollment_cleanup_handle,
);
Ok(())
}
/// Wait until the database schema has at least `REQUIRED_MIGRATION_COUNT`
/// successful migrations applied. Retries every 5 seconds up to
/// `SCHEMA_CHECK_TIMEOUT`.
async fn wait_for_schema(pool: &PgPool) -> anyhow::Result<()> {
let deadline = tokio::time::Instant::now() + SCHEMA_CHECK_TIMEOUT;
loop {
match db::check_schema_version(pool).await {
Ok(count) if count >= REQUIRED_MIGRATION_COUNT => {
tracing::info!(migration_count = count, "Schema version check passed");
return Ok(());
},
Ok(count) => {
tracing::warn!(
migration_count = count,
required = REQUIRED_MIGRATION_COUNT,
"Schema not ready, waiting..."
);
},
Err(e) => {
tracing::warn!(error = %e, "Schema version check failed, retrying...");
},
}
if tokio::time::Instant::now() >= deadline {
anyhow::bail!(
"Schema not ready after {}s — is the web process running migrations?",
SCHEMA_CHECK_TIMEOUT.as_secs()
);
}
time::sleep(Duration::from_secs(5)).await;
}
}
/// Writes a heartbeat row to `worker_heartbeat` every `interval_secs`.
/// The web process can query this to confirm the worker is alive.
async fn run_heartbeat(pool: PgPool, interval_secs: u64) {
let interval = Duration::from_secs(interval_secs);
let mut ticker = time::interval(interval);
loop {
ticker.tick().await;
let result = sqlx::query(
r#"
INSERT INTO worker_heartbeat (id, last_seen, worker_version)
VALUES (1, NOW(), $1)
ON CONFLICT (id) DO UPDATE
SET last_seen = EXCLUDED.last_seen,
worker_version = EXCLUDED.worker_version
"#,
)
.bind(env!("CARGO_PKG_VERSION"))
.execute(&pool)
.await;
match result {
Ok(_) => tracing::debug!("Worker heartbeat written"),
Err(e) => tracing::error!(error = %e, "Worker heartbeat failed"),
}
}
}
/// Periodically deletes expired enrollment requests.
async fn run_enrollment_cleanup_task(pool: PgPool) {
let mut interval = tokio::time::interval(Duration::from_secs(3600)); // Every hour
interval.tick().await; // Initial tick to run immediately if needed
loop {
interval.tick().await;
let now = Utc::now();
match sqlx::query("DELETE FROM enrollment_requests WHERE expires_at < $1")
.bind(now)
.execute(&pool)
.await
{
Ok(result) => {
if result.rows_affected() > 0 {
tracing::info!(
removed = result.rows_affected(),
"Purged expired enrollment requests"
);
}
},
Err(e) => tracing::error!(error = %e, "Failed to purge expired enrollment requests"),
}
}
}

View File

@ -0,0 +1,381 @@
//! Maintenance window scheduler.
//!
//! Polls every 60 seconds and performs two tasks:
//!
//! 1. **Auto-apply**: For each enabled maintenance window with `auto_apply = true`
//! that is currently open, if the host has pending patches and no existing
//! patch_apply job queued/running for that window, automatically creates one.
//!
//! 2. **Dispatch**: For each open window, dispatch any queued non-immediate
//! patch jobs associated with the window's host.
//!
//! A window is considered "open" when:
//! - `once` — `start_at <= NOW() < start_at + duration_minutes * '1 minute'`
//! - `daily` — current UTC time-of-day is within the window's daily slot
//! - `weekly` — same as daily, but only on the matching `recurrence_day` (0=Sun)
//! - `monthly` — same as daily, but only on the matching `recurrence_day` (1-31)
use std::sync::Arc;
use pm_core::config::AppConfig;
use sqlx::{FromRow, PgPool};
use tokio::time;
use uuid::Uuid;
use crate::job_executor::process_job;
// ─────────────────────────────────────────────────────────────────────────────
// Internal types
// ─────────────────────────────────────────────────────────────────────────────
#[derive(Debug, FromRow)]
struct OpenWindowHost {
host_id: Uuid,
}
#[derive(Debug, FromRow)]
struct QueuedJobId {
job_id: Uuid,
}
#[derive(Debug, FromRow)]
struct AutoApplyWindow {
window_id: Uuid,
host_id: Uuid,
}
#[derive(Debug, FromRow)]
#[allow(dead_code)]
struct PendingPatchHost {
host_id: Uuid,
patch_count: i32,
}
#[derive(Debug, FromRow)]
struct InsertedJobId {
job_id: Uuid,
}
// ─────────────────────────────────────────────────────────────────────────────
// Public entry point
// ─────────────────────────────────────────────────────────────────────────────
/// Run the maintenance scheduler indefinitely.
/// Spawned by `pm-worker/src/main.rs` alongside the job executor.
pub async fn run_maintenance_scheduler(pool: PgPool, config: Arc<AppConfig>) {
tracing::info!("Maintenance scheduler started");
// First tick fires immediately; consume it to align with job_executor.
let mut ticker = time::interval(std::time::Duration::from_secs(60));
ticker.tick().await;
loop {
ticker.tick().await;
tracing::debug!("Maintenance scheduler: checking open windows");
// Step 1: Auto-create patch_apply jobs for windows with auto_apply=true
auto_create_patch_jobs(pool.clone(), config.clone()).await;
// Step 2: Dispatch any queued non-immediate jobs for open windows
dispatch_open_window_jobs(pool.clone(), config.clone()).await;
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Step 1: Auto-create patch_apply jobs
// ─────────────────────────────────────────────────────────────────────────────
/// For each enabled maintenance window that is currently open AND has
/// `auto_apply = true`, check if the host has pending patches and no
/// existing patch_apply job for this window cycle. If so, create one.
async fn auto_create_patch_jobs(pool: PgPool, _config: Arc<AppConfig>) {
// Find all open windows with auto_apply=true
let auto_windows: Vec<AutoApplyWindow> = match sqlx::query_as(
r#"
SELECT mw.id AS window_id, mw.host_id
FROM maintenance_windows mw
WHERE mw.enabled = TRUE
AND mw.auto_apply = TRUE
AND (
( mw.recurrence = 'once'
AND mw.start_at <= NOW()
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
)
OR
( mw.recurrence = 'daily'
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
( mw.recurrence = 'weekly'
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
( mw.recurrence = 'monthly'
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
)
"#,
)
.fetch_all(&pool)
.await
{
Ok(w) => w,
Err(e) => {
tracing::error!(error = %e, "auto_create_patch_jobs: open-windows query failed");
return;
}
};
if auto_windows.is_empty() {
tracing::debug!("auto_create: no open auto-apply windows this cycle");
return;
}
tracing::info!(
auto_window_count = auto_windows.len(),
"auto_create: found open auto-apply windows"
);
for win in &auto_windows {
// Check if host has pending patches
let pending: Option<PendingPatchHost> = match sqlx::query_as(
r#"
SELECT host_id, patch_count
FROM host_patch_data
WHERE host_id = $1 AND patch_count > 0
"#,
)
.bind(win.host_id)
.fetch_optional(&pool)
.await
{
Ok(p) => p,
Err(e) => {
tracing::error!(
error = %e,
host_id = %win.host_id,
"auto_create: patch data query failed"
);
continue;
},
};
let Some(pending) = pending else {
tracing::debug!(
host_id = %win.host_id,
"auto_create: no pending patches, skipping"
);
continue;
};
// Check if there's already a queued/running patch_apply job for this host
// that was created during this window cycle (within the window's time range).
// We use a simpler check: any non-completed patch_apply job for this host
// that references this maintenance window, OR any non-immediate job without
// a window that was created since the window opened.
let existing_job: bool = match sqlx::query_scalar(
r#"
SELECT EXISTS(
SELECT 1 FROM patch_jobs pj
JOIN patch_job_hosts pjh ON pj.id = pjh.job_id
WHERE pjh.host_id = $1
AND pj.status IN ('queued', 'running', 'pending')
AND pj.kind = 'patch_apply'
AND (
pj.maintenance_window_id = $2
OR
(pj.immediate = FALSE AND pj.created_at >=
(SELECT start_at - INTERVAL '5 minutes' FROM maintenance_windows WHERE id = $2)
)
)
)
"#,
)
.bind(win.host_id)
.bind(win.window_id)
.fetch_one(&pool)
.await
{
Ok(b) => b,
Err(e) => {
tracing::error!(
error = %e,
host_id = %win.host_id,
"auto_create: existing job check failed"
);
continue;
}
};
if existing_job {
tracing::debug!(
host_id = %win.host_id,
window_id = %win.window_id,
"auto_create: existing job already queued/running, skipping"
);
continue;
}
// Create a new patch_apply job for this host, linked to the window.
let job: Option<InsertedJobId> = match sqlx::query_as(
r#"
WITH new_job AS (
INSERT INTO patch_jobs
(kind, status, maintenance_window_id, immediate, patch_selection, notes)
VALUES
('patch_apply', 'queued', $1, FALSE, '[]'::jsonb,
'Auto-created by maintenance window scheduler')
RETURNING id AS job_id
)
INSERT INTO patch_job_hosts (job_id, host_id, status)
SELECT new_job.job_id, $2, 'queued'
FROM new_job
RETURNING job_id
"#,
)
.bind(win.window_id)
.bind(win.host_id)
.fetch_optional(&pool)
.await
{
Ok(j) => j,
Err(e) => {
tracing::error!(
error = %e,
host_id = %win.host_id,
window_id = %win.window_id,
"auto_create: job insert failed"
);
continue;
},
};
if let Some(job) = job {
tracing::info!(
job_id = %job.job_id,
host_id = %win.host_id,
window_id = %win.window_id,
patch_count = pending.patch_count,
"auto_create: created patch_apply job for host in maintenance window"
);
}
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Step 2: Dispatch queued non-immediate jobs
// ─────────────────────────────────────────────────────────────────────────────
/// Find all hosts with a currently-open maintenance window, then for each,
/// find their queued non-immediate job entries and dispatch them.
async fn dispatch_open_window_jobs(pool: PgPool, config: Arc<AppConfig>) {
// ── 1. Find all host_ids with an open window right now ─────────────────
let open_hosts: Vec<OpenWindowHost> = match sqlx::query_as(
r#"
SELECT DISTINCT mw.host_id
FROM maintenance_windows mw
WHERE mw.enabled = TRUE
AND (
-- One-time: absolute window
( mw.recurrence = 'once'
AND mw.start_at <= NOW()
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
)
OR
-- Daily: time-of-day slot, any day
( mw.recurrence = 'daily'
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
-- Weekly: matching day-of-week + time-of-day slot
( mw.recurrence = 'weekly'
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
OR
-- Monthly: matching day-of-month + time-of-day slot
( mw.recurrence = 'monthly'
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
+ (mw.duration_minutes * INTERVAL '1 minute'))
)
)
"#,
)
.fetch_all(&pool)
.await
{
Ok(hosts) => hosts,
Err(e) => {
tracing::error!(error = %e, "dispatch_open_window_jobs: open-hosts query failed");
return;
}
};
if open_hosts.is_empty() {
tracing::debug!("Maintenance scheduler: no open windows this cycle");
return;
}
tracing::info!(
open_host_count = open_hosts.len(),
"Maintenance scheduler: found hosts with open windows"
);
// ── 2. For each open host, find distinct queued non-immediate job IDs ──
for host in open_hosts {
let job_ids: Vec<QueuedJobId> = match sqlx::query_as(
r#"
SELECT DISTINCT pjh.job_id
FROM patch_job_hosts pjh
JOIN patch_jobs j ON j.id = pjh.job_id
WHERE pjh.host_id = $1
AND pjh.status = 'queued'
AND j.immediate = FALSE
AND j.status != 'cancelled'
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
"#,
)
.bind(host.host_id)
.fetch_all(&pool)
.await
{
Ok(ids) => ids,
Err(e) => {
tracing::error!(
error = %e,
host_id = %host.host_id,
"dispatch_open_window_jobs: queued jobs query failed"
);
continue;
},
};
for job in job_ids {
tracing::info!(
job_id = %job.job_id,
host_id = %host.host_id,
"Maintenance scheduler: dispatching non-immediate job (window open)"
);
let (p, c) = (pool.clone(), config.clone());
let job_id = job.job_id;
tokio::spawn(async move {
process_job(p, c, job_id).await;
});
}
}
}

View File

@ -0,0 +1,202 @@
//! Periodic patch-data poller for all registered hosts.
//!
//! Polls every host via the agent `/patches` and `/packages` endpoints on
//! each tick of `patch_poll_interval_secs`, with bounded concurrency
//! controlled by a [`tokio::sync::Semaphore`].
use std::sync::Arc;
use pm_agent_client::AgentClient;
use pm_core::config::AppConfig;
use sqlx::{FromRow, PgPool};
use tokio::{sync::Semaphore, time};
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
/// Minimal host projection fetched for each poll cycle.
#[derive(Debug, FromRow)]
struct HostRow {
id: Uuid,
ip_address: String,
agent_port: i32,
}
/// Run the patch poller loop indefinitely.
///
/// On each tick all registered hosts are queried concurrently (up to
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
/// to `host_patch_data` and `hosts.last_patch_at` is updated.
pub async fn run_patch_poller(pool: PgPool, config: Arc<AppConfig>) {
let interval_secs = config.worker.patch_poll_interval_secs;
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
tracing::info!(interval_secs, "Patch poller started");
loop {
ticker.tick().await;
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Patch poller: failed to load agent certs — skipping cycle");
continue;
},
};
let client_cert = Arc::new(certs.client_cert);
let client_key = Arc::new(certs.client_key);
let ca_cert = Arc::new(certs.ca_cert);
let hosts: Vec<HostRow> = match sqlx::query_as(
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
)
.fetch_all(&pool)
.await
{
Ok(rows) => rows,
Err(e) => {
tracing::error!(error = %e, "Patch poller: failed to fetch hosts");
continue;
},
};
if hosts.is_empty() {
tracing::debug!("Patch poller: no hosts registered, skipping cycle");
continue;
}
let total = hosts.len();
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
let mut handles = Vec::with_capacity(total);
for host in hosts {
let pool = pool.clone();
let sem = semaphore.clone();
let cert = client_cert.clone();
let key = client_key.clone();
let ca = ca_cert.clone();
let handle = tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
poll_host_patches(pool, host, &cert, &key, &ca).await
});
handles.push(handle);
}
let mut succeeded = 0usize;
let mut failed = 0usize;
for handle in handles {
match handle.await {
Ok(true) => succeeded += 1,
Ok(false) => failed += 1,
Err(e) => {
tracing::error!(error = %e, "Patch poller task panicked");
failed += 1;
},
}
}
tracing::info!(total, succeeded, failed, "Patch poll cycle complete");
}
}
/// Poll a single host for patch and package data, persist the result.
/// Returns `true` on success, `false` on any error.
async fn poll_host_patches(
pool: PgPool,
host: HostRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) -> bool {
let client = match AgentClient::new(
&host.ip_address,
host.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Ok(c) => c,
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: failed to build AgentClient");
return false;
},
};
// Fetch patches and packages concurrently.
let (patches_result, packages_result) =
tokio::join!(client.patches(), client.packages_upgradable());
let patches_data = match patches_result {
Ok(d) => d,
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: patches() failed");
return false;
},
};
let packages_data = match packages_result {
Ok(d) => d,
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: packages_upgradable() failed");
return false;
},
};
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
let installed_packages = serde_json::to_value(&packages_data.packages).unwrap_or_default();
let patch_count = patches_data.total as i32;
let cve_count = patches_data
.patches
.iter()
.filter(|p| !p.cve_ids.is_empty())
.count() as i32;
// Upsert into host_patch_data (one row per host, latest poll wins).
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_patch_data
(host_id, available_patches, installed_packages, patch_count, cve_count)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (host_id) DO UPDATE SET
available_patches = EXCLUDED.available_patches,
installed_packages = EXCLUDED.installed_packages,
patch_count = EXCLUDED.patch_count,
cve_count = EXCLUDED.cve_count,
polled_at = NOW()
"#,
)
.bind(host.id)
.bind(&available_patches)
.bind(&installed_packages)
.bind(patch_count)
.bind(cve_count)
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to insert patch data");
return false;
}
// Update hosts.last_patch_at.
if let Err(e) = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
.bind(host.id)
.execute(&pool)
.await
{
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to update last_patch_at");
}
tracing::debug!(
host_id = %host.id,
patch_count,
cve_count,
"Patch data collected"
);
true
}

View File

@ -0,0 +1,269 @@
//! On-demand refresh listener.
//!
//! Listens on the PostgreSQL `refresh_requested` NOTIFY channel. When a
//! notification arrives the payload is expected to be a host UUID string.
//! The listener immediately polls that host for health and patch data and
//! persists the results — bypassing the normal poll intervals.
use std::sync::Arc;
use pm_agent_client::{AgentClient, AgentClientError};
use pm_core::{config::AppConfig, models::HostHealthStatus};
use sqlx::{FromRow, PgPool};
use tokio::time;
use uuid::Uuid;
use crate::agent_loader::load_agent_certs;
/// Minimal host row used for on-demand refresh.
#[derive(Debug, FromRow)]
struct HostRow {
id: Uuid,
ip_address: String,
agent_port: i32,
}
/// Run the LISTEN/NOTIFY refresh listener indefinitely.
///
/// Automatically reconnects if the underlying PostgreSQL connection drops.
pub async fn run_refresh_listener(pool: PgPool, config: Arc<AppConfig>) {
tracing::info!("Refresh listener started — listening on 'refresh_requested'");
loop {
if let Err(e) = listen_loop(&pool, &config).await {
tracing::error!(
error = %e,
"Refresh listener disconnected, reconnecting in 5s"
);
time::sleep(std::time::Duration::from_secs(5)).await;
}
}
}
/// Inner loop — returns `Err` only on a fatal listener error so the outer
/// loop can reconnect.
async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> {
let mut listener = sqlx::postgres::PgListener::connect(&config.database.url).await?;
listener.listen("refresh_requested").await?;
tracing::debug!("Refresh listener connected and listening");
loop {
let notification = listener.recv().await?;
let payload = notification.payload().to_string();
tracing::info!(payload, "Refresh notification received");
let host_id = match payload.parse::<Uuid>() {
Ok(id) => id,
Err(e) => {
tracing::warn!(
payload,
error = %e,
"Refresh listener: invalid UUID in notification payload"
);
continue;
},
};
// Fetch the host from the database.
let host: Option<HostRow> = sqlx::query_as(
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts WHERE id = $1",
)
.bind(host_id)
.fetch_optional(pool)
.await
.unwrap_or(None);
let host = match host {
Some(h) => h,
None => {
tracing::warn!(%host_id, "Refresh listener: host not found");
continue;
},
};
// Load certs for this refresh.
let certs = match load_agent_certs(&config.security) {
Ok(c) => c,
Err(e) => {
tracing::error!(
%host_id,
error = %e,
"Refresh listener: failed to load agent certs"
);
continue;
},
};
// Spawn the actual work so the listener loop is not blocked.
let pool_clone = pool.clone();
let cert = certs.client_cert;
let key = certs.client_key;
let ca = certs.ca_cert;
tokio::spawn(async move {
refresh_host(pool_clone, host, &cert, &key, &ca).await;
});
}
}
/// Perform a full health + patch refresh for one host and persist results.
async fn refresh_host(
pool: PgPool,
host: HostRow,
client_cert: &[u8],
client_key: &[u8],
ca_cert: &[u8],
) {
let client = match AgentClient::new(
&host.ip_address,
host.agent_port as u16,
client_cert,
client_key,
ca_cert,
) {
Ok(c) => c,
Err(e) => {
tracing::warn!(
host_id = %host.id,
error = %e,
"Refresh: failed to build AgentClient"
);
persist_health_unreachable(&pool, host.id).await;
return;
},
};
// ── Health ────────────────────────────────────────────────────────────
let (health_status, health_payload) = match client.health().await {
Ok(data) => {
let payload = serde_json::to_value(&data).unwrap_or_default();
(HostHealthStatus::Healthy, payload)
},
Err(AgentClientError::Timeout) | Err(AgentClientError::Connect(_)) => {
tracing::warn!(host_id = %host.id, "Refresh: agent unreachable");
(
HostHealthStatus::Unreachable,
serde_json::Value::Object(Default::default()),
)
},
Err(e) => {
tracing::warn!(host_id = %host.id, error = %e, "Refresh: health error");
(
HostHealthStatus::Degraded,
serde_json::Value::Object(Default::default()),
)
},
};
persist_health(&pool, host.id, &health_status, &health_payload).await;
// ── Patch data ────────────────────────────────────────────────────────
let (patches_result, packages_result) =
tokio::join!(client.patches(), client.packages_upgradable());
match (patches_result, packages_result) {
(Ok(patches_data), Ok(packages_data)) => {
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
let installed_packages =
serde_json::to_value(&packages_data.packages).unwrap_or_default();
let patch_count = patches_data.total as i32;
let cve_count = patches_data
.patches
.iter()
.filter(|p| !p.cve_ids.is_empty())
.count() as i32;
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_patch_data
(host_id, available_patches, installed_packages, patch_count, cve_count)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (host_id) DO UPDATE SET
available_patches = EXCLUDED.available_patches,
installed_packages = EXCLUDED.installed_packages,
patch_count = EXCLUDED.patch_count,
cve_count = EXCLUDED.cve_count,
polled_at = NOW()
"#,
)
.bind(host.id)
.bind(&available_patches)
.bind(&installed_packages)
.bind(patch_count)
.bind(cve_count)
.execute(&pool)
.await
{
tracing::error!(
host_id = %host.id,
error = %e,
"Refresh: failed to insert patch data"
);
} else {
let _ = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
.bind(host.id)
.execute(&pool)
.await;
tracing::info!(
host_id = %host.id,
patch_count,
cve_count,
"On-demand refresh complete"
);
}
},
(Err(e), _) | (_, Err(e)) => {
tracing::warn!(
host_id = %host.id,
error = %e,
"Refresh: failed to collect patch data"
);
},
}
}
async fn persist_health_unreachable(pool: &PgPool, host_id: Uuid) {
let status = HostHealthStatus::Unreachable;
let payload = serde_json::Value::Object(Default::default());
persist_health(pool, host_id, &status, &payload).await;
}
async fn persist_health(
pool: &PgPool,
host_id: Uuid,
status: &HostHealthStatus,
payload: &serde_json::Value,
) {
if let Err(e) = sqlx::query(
r#"
INSERT INTO host_health_data (host_id, status, payload)
VALUES ($1, $2, $3)
"#,
)
.bind(host_id)
.bind(status)
.bind(payload)
.execute(pool)
.await
{
tracing::error!(
%host_id,
error = %e,
"Refresh: failed to insert health data"
);
}
if let Err(e) =
sqlx::query("UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1")
.bind(host_id)
.bind(status)
.execute(pool)
.await
{
tracing::error!(%host_id, error = %e, "Refresh: failed to update host health_status");
}
}

718
crates/pm-worker/src/ws_relay.rs Executable file
View File

@ -0,0 +1,718 @@
//! WS relay — M7
//!
//! For every running `patch_job_hosts` row that has an `agent_job_id`, open a
//! WebSocket to the corresponding agent, stream job-status events, update the
//! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS
//! handler can forward the event to connected clients.
use std::{collections::HashSet, error::Error, sync::Arc, time::Duration};
use anyhow::Context;
use futures::StreamExt;
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
ClientConfig as TlsClientConfig, RootCertStore,
};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::protocol::Message, Connector};
use uuid::Uuid;
use pm_agent_client::client::AgentClient;
use pm_agent_client::client::DEFAULT_AGENT_PORT;
use pm_core::config::AppConfig;
// ── Types ─────────────────────────────────────────────────────────────────────
#[derive(Debug, sqlx::FromRow)]
struct RunningHostJob {
job_id: Uuid,
host_id: Uuid,
agent_job_id: String,
host_address: String,
}
/// JSON event streamed by the agent over its WS endpoint.
#[derive(Debug, Deserialize)]
struct AgentWsEvent {
#[allow(dead_code)]
job_id: String,
status: String,
output: Option<String>,
error: Option<String>,
#[allow(dead_code)]
progress_percent: Option<u8>,
}
/// Payload broadcast via `pg_notify('job_update', …)`.
#[derive(Debug, Serialize)]
struct NotifyPayload {
event_type: String, // "host" or "job"
job_id: String,
host_id: String,
status: String,
output: Option<String>,
error_message: Option<String>,
agent_job_id: String,
// Job-level fields (only present when event_type === "job")
succeeded_count: Option<i64>,
failed_count: Option<i64>,
host_count: Option<i64>,
}
// ── Cert PEM bytes for building AgentClient ────────────────────────────────────
/// Raw PEM bytes read from the security config cert paths.
/// Used to build an [`AgentClient`] for HTTP polling fallback.
#[derive(Clone)]
struct CertPems {
client_cert: Vec<u8>,
client_key: Vec<u8>,
ca_cert: Vec<u8>,
}
/// Read the three PEM files referenced by the security config.
/// Mirrors the file reads in [`build_tls_config`] but returns raw bytes instead of
/// parsing them into rustls types.
async fn read_cert_pems(config: &AppConfig) -> anyhow::Result<CertPems> {
let sec = &config.security;
let client_cert = tokio::fs::read(&sec.agent_client_cert_path)
.await
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
let client_key = tokio::fs::read(&sec.agent_client_key_path)
.await
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
let ca_cert = tokio::fs::read(&sec.ca_cert_path)
.await
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
Ok(CertPems {
client_cert,
client_key,
ca_cert,
})
}
// ── Entry point ───────────────────────────────────────────────────────────────
/// Long-running task: polls the DB for running host-jobs and spawns a per-pair
/// relay task for each one that isn't already being tracked.
pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
tracing::info!("WS relay task started");
let active: Arc<Mutex<HashSet<(Uuid, Uuid)>>> = Arc::new(Mutex::new(HashSet::new()));
let mut interval = tokio::time::interval(Duration::from_secs(10));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
let rows = match query_running_jobs(&pool).await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "ws_relay: DB poll failed");
continue;
},
};
for row in rows {
let key = (row.job_id, row.host_id);
// Skip pairs that already have an active relay.
if active.lock().await.contains(&key) {
continue;
}
// Build the rustls ClientConfig once per connection.
let tls_config = match build_tls_config(&config).await {
Ok(c) => Arc::new(c),
Err(e) => {
tracing::error!(error = %e, "ws_relay: TLS config error");
continue;
},
};
// Read raw cert PEM bytes for HTTP polling fallback.
let cert_pems = match read_cert_pems(&config).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "ws_relay: cert PEM read error");
continue;
},
};
let poll_interval = config.worker.ws_relay_poll_interval_secs;
active.lock().await.insert(key);
let pool_c = pool.clone();
let active_c = active.clone();
tokio::spawn(async move {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
agent_job_id = %row.agent_job_id,
host = %row.host_address,
"WS relay: starting relay"
);
match relay_one_job(&pool_c, &row, tls_config, &cert_pems, poll_interval).await {
Ok(()) => tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: completed"
),
Err(e) => tracing::error!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: ended with error"
),
}
active_c.lock().await.remove(&key);
});
}
}
}
// ── DB helpers ────────────────────────────────────────────────────────────────
async fn query_running_jobs(pool: &PgPool) -> anyhow::Result<Vec<RunningHostJob>> {
sqlx::query_as::<_, RunningHostJob>(
r#"
SELECT
pjh.job_id,
pjh.host_id,
pjh.agent_job_id,
COALESCE(h.fqdn, host(h.ip_address)::text) AS host_address
FROM patch_job_hosts pjh
JOIN hosts h ON h.id = pjh.host_id
WHERE pjh.status = 'running'::job_status
AND pjh.agent_job_id IS NOT NULL
"#,
)
.fetch_all(pool)
.await
.context("query_running_jobs")
}
// ── TLS ───────────────────────────────────────────────────────────────────────
async fn build_tls_config(config: &AppConfig) -> anyhow::Result<TlsClientConfig> {
let sec = &config.security;
let cert_pem = tokio::fs::read(&sec.agent_client_cert_path)
.await
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
let key_pem = tokio::fs::read(&sec.agent_client_key_path)
.await
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
let ca_pem = tokio::fs::read(&sec.ca_cert_path)
.await
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
// Parse client certificate chain.
let client_certs: Vec<CertificateDer<'static>> = {
let mut cur = std::io::Cursor::new(&cert_pem);
rustls_pemfile::certs(&mut cur)
.collect::<Result<Vec<_>, _>>()
.context("parse client cert PEM")?
};
// Parse client private key.
let client_key: PrivateKeyDer<'static> = {
let mut cur = std::io::Cursor::new(&key_pem);
rustls_pemfile::private_key(&mut cur)
.context("parse client key PEM")?
.context("no private key in PEM")?
};
// Build root store from CA cert.
let mut root_store = RootCertStore::empty();
{
let mut cur = std::io::Cursor::new(&ca_pem);
for cert_result in rustls_pemfile::certs(&mut cur) {
root_store
.add(cert_result.context("read CA cert entry")?)
.context("add CA cert to root store")?;
}
}
let mut config = TlsClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(client_certs, client_key)
.context("build TlsClientConfig")?;
// WebSocket requires HTTP/1.1 — without ALPN the server may negotiate
// h2 (HTTP/2), which breaks the WebSocket upgrade handshake
// ("Key mismatch in Sec-WebSocket-Accept header").
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(config)
}
// ── Per-job relay ─────────────────────────────────────────────────────────────
async fn relay_one_job(
pool: &PgPool,
row: &RunningHostJob,
tls_config: Arc<TlsClientConfig>,
cert_pems: &CertPems,
poll_interval_secs: u64,
) -> anyhow::Result<()> {
let url = format!(
"wss://{}:{}/api/v1/ws/jobs",
row.host_address, DEFAULT_AGENT_PORT,
);
let (ws_stream, _) = match connect_async_tls_with_config(
url.as_str(),
None,
false,
Some(Connector::Rustls(tls_config)),
)
.await
{
Ok(ws) => ws,
Err(e) => {
// Log the full error chain for TLS debugging
let mut source = e.source();
let mut depth = 0;
while let Some(err) = source {
tracing::warn!(
job_id = %row.job_id,
host_id = %row.host_id,
depth,
error = %err,
"WS relay: TLS connection error detail"
);
source = err.source();
depth += 1;
}
// Fall back to HTTP polling instead of returning an error.
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
host = %row.host_address,
"WS relay: WebSocket connection failed, falling back to HTTP polling"
);
return relay_one_job_poll(pool, row, cert_pems, poll_interval_secs).await;
},
};
let (_sink, mut stream) = ws_stream.split();
while let Some(frame) = stream.next().await {
let frame = match frame {
Ok(f) => f,
Err(e) => {
tracing::warn!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: stream error"
);
break;
},
};
let text = match frame {
Message::Text(t) => t.to_string(),
Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(),
Message::Close(_) => {
tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly");
break;
},
_ => continue,
};
if text.is_empty() {
continue;
}
let event: AgentWsEvent = match serde_json::from_str(&text) {
Ok(e) => e,
Err(e) => {
tracing::warn!(
error = %e, raw = %text,
"WS relay: unparseable agent frame"
);
continue;
},
};
process_event(pool, row, &event).await;
if matches!(event.status.as_str(), "succeeded" | "failed" | "cancelled") {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %event.status,
"WS relay: terminal state — stopping"
);
break;
}
}
Ok(())
}
// ── Per-job HTTP polling fallback ─────────────────────────────────────────────
/// Fall back to HTTP polling when the WebSocket connection fails.
///
/// Builds an [`AgentClient`] from the same cert paths used for TLS, polls
/// `GET /api/v1/jobs/{id}` every `poll_interval_secs`, and calls
/// [`process_event`] whenever the status changes from the previous poll.
/// Stops when the job reaches a terminal state (succeeded/failed/cancelled).
async fn relay_one_job_poll(
pool: &PgPool,
row: &RunningHostJob,
cert_pems: &CertPems,
poll_interval_secs: u64,
) -> anyhow::Result<()> {
// Build the HTTP client using the same mTLS certs.
let agent_client = AgentClient::new(
&row.host_address,
DEFAULT_AGENT_PORT,
&cert_pems.client_cert,
&cert_pems.client_key,
&cert_pems.ca_cert,
)
.context("build AgentClient for polling fallback")?;
let mut last_status: Option<String> = None;
let poll_interval = Duration::from_secs(poll_interval_secs);
loop {
tokio::time::sleep(poll_interval).await;
let job_status = match agent_client.job_status(&row.agent_job_id).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
error = %e,
"WS relay poll: job_status request failed, will retry"
);
continue;
},
};
// Map agent status to the WS event format.
// The agent uses "completed" but pm-worker expects "succeeded".
// The agent uses "queued" but pm-worker expects "running".
let mapped_status = match job_status.status.as_str() {
"queued" => "running",
"running" => "running",
"succeeded" => "succeeded",
"completed" => "succeeded",
"failed" => "failed",
"cancelled" => "cancelled",
other => {
tracing::warn!(
status = %other,
job_id = %row.job_id,
"WS relay poll: unknown agent status, treating as running"
);
"running"
},
};
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
agent_status = %job_status.status,
mapped_status = %mapped_status,
progress = ?job_status.progress_percent,
"WS relay poll: fetched job status"
);
// Only process when the status has changed since the last poll.
if last_status.as_deref() == Some(mapped_status) {
continue;
}
last_status = Some(mapped_status.to_string());
let event = AgentWsEvent {
job_id: job_status.job_id.clone(),
status: mapped_status.to_string(),
output: job_status.output.clone(),
error: job_status.error.clone(),
progress_percent: job_status.progress_percent,
};
process_event(pool, row, &event).await;
// Stop polling on terminal states.
if matches!(mapped_status, "succeeded" | "failed" | "cancelled") {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %mapped_status,
"WS relay poll: terminal state — stopping"
);
break;
}
}
Ok(())
}
// ── Event processing ──────────────────────────────────────────────────────────
async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) {
// Map agent status string to DB job_status enum value.
let db_status = match event.status.as_str() {
"running" => "running",
"succeeded" => "succeeded",
"failed" => "failed",
"cancelled" => "cancelled",
other => {
tracing::warn!(status = %other, "WS relay: unknown agent status");
return;
},
};
let output = event.output.as_deref().unwrap_or("");
let error_msg = event.error.as_deref();
// Determine timestamps based on terminal state.
let is_terminal = matches!(db_status, "succeeded" | "failed" | "cancelled");
// Update the DB row.
let update_result = if is_terminal {
sqlx::query(
r#"
UPDATE patch_job_hosts
SET status = $1::job_status,
output = CASE WHEN $2 != '' THEN $2 ELSE output END,
error_message = $3,
completed_at = NOW()
WHERE job_id = $4
AND host_id = $5
"#,
)
.bind(db_status)
.bind(output)
.bind(error_msg)
.bind(row.job_id)
.bind(row.host_id)
.execute(pool)
.await
} else {
sqlx::query(
r#"
UPDATE patch_job_hosts
SET status = $1::job_status,
output = CASE WHEN $2 != '' THEN $2 ELSE output END
WHERE job_id = $3
AND host_id = $4
"#,
)
.bind(db_status)
.bind(output)
.bind(row.job_id)
.bind(row.host_id)
.execute(pool)
.await
};
if let Err(e) = update_result {
tracing::error!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: DB update failed"
);
return;
}
// Also update the parent patch_jobs status when the host-level job reaches
// a terminal state: running → if all hosts terminal then update parent.
if is_terminal {
update_parent_job_status(pool, row.job_id).await;
}
// Fire pg_notify so browser WS handlers forward the host-level event.
let payload = NotifyPayload {
event_type: "host".to_string(),
job_id: row.job_id.to_string(),
host_id: row.host_id.to_string(),
status: db_status.to_string(),
output: event.output.clone(),
error_message: event.error.clone(),
agent_job_id: row.agent_job_id.clone(),
succeeded_count: None,
failed_count: None,
host_count: None,
};
let payload_json = match serde_json::to_string(&payload) {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "WS relay: failed to serialize notify payload");
return;
},
};
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
.bind(&payload_json)
.execute(pool)
.await
{
tracing::error!(
error = %e,
job_id = %row.job_id,
host_id = %row.host_id,
"WS relay: pg_notify failed"
);
} else {
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %db_status,
"WS relay: pg_notify sent"
);
}
}
// ── Parent job status rollup ──────────────────────────────────────────────────
/// After a host-level job reaches a terminal state, check whether ALL hosts for
/// that job are now terminal and update the parent `patch_jobs` row accordingly.
///
/// If the parent job transitions to a terminal status, also fires a `job_update`
/// pg_notify with `event_type: "job"` so the frontend can update the job row.
async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) {
// Count hosts that are still in a non-terminal state.
let pending: i64 = match sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM patch_job_hosts
WHERE job_id = $1
AND status NOT IN (
'succeeded'::job_status,
'failed'::job_status,
'cancelled'::job_status
)
"#,
)
.bind(job_id)
.fetch_one(pool)
.await
{
Ok(n) => n,
Err(e) => {
tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed");
return;
},
};
if pending > 0 {
return; // still hosts running — parent stays running
}
// All hosts terminal — determine final parent status and counts.
#[derive(sqlx::FromRow)]
struct RollupCounts {
total: i64,
succeeded: i64,
failed: i64,
}
let counts: RollupCounts = match sqlx::query_as(
r#"
SELECT
COUNT(*) AS total,
COUNT(*) FILTER (WHERE status = 'succeeded') AS succeeded,
COUNT(*) FILTER (WHERE status = 'failed') AS failed
FROM patch_job_hosts
WHERE job_id = $1
"#,
)
.bind(job_id)
.fetch_one(pool)
.await
{
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, %job_id, "update_parent_job_status: rollup query failed");
return;
},
};
let final_status = if counts.failed > 0 {
"failed"
} else {
"succeeded"
};
if let Err(e) = sqlx::query(
"UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2",
)
.bind(final_status)
.bind(job_id)
.execute(pool)
.await
{
tracing::error!(
error = %e,
%job_id,
status = %final_status,
"update_parent_job_status: UPDATE failed"
);
return;
}
tracing::info!(
%job_id,
status = %final_status,
"Parent job status updated"
);
// Fire job-level pg_notify so the frontend can update the job row.
let payload = NotifyPayload {
event_type: "job".to_string(),
job_id: job_id.to_string(),
host_id: String::new(), // no specific host for job-level events
status: final_status.to_string(),
output: None,
error_message: None,
agent_job_id: String::new(),
succeeded_count: Some(counts.succeeded),
failed_count: Some(counts.failed),
host_count: Some(counts.total),
};
let payload_json = match serde_json::to_string(&payload) {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, %job_id, "update_parent_job_status: failed to serialize job-level notify payload");
return;
},
};
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
.bind(&payload_json)
.execute(pool)
.await
{
tracing::error!(
error = %e,
%job_id,
"update_parent_job_status: job-level pg_notify failed"
);
} else {
tracing::info!(
%job_id,
status = %final_status,
"Job-level pg_notify sent"
);
}
}