feat: add bump-version.sh script for version management
Automates version bumps across all version source files: - Cargo.toml (PRIMARY - workspace.package.version) - debian/changelog (prepend new entry) - debian/control (update Version field) - scripts/build-package.sh (update VERSION variable) - frontend/package.json (update version field) - Stale references check after bump Usage: ./scripts/bump-version.sh <new_version> <old_version>
This commit is contained in:
31
crates/pm-worker/Cargo.toml
Normal file
31
crates/pm-worker/Cargo.toml
Normal file
@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "pm-worker"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "pm-worker"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
pm-agent-client = { path = "../pm-agent-client" }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { version = "0.26" }
|
||||
rustls-pemfile = { version = "2" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
|
||||
reqwest = { workspace = true }
|
||||
45
crates/pm-worker/src/agent_loader.rs
Executable file
45
crates/pm-worker/src/agent_loader.rs
Executable file
@ -0,0 +1,45 @@
|
||||
//! Helper for loading mTLS certificate/key material from disk.
|
||||
//!
|
||||
//! Reads PEM files referenced in [`SecurityConfig`] and returns the raw bytes
|
||||
//! needed by [`pm_agent_client::AgentClient`].
|
||||
|
||||
use pm_core::config::SecurityConfig;
|
||||
|
||||
/// Raw PEM bytes for mTLS client authentication and CA verification.
|
||||
pub struct AgentCerts {
|
||||
pub client_cert: Vec<u8>,
|
||||
pub client_key: Vec<u8>,
|
||||
pub ca_cert: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Load agent mTLS certificates from the paths specified in [`SecurityConfig`].
|
||||
///
|
||||
/// Returns an error if any file cannot be read. The caller should handle
|
||||
/// the error gracefully (log and skip the poll cycle) rather than crashing.
|
||||
pub fn load_agent_certs(security: &SecurityConfig) -> anyhow::Result<AgentCerts> {
|
||||
let client_cert = std::fs::read(&security.agent_client_cert_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read agent client cert '{}': {}",
|
||||
security.agent_client_cert_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let client_key = std::fs::read(&security.agent_client_key_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read agent client key '{}': {}",
|
||||
security.agent_client_key_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let ca_cert = std::fs::read(&security.ca_cert_path).map_err(|e| {
|
||||
anyhow::anyhow!("Failed to read CA cert '{}': {}", security.ca_cert_path, e)
|
||||
})?;
|
||||
|
||||
Ok(AgentCerts {
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
})
|
||||
}
|
||||
86
crates/pm-worker/src/audit_verifier.rs
Executable file
86
crates/pm-worker/src/audit_verifier.rs
Executable file
@ -0,0 +1,86 @@
|
||||
//! Periodic audit log integrity verification.
|
||||
//!
|
||||
//! Runs every 24 hours, walks the audit_log rows ordered by id,
|
||||
//! verifies each row_hash matches the recomputed hash, and logs the
|
||||
//! result as an `AuditIntegrityVerified` event. If tampering is
|
||||
//! detected, logs an error and creates an alert.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use sqlx::PgPool;
|
||||
|
||||
use pm_core::audit::{log_event, verify_integrity, AuditAction};
|
||||
use pm_core::config::AppConfig;
|
||||
|
||||
/// Run the audit integrity verifier every 24 hours.
|
||||
pub async fn run_audit_verifier(pool: PgPool, _config: Arc<AppConfig>) {
|
||||
tracing::info!("Audit integrity verifier started");
|
||||
|
||||
// Run immediately on startup
|
||||
verify_once(&pool).await;
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(24 * 60 * 60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
tracing::info!("Running scheduled audit integrity verification");
|
||||
verify_once(&pool).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a single integrity verification pass.
|
||||
async fn verify_once(pool: &PgPool) {
|
||||
let result = verify_integrity(pool).await;
|
||||
|
||||
if result.intact {
|
||||
tracing::info!(
|
||||
rows_checked = result.rows_checked,
|
||||
"Audit integrity verification passed"
|
||||
);
|
||||
} else {
|
||||
tracing::error!(
|
||||
rows_checked = result.rows_checked,
|
||||
error_count = result.errors.len(),
|
||||
"Audit integrity verification FAILED — tampering detected!"
|
||||
);
|
||||
|
||||
for err in &result.errors {
|
||||
tracing::error!(
|
||||
row_id = err.row_id,
|
||||
expected_hash = %err.expected_hash,
|
||||
actual_hash = %err.actual_hash,
|
||||
"Audit chain integrity error"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Log the verification event
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::AuditIntegrityVerified,
|
||||
None,
|
||||
None,
|
||||
Some("audit_log"),
|
||||
None,
|
||||
serde_json::json!({
|
||||
"intact": result.intact,
|
||||
"rows_checked": result.rows_checked,
|
||||
"error_count": result.errors.len(),
|
||||
"errors": result.errors.iter().take(10).map(|e| serde_json::json!({
|
||||
"row_id": e.row_id,
|
||||
"expected_hash": e.expected_hash,
|
||||
"actual_hash": e.actual_hash,
|
||||
})).collect::<Vec<_>>(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Update last verified timestamp
|
||||
let _ = sqlx::query(
|
||||
"UPDATE system_config SET value = NOW()::text, updated_at = NOW() WHERE key = 'audit_integrity_last_verified'",
|
||||
)
|
||||
.execute(pool)
|
||||
.await;
|
||||
}
|
||||
331
crates/pm-worker/src/email.rs
Executable file
331
crates/pm-worker/src/email.rs
Executable file
@ -0,0 +1,331 @@
|
||||
//! Email notification module.
|
||||
//!
|
||||
//! Loads SMTP configuration from `system_config` and sends notification emails
|
||||
//! for patch job events (completion, failure) and maintenance window reminders.
|
||||
//! All emails are optional and disabled by default via `notification_email_enabled`.
|
||||
|
||||
use lettre::{
|
||||
message::{header::ContentType, Mailbox},
|
||||
transport::smtp::authentication::Credentials,
|
||||
AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use pm_core::audit::{log_event, AuditAction};
|
||||
|
||||
/// SMTP configuration loaded from `system_config`.
|
||||
struct SmtpSettings {
|
||||
enabled: bool,
|
||||
host: String,
|
||||
port: u16,
|
||||
username: String,
|
||||
password: String,
|
||||
from: String,
|
||||
tls_mode: String,
|
||||
}
|
||||
|
||||
/// Notification preferences loaded from `system_config`.
|
||||
struct NotificationSettings {
|
||||
email_enabled: bool,
|
||||
email_from: String,
|
||||
recipients: Vec<String>,
|
||||
}
|
||||
|
||||
/// Load SMTP settings from the `system_config` table.
|
||||
async fn load_smtp_settings(pool: &PgPool) -> SmtpSettings {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(
|
||||
"SELECT key, value FROM system_config WHERE key IN (
|
||||
'smtp_enabled', 'smtp_host', 'smtp_port', 'smtp_username',
|
||||
'smtp_password', 'smtp_from', 'smtp_tls_mode'
|
||||
)",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let get = |key: &str| -> String {
|
||||
rows.iter()
|
||||
.find(|(k, _)| k == key)
|
||||
.map(|(_, v)| v.clone())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
SmtpSettings {
|
||||
enabled: get("smtp_enabled") == "true",
|
||||
host: get("smtp_host"),
|
||||
port: get("smtp_port").parse().unwrap_or(587),
|
||||
username: get("smtp_username"),
|
||||
password: get("smtp_password"),
|
||||
from: get("smtp_from"),
|
||||
tls_mode: get("smtp_tls_mode"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load notification preferences from `system_config`.
|
||||
async fn load_notification_settings(pool: &PgPool) -> NotificationSettings {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(
|
||||
"SELECT key, value FROM system_config WHERE key IN (
|
||||
'notification_email_enabled', 'notification_email_from', 'notification_email_recipients'
|
||||
)",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let get = |key: &str| -> String {
|
||||
rows.iter()
|
||||
.find(|(k, _)| k == key)
|
||||
.map(|(_, v)| v.clone())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
let recipients: Vec<String> =
|
||||
serde_json::from_str(&get("notification_email_recipients")).unwrap_or_default();
|
||||
|
||||
NotificationSettings {
|
||||
email_enabled: get("notification_email_enabled") == "true",
|
||||
email_from: get("notification_email_from"),
|
||||
recipients,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an async SMTP transport from settings.
|
||||
fn build_transport(settings: &SmtpSettings) -> Result<AsyncSmtpTransport<Tokio1Executor>, String> {
|
||||
match settings.tls_mode.as_str() {
|
||||
"tls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::relay(&settings.host)
|
||||
.map_err(|e| format!("TLS relay error: {}", e))?;
|
||||
builder = builder.port(settings.port);
|
||||
if !settings.username.is_empty() {
|
||||
builder = builder.credentials(Credentials::new(
|
||||
settings.username.clone(),
|
||||
settings.password.clone(),
|
||||
));
|
||||
}
|
||||
Ok(builder.build())
|
||||
},
|
||||
"starttls" => {
|
||||
let mut builder = AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(&settings.host)
|
||||
.map_err(|e| format!("STARTTLS relay error: {}", e))?;
|
||||
builder = builder.port(settings.port);
|
||||
if !settings.username.is_empty() {
|
||||
builder = builder.credentials(Credentials::new(
|
||||
settings.username.clone(),
|
||||
settings.password.clone(),
|
||||
));
|
||||
}
|
||||
Ok(builder.build())
|
||||
},
|
||||
_ => {
|
||||
// "none" — plaintext / no TLS
|
||||
let mut builder =
|
||||
AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(&settings.host)
|
||||
.port(settings.port);
|
||||
if !settings.username.is_empty() {
|
||||
builder = builder.credentials(Credentials::new(
|
||||
settings.username.clone(),
|
||||
settings.password.clone(),
|
||||
));
|
||||
}
|
||||
Ok(builder.build())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an email notification. Returns true if the email was sent successfully.
|
||||
async fn send_email(pool: &PgPool, subject: &str, body: &str) -> bool {
|
||||
let smtp = match load_smtp_settings(pool).await {
|
||||
s if !s.enabled => {
|
||||
tracing::debug!("SMTP not enabled, skipping email notification");
|
||||
return false;
|
||||
},
|
||||
s => s,
|
||||
};
|
||||
|
||||
let notif = load_notification_settings(pool).await;
|
||||
if !notif.email_enabled {
|
||||
tracing::debug!("Email notifications disabled, skipping");
|
||||
return false;
|
||||
}
|
||||
|
||||
if notif.recipients.is_empty() {
|
||||
tracing::debug!("No email recipients configured, skipping notification");
|
||||
return false;
|
||||
}
|
||||
|
||||
let from_addr = if notif.email_from.is_empty() {
|
||||
smtp.from.clone()
|
||||
} else {
|
||||
notif.email_from
|
||||
};
|
||||
|
||||
let from_mailbox: Mailbox = match from_addr.parse() {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Invalid from address for email notification");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let mut builder = Message::builder()
|
||||
.from(from_mailbox.clone())
|
||||
.subject(subject)
|
||||
.header(ContentType::TEXT_PLAIN);
|
||||
|
||||
// Add all recipients
|
||||
for recipient in ¬if.recipients {
|
||||
let mailbox: Mailbox = match recipient.parse() {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, recipient = %recipient, "Invalid recipient address");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
builder = builder.to(mailbox);
|
||||
}
|
||||
|
||||
let email = match builder.body(body.to_string()) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to build email message");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let transport = match build_transport(&smtp) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to build SMTP transport");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
match transport.send(email).await {
|
||||
Ok(_) => {
|
||||
tracing::info!(subject, "Email notification sent successfully");
|
||||
true
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, subject, "Failed to send email notification");
|
||||
false
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a patch failure notification email for a specific host.
|
||||
pub async fn send_patch_failure_email(
|
||||
pool: &PgPool,
|
||||
host_fqdn: &str,
|
||||
job_id: &str,
|
||||
error_message: &str,
|
||||
) {
|
||||
let subject = format!("[Patch Manager] Patch Failed on {}", host_fqdn);
|
||||
let body = format!(
|
||||
"Patch operation failed on host: {host_fqdn}\n\
|
||||
Job ID: {job_id}\n\
|
||||
Error: {error_message}\n\
|
||||
\n\
|
||||
Please review the job details in the Patch Manager dashboard."
|
||||
);
|
||||
|
||||
let sent = send_email(pool, &subject, &body).await;
|
||||
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::EmailNotificationSent,
|
||||
None,
|
||||
None,
|
||||
Some("patch_job"),
|
||||
Some(job_id),
|
||||
serde_json::json!({
|
||||
"type": "patch_failure",
|
||||
"host_fqdn": host_fqdn,
|
||||
"sent": sent,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Send a job completion notification email.
|
||||
pub async fn send_job_completion_email(
|
||||
pool: &PgPool,
|
||||
job_id: &str,
|
||||
host_count: i64,
|
||||
succeeded_count: i64,
|
||||
failed_count: i64,
|
||||
) {
|
||||
let subject = format!("[Patch Manager] Job {} Completed", job_id);
|
||||
let body = format!(
|
||||
"Patch job completed: {job_id}\n\
|
||||
Total hosts: {host_count}\n\
|
||||
Succeeded: {succeeded_count}\n\
|
||||
Failed: {failed_count}\n\
|
||||
\n\
|
||||
Please review the job details in the Patch Manager dashboard."
|
||||
);
|
||||
|
||||
let sent = send_email(pool, &subject, &body).await;
|
||||
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::EmailNotificationSent,
|
||||
None,
|
||||
None,
|
||||
Some("patch_job"),
|
||||
Some(job_id),
|
||||
serde_json::json!({
|
||||
"type": "job_completion",
|
||||
"host_count": host_count,
|
||||
"succeeded_count": succeeded_count,
|
||||
"failed_count": failed_count,
|
||||
"sent": sent,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Send a maintenance window reminder email.
|
||||
#[allow(dead_code)]
|
||||
pub async fn send_maintenance_window_reminder_email(
|
||||
pool: &PgPool,
|
||||
host_fqdn: &str,
|
||||
window_label: &str,
|
||||
start_at: &str,
|
||||
) {
|
||||
let subject = format!(
|
||||
"[Patch Manager] Upcoming Maintenance Window: {}",
|
||||
window_label
|
||||
);
|
||||
let body = format!(
|
||||
"Maintenance window reminder:\n\
|
||||
Host: {host_fqdn}\n\
|
||||
Window: {window_label}\n\
|
||||
Starts at: {start_at}\n\
|
||||
\n\
|
||||
Patch operations will begin at the scheduled time."
|
||||
);
|
||||
|
||||
let sent = send_email(pool, &subject, &body).await;
|
||||
|
||||
log_event(
|
||||
pool,
|
||||
AuditAction::MaintenanceWindowReminder,
|
||||
None,
|
||||
None,
|
||||
Some("maintenance_window"),
|
||||
None,
|
||||
serde_json::json!({
|
||||
"type": "maintenance_reminder",
|
||||
"host_fqdn": host_fqdn,
|
||||
"window_label": window_label,
|
||||
"sent": sent,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
480
crates/pm-worker/src/health_check_poller.rs
Executable file
480
crates/pm-worker/src/health_check_poller.rs
Executable file
@ -0,0 +1,480 @@
|
||||
//! Periodic health check poller for configured service and HTTP checks.
|
||||
//!
|
||||
//! Polls every `health_check_poll_interval_secs`, querying each enabled health
|
||||
//! check definition and storing results in `host_health_check_results`.
|
||||
//! Results older than 4 days are pruned on each cycle.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use pm_core::{config::AppConfig, crypto};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// DB row types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Row fetched for each enabled health check, joined with host connection info.
|
||||
#[derive(FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct HealthCheckRow {
|
||||
id: Uuid,
|
||||
host_id: Uuid,
|
||||
name: String,
|
||||
check_type: String,
|
||||
service_name: Option<String>,
|
||||
url: Option<String>,
|
||||
expected_body: Option<String>,
|
||||
ignore_cert_errors: Option<bool>,
|
||||
basic_auth_user: Option<String>,
|
||||
basic_auth_pass_encrypted: Option<Vec<u8>>,
|
||||
basic_auth_pass_nonce: Option<Vec<u8>>,
|
||||
target_host_id: Option<Uuid>,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public entry point
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run the health check poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all enabled health checks are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_health_check_results` and stale rows are pruned.
|
||||
pub async fn run_health_check_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.health_check_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(interval_secs, "Health check poller started");
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Load certs on each cycle so cert rotation is picked up automatically.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Health check poller: failed to load agent certs — skipping cycle"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
// Load the crypto key for decrypting HTTP check passwords.
|
||||
let crypto_key = match crypto::load_or_create_key(Path::new(crypto::KEY_PATH)) {
|
||||
Ok(k) => Arc::new(k),
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Health check poller: failed to load crypto key — skipping cycle"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Fetch all enabled health checks with host connection info.
|
||||
let checks: Vec<HealthCheckRow> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
hc.id,
|
||||
hc.host_id,
|
||||
hc.name,
|
||||
hc.check_type,
|
||||
hc.service_name,
|
||||
hc.url,
|
||||
hc.expected_body,
|
||||
hc.ignore_cert_errors,
|
||||
hc.basic_auth_user,
|
||||
hc.basic_auth_pass_encrypted,
|
||||
hc.basic_auth_pass_nonce,
|
||||
hc.target_host_id,
|
||||
host(COALESCE(th.ip_address, h.ip_address))::text AS ip_address,
|
||||
COALESCE(th.agent_port, h.agent_port) AS agent_port
|
||||
FROM host_health_checks hc
|
||||
JOIN hosts h ON h.id = hc.host_id
|
||||
LEFT JOIN hosts th ON th.id = hc.target_host_id
|
||||
WHERE hc.enabled = TRUE
|
||||
ORDER BY hc.id
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health check poller: failed to fetch health checks");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if checks.is_empty() {
|
||||
tracing::debug!("Health check poller: no enabled health checks, skipping cycle");
|
||||
prune_old_results(&pool).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = checks.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for check in checks {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
let ckey = crypto_key.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
run_check(pool, check, &cert, &key, &ca, &ckey).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect results and tally counts.
|
||||
let mut healthy_count = 0usize;
|
||||
let mut unhealthy_count = 0usize;
|
||||
let mut error_count = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(true) => healthy_count += 1,
|
||||
Ok(false) => unhealthy_count += 1,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health check poller task panicked");
|
||||
error_count += 1;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
total,
|
||||
healthy_count,
|
||||
unhealthy_count,
|
||||
error_count,
|
||||
"Health check poll cycle complete"
|
||||
);
|
||||
|
||||
// Prune results older than 4 days.
|
||||
prune_old_results(&pool).await;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Check dispatch
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run a single health check and persist the result. Returns `true` if healthy.
|
||||
async fn run_check(
|
||||
pool: PgPool,
|
||||
check: HealthCheckRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
crypto_key: &[u8; 32],
|
||||
) -> bool {
|
||||
let start = Instant::now();
|
||||
|
||||
let (healthy, detail) = match check.check_type.as_str() {
|
||||
"service" => run_service_check(&check, client_cert, client_key, ca_cert).await,
|
||||
"http" => run_http_check(&check, crypto_key).await,
|
||||
other => {
|
||||
tracing::warn!(
|
||||
check_id = %check.id,
|
||||
check_type = other,
|
||||
"Unknown health check type — treating as unhealthy"
|
||||
);
|
||||
(false, format!("Unknown check type: {other}"))
|
||||
},
|
||||
};
|
||||
|
||||
let latency_ms = start.elapsed().as_millis() as i32;
|
||||
|
||||
// Persist the result.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_check_results (check_id, healthy, detail, latency_ms)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
)
|
||||
.bind(check.id)
|
||||
.bind(healthy)
|
||||
.bind(&detail)
|
||||
.bind(latency_ms)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
check_id = %check.id,
|
||||
error = %e,
|
||||
"Health check poller: failed to insert result"
|
||||
);
|
||||
}
|
||||
|
||||
healthy
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Service check (via mTLS AgentClient)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute a service check by calling the agent's `/api/v1/system/services/{name}` endpoint.
|
||||
async fn run_service_check(
|
||||
check: &HealthCheckRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> (bool, String) {
|
||||
let service_name = match &check.service_name {
|
||||
Some(name) => name.clone(),
|
||||
None => {
|
||||
return (false, "Service check missing service_name".to_string());
|
||||
},
|
||||
};
|
||||
|
||||
let client = match AgentClient::new(
|
||||
&check.ip_address,
|
||||
check.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return (false, format!("Failed to build AgentClient: {e}"));
|
||||
},
|
||||
};
|
||||
|
||||
match client.service_status(&service_name).await {
|
||||
Ok(data) => {
|
||||
let detail = if data.healthy {
|
||||
format!(
|
||||
"Service '{}' is {}/{} (enabled: {})",
|
||||
data.name, data.active_state, data.sub_state, data.enabled_state
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Service '{}' status: {}/{} (unhealthy, enabled: {})",
|
||||
data.name, data.active_state, data.sub_state, data.enabled_state
|
||||
)
|
||||
};
|
||||
(data.healthy, detail)
|
||||
},
|
||||
Err(AgentClientError::Timeout) => (
|
||||
false,
|
||||
format!("Agent timed out querying service '{service_name}'"),
|
||||
),
|
||||
Err(AgentClientError::Connect(_)) => (
|
||||
false,
|
||||
format!("Agent connection refused for service '{service_name}'"),
|
||||
),
|
||||
Err(AgentClientError::ApiError { code, message }) => {
|
||||
// 404, 400, 500 etc. from the agent means the service is unhealthy.
|
||||
(false, format!("Agent error [{code}]: {message}"))
|
||||
},
|
||||
Err(e) => (
|
||||
false,
|
||||
format!("Agent error querying service '{service_name}': {e}"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// HTTP check (via reqwest, no mTLS)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute an HTTP check by making a GET request to the configured URL.
|
||||
/// Supports optional basic auth (decrypted from DB) and substring body matching.
|
||||
async fn run_http_check(check: &HealthCheckRow, crypto_key: &[u8; 32]) -> (bool, String) {
|
||||
let url = match &check.url {
|
||||
Some(u) => u.clone(),
|
||||
None => {
|
||||
return (false, "HTTP check missing URL".to_string());
|
||||
},
|
||||
};
|
||||
|
||||
// Build a reqwest client for this check.
|
||||
// Use danger_accept_invalid_certs if ignore_cert_errors is set (default true).
|
||||
let ignore_cert_errors = check.ignore_cert_errors.unwrap_or(true);
|
||||
|
||||
let client_builder = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.redirect(reqwest::redirect::Policy::limited(5));
|
||||
|
||||
let client = if ignore_cert_errors {
|
||||
client_builder
|
||||
.danger_accept_invalid_certs(true)
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
} else {
|
||||
client_builder
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
};
|
||||
|
||||
// Build the request.
|
||||
let mut request = client.get(&url);
|
||||
|
||||
// Add basic auth if configured.
|
||||
if let Some(user) = &check.basic_auth_user {
|
||||
// Decrypt the password if present.
|
||||
let password = match (
|
||||
&check.basic_auth_pass_encrypted,
|
||||
&check.basic_auth_pass_nonce,
|
||||
) {
|
||||
(Some(enc), Some(nonce)) => match crypto::decrypt(enc, nonce, crypto_key) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return (false, format!("Failed to decrypt basic auth password: {e}"));
|
||||
},
|
||||
},
|
||||
_ => {
|
||||
// No encrypted password stored — treat as missing credentials.
|
||||
return (
|
||||
false,
|
||||
"HTTP check has basic_auth_user but no encrypted password".to_string(),
|
||||
);
|
||||
},
|
||||
};
|
||||
request = request.basic_auth(user.as_str(), Some(password.as_str()));
|
||||
}
|
||||
|
||||
// Execute the request.
|
||||
let response = match request.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
if e.is_timeout() {
|
||||
return (false, format!("HTTP check timed out: {url}"));
|
||||
} else if e.is_connect() {
|
||||
return (false, format!("HTTP check connection failed: {url}"));
|
||||
} else {
|
||||
return (false, format!("HTTP check request error: {e}"));
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
|
||||
// Check HTTP status code.
|
||||
if !status.is_success() {
|
||||
return (
|
||||
false,
|
||||
format!("HTTP check returned status {} for {url}", status.as_u16()),
|
||||
);
|
||||
}
|
||||
|
||||
// Read the response body for substring matching.
|
||||
let body = match response.text().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
return (
|
||||
false,
|
||||
format!("HTTP check failed to read response body: {e}"),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
// Check expected_body substring match.
|
||||
if let Some(expected) = &check.expected_body {
|
||||
if !body.contains(expected) {
|
||||
return (
|
||||
false,
|
||||
format!("HTTP check body mismatch for {url}: expected substring not found"),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
true,
|
||||
format!("HTTP check OK for {url} (status {})", status.as_u16()),
|
||||
)
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Prune old results
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Delete health check results older than 4 days.
|
||||
async fn prune_old_results(pool: &PgPool) {
|
||||
match sqlx::query(
|
||||
"DELETE FROM host_health_check_results WHERE checked_at < NOW() - INTERVAL '4 days'",
|
||||
)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!(
|
||||
rows_deleted = result.rows_affected(),
|
||||
"Health check poller: pruned old results"
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health check poller: failed to prune old results");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Health check gate for job executor
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Check whether all enabled health checks for a host are healthy.
|
||||
///
|
||||
/// Returns `Ok(true)` if all checks pass (or no checks are configured),
|
||||
/// `Ok(false)` if any check is unhealthy or has no result yet.
|
||||
pub async fn check_host_health_checks(pool: &PgPool, host_id: Uuid) -> anyhow::Result<bool> {
|
||||
// Check if there are any enabled health checks for this host.
|
||||
let check_count: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM host_health_checks WHERE host_id = $1 AND enabled = TRUE",
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if check_count.0 == 0 {
|
||||
// No health checks configured for this host — treat as healthy.
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// Find any enabled check that has no healthy result or an unhealthy latest result.
|
||||
let unhealthy_count: (i64,) = sqlx::query_as(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM host_health_checks hc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT healthy
|
||||
FROM host_health_check_results r
|
||||
WHERE r.check_id = hc.id
|
||||
ORDER BY r.checked_at DESC
|
||||
LIMIT 1
|
||||
) latest ON true
|
||||
WHERE hc.host_id = $1
|
||||
AND hc.enabled = TRUE
|
||||
AND (latest.healthy IS NULL OR latest.healthy = FALSE)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(unhealthy_count.0 == 0)
|
||||
}
|
||||
249
crates/pm-worker/src/health_poller.rs
Normal file
249
crates/pm-worker/src/health_poller.rs
Normal file
@ -0,0 +1,249 @@
|
||||
//! Periodic health poller for all registered hosts.
|
||||
//!
|
||||
//! Polls every host via the agent `/health` endpoint on each tick of
|
||||
//! `health_poll_interval_secs`, with bounded concurrency controlled by a
|
||||
//! [`tokio::sync::Semaphore`]. Also calls `/system/info` to refresh
|
||||
//! `os_family`, `os_name`, `arch`, and `agent_version` in the hosts table.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
use pm_core::{config::AppConfig, models::HostHealthStatus};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host projection fetched for each poll cycle.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the health poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all registered hosts are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_health_data` and the `hosts` table is updated.
|
||||
pub async fn run_health_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.health_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(interval_secs, "Health poller started");
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Load certs on each cycle so cert rotation is picked up automatically.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health poller: failed to load agent certs — skipping cycle");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
// Fetch all hosts.
|
||||
let hosts: Vec<HostRow> = match sqlx::query_as(
|
||||
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health poller: failed to fetch hosts");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!("Health poller: no hosts registered, skipping cycle");
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = hosts.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for host in hosts {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
poll_host_health(pool, host, &cert, &key, &ca).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect results and tally counts.
|
||||
let mut healthy = 0usize;
|
||||
let mut degraded = 0usize;
|
||||
let mut unreachable = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(HostHealthStatus::Healthy) => healthy += 1,
|
||||
Ok(HostHealthStatus::Degraded) => degraded += 1,
|
||||
Ok(HostHealthStatus::Unreachable) => unreachable += 1,
|
||||
Ok(_) => {},
|
||||
Err(e) => tracing::error!(error = %e, "Health poller task panicked"),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
total,
|
||||
healthy,
|
||||
degraded,
|
||||
unreachable,
|
||||
"Health poll cycle complete"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll a single host, persist the result, and return the determined status.
|
||||
///
|
||||
/// Also updates `agent_version` from the health response and
|
||||
/// `os_family`/`os_name`/`arch` from the `/system/info` endpoint when available.
|
||||
async fn poll_host_health(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> HostHealthStatus {
|
||||
// Determine status, payload, agent version, and optional system info.
|
||||
let (status, payload, agent_version, sys_info) = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Health poller: failed to build AgentClient"
|
||||
);
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
},
|
||||
Ok(client) => {
|
||||
let (status, payload, version) = match client.health().await {
|
||||
Ok(data) => {
|
||||
let payload = serde_json::to_value(&data).unwrap_or_default();
|
||||
(HostHealthStatus::Healthy, payload, Some(data.version))
|
||||
},
|
||||
Err(AgentClientError::Timeout) => {
|
||||
tracing::warn!(host_id = %host.id, "Health poller: agent timed out");
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
)
|
||||
},
|
||||
Err(AgentClientError::Connect(_)) => {
|
||||
tracing::warn!(host_id = %host.id, "Health poller: agent connection refused");
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
)
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Health poller: agent error");
|
||||
(
|
||||
HostHealthStatus::Degraded,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
None,
|
||||
)
|
||||
},
|
||||
};
|
||||
|
||||
// Try to fetch system info for OS/arch details (best-effort).
|
||||
let sys_info = if status != HostHealthStatus::Unreachable {
|
||||
match client.system_info().await {
|
||||
Ok(info) => Some(info),
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Health poller: failed to get system info (non-fatal)"
|
||||
);
|
||||
None
|
||||
},
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(status, payload, version, sys_info)
|
||||
},
|
||||
};
|
||||
|
||||
// Insert into host_health_data.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_data (host_id, status, payload)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&status)
|
||||
.bind(&payload)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to insert health data");
|
||||
}
|
||||
|
||||
// Build OS name from system info components (e.g. "Ubuntu 24.04").
|
||||
let os_name_from_sysinfo = sys_info
|
||||
.as_ref()
|
||||
.map(|i| format!("{} {}", i.os, i.os_version));
|
||||
|
||||
// Update hosts table with health status, agent version, and OS details.
|
||||
// COALESCE preserves existing values when new data is unavailable.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE hosts
|
||||
SET health_status = $2, last_health_at = NOW(),
|
||||
agent_version = COALESCE($3, agent_version),
|
||||
os_family = COALESCE($4, os_family),
|
||||
os_name = COALESCE($5, os_name),
|
||||
arch = COALESCE($6, arch)
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&status)
|
||||
.bind(&agent_version)
|
||||
.bind(sys_info.as_ref().map(|i| i.os.as_str()))
|
||||
.bind(os_name_from_sysinfo)
|
||||
.bind(sys_info.as_ref().map(|i| i.architecture.as_str()))
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to update host status");
|
||||
}
|
||||
|
||||
status
|
||||
}
|
||||
1057
crates/pm-worker/src/job_executor.rs
Executable file
1057
crates/pm-worker/src/job_executor.rs
Executable file
File diff suppressed because it is too large
Load Diff
208
crates/pm-worker/src/main.rs
Executable file
208
crates/pm-worker/src/main.rs
Executable file
@ -0,0 +1,208 @@
|
||||
//! pm-worker — Linux Patch Manager background worker.
|
||||
//!
|
||||
//! Handles scheduled polling, job execution, maintenance window scheduling,
|
||||
//! retry logic, email notifications, audit integrity verification, and data pruning.
|
||||
|
||||
mod agent_loader;
|
||||
mod audit_verifier;
|
||||
mod email;
|
||||
mod health_check_poller;
|
||||
mod health_poller;
|
||||
mod job_executor;
|
||||
mod maintenance_scheduler;
|
||||
mod patch_poller;
|
||||
mod refresh_listener;
|
||||
mod ws_relay;
|
||||
|
||||
use chrono::Utc;
|
||||
use pm_core::{config::AppConfig, db, logging};
|
||||
use sqlx::PgPool;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time;
|
||||
|
||||
use audit_verifier::run_audit_verifier;
|
||||
use health_check_poller::run_health_check_poller;
|
||||
use health_poller::run_health_poller;
|
||||
use job_executor::run_job_executor;
|
||||
use maintenance_scheduler::run_maintenance_scheduler;
|
||||
use patch_poller::run_patch_poller;
|
||||
use refresh_listener::run_refresh_listener;
|
||||
use ws_relay::run_ws_relay;
|
||||
|
||||
/// Minimum number of applied migrations the worker requires before
|
||||
/// accepting work. Prevents the worker from running against a schema
|
||||
/// that hasn't been migrated yet.
|
||||
const REQUIRED_MIGRATION_COUNT: i64 = 16;
|
||||
|
||||
/// How long to wait between schema-version checks before giving up.
|
||||
const SCHEMA_CHECK_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Install the default crypto provider for rustls (required since 0.23)
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install rustls crypto provider");
|
||||
|
||||
// Load configuration
|
||||
let config_path = std::env::var("PATCH_MANAGER_CONFIG")
|
||||
.unwrap_or_else(|_| "/etc/patch-manager/config.toml".to_string());
|
||||
|
||||
let config = AppConfig::load(&config_path).unwrap_or_else(|_| {
|
||||
eprintln!("Config file not found or invalid, using defaults");
|
||||
AppConfig::default()
|
||||
});
|
||||
|
||||
// Initialize logging
|
||||
logging::init(&config.logging);
|
||||
|
||||
tracing::info!(
|
||||
version = env!("CARGO_PKG_VERSION"),
|
||||
"patch-manager-worker starting"
|
||||
);
|
||||
|
||||
// Initialize database pool
|
||||
let pool = db::init_pool(&config.database).await?;
|
||||
|
||||
// Wait for schema to be at the expected version (web process runs migrations)
|
||||
wait_for_schema(&pool).await?;
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Spawn worker tasks
|
||||
let heartbeat_handle = tokio::spawn(run_heartbeat(
|
||||
pool.clone(),
|
||||
config.worker.heartbeat_interval_secs,
|
||||
));
|
||||
|
||||
// M4: agent health poller, patch data poller, on-demand refresh listener
|
||||
let health_handle = tokio::spawn(run_health_poller(pool.clone(), config.clone()));
|
||||
let patch_handle = tokio::spawn(run_patch_poller(pool.clone(), config.clone()));
|
||||
let refresh_handle = tokio::spawn(run_refresh_listener(pool.clone(), config.clone()));
|
||||
|
||||
// M5: job execution engine
|
||||
let job_exec_handle = tokio::spawn(run_job_executor(pool.clone(), config.clone()));
|
||||
|
||||
// M6: maintenance window scheduler
|
||||
let maint_sched_handle = tokio::spawn(run_maintenance_scheduler(pool.clone(), config.clone()));
|
||||
|
||||
// M7: WS relay — streams agent job events → DB → pg_notify → browser WS
|
||||
let ws_relay_handle = tokio::spawn(run_ws_relay(pool.clone(), config.clone()));
|
||||
|
||||
// M11: audit integrity verification (runs every 24 hours)
|
||||
let audit_verifier_handle = tokio::spawn(run_audit_verifier(pool.clone(), config.clone()));
|
||||
|
||||
// Health check poller — runs configured service/HTTP health checks
|
||||
let health_check_handle = tokio::spawn(run_health_check_poller(pool.clone(), config.clone()));
|
||||
|
||||
// Enrollment cleanup task (runs every hour)
|
||||
let enrollment_cleanup_handle = tokio::spawn(run_enrollment_cleanup_task(pool.clone()));
|
||||
|
||||
tracing::info!("Worker tasks started");
|
||||
|
||||
// Wait for all tasks (they run indefinitely)
|
||||
let _ = tokio::join!(
|
||||
heartbeat_handle,
|
||||
health_handle,
|
||||
patch_handle,
|
||||
refresh_handle,
|
||||
job_exec_handle,
|
||||
maint_sched_handle,
|
||||
ws_relay_handle,
|
||||
audit_verifier_handle,
|
||||
health_check_handle,
|
||||
enrollment_cleanup_handle,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Wait until the database schema has at least `REQUIRED_MIGRATION_COUNT`
|
||||
/// successful migrations applied. Retries every 5 seconds up to
|
||||
/// `SCHEMA_CHECK_TIMEOUT`.
|
||||
async fn wait_for_schema(pool: &PgPool) -> anyhow::Result<()> {
|
||||
let deadline = tokio::time::Instant::now() + SCHEMA_CHECK_TIMEOUT;
|
||||
|
||||
loop {
|
||||
match db::check_schema_version(pool).await {
|
||||
Ok(count) if count >= REQUIRED_MIGRATION_COUNT => {
|
||||
tracing::info!(migration_count = count, "Schema version check passed");
|
||||
return Ok(());
|
||||
},
|
||||
Ok(count) => {
|
||||
tracing::warn!(
|
||||
migration_count = count,
|
||||
required = REQUIRED_MIGRATION_COUNT,
|
||||
"Schema not ready, waiting..."
|
||||
);
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Schema version check failed, retrying...");
|
||||
},
|
||||
}
|
||||
|
||||
if tokio::time::Instant::now() >= deadline {
|
||||
anyhow::bail!(
|
||||
"Schema not ready after {}s — is the web process running migrations?",
|
||||
SCHEMA_CHECK_TIMEOUT.as_secs()
|
||||
);
|
||||
}
|
||||
|
||||
time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes a heartbeat row to `worker_heartbeat` every `interval_secs`.
|
||||
/// The web process can query this to confirm the worker is alive.
|
||||
async fn run_heartbeat(pool: PgPool, interval_secs: u64) {
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
let mut ticker = time::interval(interval);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO worker_heartbeat (id, last_seen, worker_version)
|
||||
VALUES (1, NOW(), $1)
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET last_seen = EXCLUDED.last_seen,
|
||||
worker_version = EXCLUDED.worker_version
|
||||
"#,
|
||||
)
|
||||
.bind(env!("CARGO_PKG_VERSION"))
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => tracing::debug!("Worker heartbeat written"),
|
||||
Err(e) => tracing::error!(error = %e, "Worker heartbeat failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Periodically deletes expired enrollment requests.
|
||||
async fn run_enrollment_cleanup_task(pool: PgPool) {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(3600)); // Every hour
|
||||
interval.tick().await; // Initial tick to run immediately if needed
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = Utc::now();
|
||||
match sqlx::query("DELETE FROM enrollment_requests WHERE expires_at < $1")
|
||||
.bind(now)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!(
|
||||
removed = result.rows_affected(),
|
||||
"Purged expired enrollment requests"
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => tracing::error!(error = %e, "Failed to purge expired enrollment requests"),
|
||||
}
|
||||
}
|
||||
}
|
||||
381
crates/pm-worker/src/maintenance_scheduler.rs
Executable file
381
crates/pm-worker/src/maintenance_scheduler.rs
Executable file
@ -0,0 +1,381 @@
|
||||
//! Maintenance window scheduler.
|
||||
//!
|
||||
//! Polls every 60 seconds and performs two tasks:
|
||||
//!
|
||||
//! 1. **Auto-apply**: For each enabled maintenance window with `auto_apply = true`
|
||||
//! that is currently open, if the host has pending patches and no existing
|
||||
//! patch_apply job queued/running for that window, automatically creates one.
|
||||
//!
|
||||
//! 2. **Dispatch**: For each open window, dispatch any queued non-immediate
|
||||
//! patch jobs associated with the window's host.
|
||||
//!
|
||||
//! A window is considered "open" when:
|
||||
//! - `once` — `start_at <= NOW() < start_at + duration_minutes * '1 minute'`
|
||||
//! - `daily` — current UTC time-of-day is within the window's daily slot
|
||||
//! - `weekly` — same as daily, but only on the matching `recurrence_day` (0=Sun)
|
||||
//! - `monthly` — same as daily, but only on the matching `recurrence_day` (1-31)
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::time;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::job_executor::process_job;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Internal types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct OpenWindowHost {
|
||||
host_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct QueuedJobId {
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct AutoApplyWindow {
|
||||
window_id: Uuid,
|
||||
host_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct PendingPatchHost {
|
||||
host_id: Uuid,
|
||||
patch_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct InsertedJobId {
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public entry point
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run the maintenance scheduler indefinitely.
|
||||
/// Spawned by `pm-worker/src/main.rs` alongside the job executor.
|
||||
pub async fn run_maintenance_scheduler(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Maintenance scheduler started");
|
||||
|
||||
// First tick fires immediately; consume it to align with job_executor.
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(60));
|
||||
ticker.tick().await;
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
tracing::debug!("Maintenance scheduler: checking open windows");
|
||||
|
||||
// Step 1: Auto-create patch_apply jobs for windows with auto_apply=true
|
||||
auto_create_patch_jobs(pool.clone(), config.clone()).await;
|
||||
|
||||
// Step 2: Dispatch any queued non-immediate jobs for open windows
|
||||
dispatch_open_window_jobs(pool.clone(), config.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Step 1: Auto-create patch_apply jobs
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// For each enabled maintenance window that is currently open AND has
|
||||
/// `auto_apply = true`, check if the host has pending patches and no
|
||||
/// existing patch_apply job for this window cycle. If so, create one.
|
||||
async fn auto_create_patch_jobs(pool: PgPool, _config: Arc<AppConfig>) {
|
||||
// Find all open windows with auto_apply=true
|
||||
let auto_windows: Vec<AutoApplyWindow> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT mw.id AS window_id, mw.host_id
|
||||
FROM maintenance_windows mw
|
||||
WHERE mw.enabled = TRUE
|
||||
AND mw.auto_apply = TRUE
|
||||
AND (
|
||||
( mw.recurrence = 'once'
|
||||
AND mw.start_at <= NOW()
|
||||
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
|
||||
)
|
||||
OR
|
||||
( mw.recurrence = 'daily'
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
( mw.recurrence = 'weekly'
|
||||
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
( mw.recurrence = 'monthly'
|
||||
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(w) => w,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "auto_create_patch_jobs: open-windows query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if auto_windows.is_empty() {
|
||||
tracing::debug!("auto_create: no open auto-apply windows this cycle");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
auto_window_count = auto_windows.len(),
|
||||
"auto_create: found open auto-apply windows"
|
||||
);
|
||||
|
||||
for win in &auto_windows {
|
||||
// Check if host has pending patches
|
||||
let pending: Option<PendingPatchHost> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT host_id, patch_count
|
||||
FROM host_patch_data
|
||||
WHERE host_id = $1 AND patch_count > 0
|
||||
"#,
|
||||
)
|
||||
.bind(win.host_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %win.host_id,
|
||||
"auto_create: patch data query failed"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let Some(pending) = pending else {
|
||||
tracing::debug!(
|
||||
host_id = %win.host_id,
|
||||
"auto_create: no pending patches, skipping"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
// Check if there's already a queued/running patch_apply job for this host
|
||||
// that was created during this window cycle (within the window's time range).
|
||||
// We use a simpler check: any non-completed patch_apply job for this host
|
||||
// that references this maintenance window, OR any non-immediate job without
|
||||
// a window that was created since the window opened.
|
||||
let existing_job: bool = match sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM patch_jobs pj
|
||||
JOIN patch_job_hosts pjh ON pj.id = pjh.job_id
|
||||
WHERE pjh.host_id = $1
|
||||
AND pj.status IN ('queued', 'running', 'pending')
|
||||
AND pj.kind = 'patch_apply'
|
||||
AND (
|
||||
pj.maintenance_window_id = $2
|
||||
OR
|
||||
(pj.immediate = FALSE AND pj.created_at >=
|
||||
(SELECT start_at - INTERVAL '5 minutes' FROM maintenance_windows WHERE id = $2)
|
||||
)
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(win.host_id)
|
||||
.bind(win.window_id)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %win.host_id,
|
||||
"auto_create: existing job check failed"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if existing_job {
|
||||
tracing::debug!(
|
||||
host_id = %win.host_id,
|
||||
window_id = %win.window_id,
|
||||
"auto_create: existing job already queued/running, skipping"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create a new patch_apply job for this host, linked to the window.
|
||||
let job: Option<InsertedJobId> = match sqlx::query_as(
|
||||
r#"
|
||||
WITH new_job AS (
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, maintenance_window_id, immediate, patch_selection, notes)
|
||||
VALUES
|
||||
('patch_apply', 'queued', $1, FALSE, '[]'::jsonb,
|
||||
'Auto-created by maintenance window scheduler')
|
||||
RETURNING id AS job_id
|
||||
)
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
SELECT new_job.job_id, $2, 'queued'
|
||||
FROM new_job
|
||||
RETURNING job_id
|
||||
"#,
|
||||
)
|
||||
.bind(win.window_id)
|
||||
.bind(win.host_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %win.host_id,
|
||||
window_id = %win.window_id,
|
||||
"auto_create: job insert failed"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(job) = job {
|
||||
tracing::info!(
|
||||
job_id = %job.job_id,
|
||||
host_id = %win.host_id,
|
||||
window_id = %win.window_id,
|
||||
patch_count = pending.patch_count,
|
||||
"auto_create: created patch_apply job for host in maintenance window"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Step 2: Dispatch queued non-immediate jobs
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Find all hosts with a currently-open maintenance window, then for each,
|
||||
/// find their queued non-immediate job entries and dispatch them.
|
||||
async fn dispatch_open_window_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
// ── 1. Find all host_ids with an open window right now ─────────────────
|
||||
let open_hosts: Vec<OpenWindowHost> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT mw.host_id
|
||||
FROM maintenance_windows mw
|
||||
WHERE mw.enabled = TRUE
|
||||
AND (
|
||||
-- One-time: absolute window
|
||||
( mw.recurrence = 'once'
|
||||
AND mw.start_at <= NOW()
|
||||
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
|
||||
)
|
||||
OR
|
||||
-- Daily: time-of-day slot, any day
|
||||
( mw.recurrence = 'daily'
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
-- Weekly: matching day-of-week + time-of-day slot
|
||||
( mw.recurrence = 'weekly'
|
||||
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
-- Monthly: matching day-of-month + time-of-day slot
|
||||
( mw.recurrence = 'monthly'
|
||||
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(hosts) => hosts,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "dispatch_open_window_jobs: open-hosts query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if open_hosts.is_empty() {
|
||||
tracing::debug!("Maintenance scheduler: no open windows this cycle");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
open_host_count = open_hosts.len(),
|
||||
"Maintenance scheduler: found hosts with open windows"
|
||||
);
|
||||
|
||||
// ── 2. For each open host, find distinct queued non-immediate job IDs ──
|
||||
for host in open_hosts {
|
||||
let job_ids: Vec<QueuedJobId> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT pjh.job_id
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs j ON j.id = pjh.job_id
|
||||
WHERE pjh.host_id = $1
|
||||
AND pjh.status = 'queued'
|
||||
AND j.immediate = FALSE
|
||||
AND j.status != 'cancelled'
|
||||
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
|
||||
"#,
|
||||
)
|
||||
.bind(host.host_id)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %host.host_id,
|
||||
"dispatch_open_window_jobs: queued jobs query failed"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
for job in job_ids {
|
||||
tracing::info!(
|
||||
job_id = %job.job_id,
|
||||
host_id = %host.host_id,
|
||||
"Maintenance scheduler: dispatching non-immediate job (window open)"
|
||||
);
|
||||
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
let job_id = job.job_id;
|
||||
tokio::spawn(async move {
|
||||
process_job(p, c, job_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
202
crates/pm-worker/src/patch_poller.rs
Executable file
202
crates/pm-worker/src/patch_poller.rs
Executable file
@ -0,0 +1,202 @@
|
||||
//! Periodic patch-data poller for all registered hosts.
|
||||
//!
|
||||
//! Polls every host via the agent `/patches` and `/packages` endpoints on
|
||||
//! each tick of `patch_poll_interval_secs`, with bounded concurrency
|
||||
//! controlled by a [`tokio::sync::Semaphore`].
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::AgentClient;
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host projection fetched for each poll cycle.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the patch poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all registered hosts are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_patch_data` and `hosts.last_patch_at` is updated.
|
||||
pub async fn run_patch_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.patch_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(interval_secs, "Patch poller started");
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller: failed to load agent certs — skipping cycle");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
let hosts: Vec<HostRow> = match sqlx::query_as(
|
||||
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts ORDER BY id",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller: failed to fetch hosts");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!("Patch poller: no hosts registered, skipping cycle");
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = hosts.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for host in hosts {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
poll_host_patches(pool, host, &cert, &key, &ca).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let mut succeeded = 0usize;
|
||||
let mut failed = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(true) => succeeded += 1,
|
||||
Ok(false) => failed += 1,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller task panicked");
|
||||
failed += 1;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(total, succeeded, failed, "Patch poll cycle complete");
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll a single host for patch and package data, persist the result.
|
||||
/// Returns `true` on success, `false` on any error.
|
||||
async fn poll_host_patches(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> bool {
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: failed to build AgentClient");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
// Fetch patches and packages concurrently.
|
||||
let (patches_result, packages_result) =
|
||||
tokio::join!(client.patches(), client.packages_upgradable());
|
||||
|
||||
let patches_data = match patches_result {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: patches() failed");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let packages_data = match packages_result {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: packages_upgradable() failed");
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
|
||||
let installed_packages = serde_json::to_value(&packages_data.packages).unwrap_or_default();
|
||||
let patch_count = patches_data.total as i32;
|
||||
let cve_count = patches_data
|
||||
.patches
|
||||
.iter()
|
||||
.filter(|p| !p.cve_ids.is_empty())
|
||||
.count() as i32;
|
||||
|
||||
// Upsert into host_patch_data (one row per host, latest poll wins).
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_patch_data
|
||||
(host_id, available_patches, installed_packages, patch_count, cve_count)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (host_id) DO UPDATE SET
|
||||
available_patches = EXCLUDED.available_patches,
|
||||
installed_packages = EXCLUDED.installed_packages,
|
||||
patch_count = EXCLUDED.patch_count,
|
||||
cve_count = EXCLUDED.cve_count,
|
||||
polled_at = NOW()
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&available_patches)
|
||||
.bind(&installed_packages)
|
||||
.bind(patch_count)
|
||||
.bind(cve_count)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to insert patch data");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update hosts.last_patch_at.
|
||||
if let Err(e) = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
|
||||
.bind(host.id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to update last_patch_at");
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
host_id = %host.id,
|
||||
patch_count,
|
||||
cve_count,
|
||||
"Patch data collected"
|
||||
);
|
||||
|
||||
true
|
||||
}
|
||||
269
crates/pm-worker/src/refresh_listener.rs
Executable file
269
crates/pm-worker/src/refresh_listener.rs
Executable file
@ -0,0 +1,269 @@
|
||||
//! On-demand refresh listener.
|
||||
//!
|
||||
//! Listens on the PostgreSQL `refresh_requested` NOTIFY channel. When a
|
||||
//! notification arrives the payload is expected to be a host UUID string.
|
||||
//! The listener immediately polls that host for health and patch data and
|
||||
//! persists the results — bypassing the normal poll intervals.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
use pm_core::{config::AppConfig, models::HostHealthStatus};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::time;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host row used for on-demand refresh.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the LISTEN/NOTIFY refresh listener indefinitely.
|
||||
///
|
||||
/// Automatically reconnects if the underlying PostgreSQL connection drops.
|
||||
pub async fn run_refresh_listener(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Refresh listener started — listening on 'refresh_requested'");
|
||||
|
||||
loop {
|
||||
if let Err(e) = listen_loop(&pool, &config).await {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Refresh listener disconnected, reconnecting in 5s"
|
||||
);
|
||||
time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inner loop — returns `Err` only on a fatal listener error so the outer
|
||||
/// loop can reconnect.
|
||||
async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> {
|
||||
let mut listener = sqlx::postgres::PgListener::connect(&config.database.url).await?;
|
||||
|
||||
listener.listen("refresh_requested").await?;
|
||||
|
||||
tracing::debug!("Refresh listener connected and listening");
|
||||
|
||||
loop {
|
||||
let notification = listener.recv().await?;
|
||||
let payload = notification.payload().to_string();
|
||||
|
||||
tracing::info!(payload, "Refresh notification received");
|
||||
|
||||
let host_id = match payload.parse::<Uuid>() {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
payload,
|
||||
error = %e,
|
||||
"Refresh listener: invalid UUID in notification payload"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Fetch the host from the database.
|
||||
let host: Option<HostRow> = sqlx::query_as(
|
||||
"SELECT id, host(ip_address)::text AS ip_address, agent_port FROM hosts WHERE id = $1",
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let host = match host {
|
||||
Some(h) => h,
|
||||
None => {
|
||||
tracing::warn!(%host_id, "Refresh listener: host not found");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Load certs for this refresh.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
%host_id,
|
||||
error = %e,
|
||||
"Refresh listener: failed to load agent certs"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Spawn the actual work so the listener loop is not blocked.
|
||||
let pool_clone = pool.clone();
|
||||
let cert = certs.client_cert;
|
||||
let key = certs.client_key;
|
||||
let ca = certs.ca_cert;
|
||||
|
||||
tokio::spawn(async move {
|
||||
refresh_host(pool_clone, host, &cert, &key, &ca).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a full health + patch refresh for one host and persist results.
|
||||
async fn refresh_host(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) {
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to build AgentClient"
|
||||
);
|
||||
persist_health_unreachable(&pool, host.id).await;
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
// ── Health ────────────────────────────────────────────────────────────
|
||||
let (health_status, health_payload) = match client.health().await {
|
||||
Ok(data) => {
|
||||
let payload = serde_json::to_value(&data).unwrap_or_default();
|
||||
(HostHealthStatus::Healthy, payload)
|
||||
},
|
||||
Err(AgentClientError::Timeout) | Err(AgentClientError::Connect(_)) => {
|
||||
tracing::warn!(host_id = %host.id, "Refresh: agent unreachable");
|
||||
(
|
||||
HostHealthStatus::Unreachable,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
)
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Refresh: health error");
|
||||
(
|
||||
HostHealthStatus::Degraded,
|
||||
serde_json::Value::Object(Default::default()),
|
||||
)
|
||||
},
|
||||
};
|
||||
|
||||
persist_health(&pool, host.id, &health_status, &health_payload).await;
|
||||
|
||||
// ── Patch data ────────────────────────────────────────────────────────
|
||||
let (patches_result, packages_result) =
|
||||
tokio::join!(client.patches(), client.packages_upgradable());
|
||||
|
||||
match (patches_result, packages_result) {
|
||||
(Ok(patches_data), Ok(packages_data)) => {
|
||||
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
|
||||
let installed_packages =
|
||||
serde_json::to_value(&packages_data.packages).unwrap_or_default();
|
||||
let patch_count = patches_data.total as i32;
|
||||
let cve_count = patches_data
|
||||
.patches
|
||||
.iter()
|
||||
.filter(|p| !p.cve_ids.is_empty())
|
||||
.count() as i32;
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_patch_data
|
||||
(host_id, available_patches, installed_packages, patch_count, cve_count)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (host_id) DO UPDATE SET
|
||||
available_patches = EXCLUDED.available_patches,
|
||||
installed_packages = EXCLUDED.installed_packages,
|
||||
patch_count = EXCLUDED.patch_count,
|
||||
cve_count = EXCLUDED.cve_count,
|
||||
polled_at = NOW()
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&available_patches)
|
||||
.bind(&installed_packages)
|
||||
.bind(patch_count)
|
||||
.bind(cve_count)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to insert patch data"
|
||||
);
|
||||
} else {
|
||||
let _ = sqlx::query("UPDATE hosts SET last_patch_at = NOW() WHERE id = $1")
|
||||
.bind(host.id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
host_id = %host.id,
|
||||
patch_count,
|
||||
cve_count,
|
||||
"On-demand refresh complete"
|
||||
);
|
||||
}
|
||||
},
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to collect patch data"
|
||||
);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_health_unreachable(pool: &PgPool, host_id: Uuid) {
|
||||
let status = HostHealthStatus::Unreachable;
|
||||
let payload = serde_json::Value::Object(Default::default());
|
||||
persist_health(pool, host_id, &status, &payload).await;
|
||||
}
|
||||
|
||||
async fn persist_health(
|
||||
pool: &PgPool,
|
||||
host_id: Uuid,
|
||||
status: &HostHealthStatus,
|
||||
payload: &serde_json::Value,
|
||||
) {
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_data (host_id, status, payload)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(status)
|
||||
.bind(payload)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
%host_id,
|
||||
error = %e,
|
||||
"Refresh: failed to insert health data"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
sqlx::query("UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1")
|
||||
.bind(host_id)
|
||||
.bind(status)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%host_id, error = %e, "Refresh: failed to update host health_status");
|
||||
}
|
||||
}
|
||||
718
crates/pm-worker/src/ws_relay.rs
Executable file
718
crates/pm-worker/src/ws_relay.rs
Executable file
@ -0,0 +1,718 @@
|
||||
//! WS relay — M7
|
||||
//!
|
||||
//! For every running `patch_job_hosts` row that has an `agent_job_id`, open a
|
||||
//! WebSocket to the corresponding agent, stream job-status events, update the
|
||||
//! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS
|
||||
//! handler can forward the event to connected clients.
|
||||
|
||||
use std::{collections::HashSet, error::Error, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::StreamExt;
|
||||
use rustls::{
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
ClientConfig as TlsClientConfig, RootCertStore,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::protocol::Message, Connector};
|
||||
use uuid::Uuid;
|
||||
|
||||
use pm_agent_client::client::AgentClient;
|
||||
use pm_agent_client::client::DEFAULT_AGENT_PORT;
|
||||
use pm_core::config::AppConfig;
|
||||
|
||||
// ── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct RunningHostJob {
|
||||
job_id: Uuid,
|
||||
host_id: Uuid,
|
||||
agent_job_id: String,
|
||||
host_address: String,
|
||||
}
|
||||
|
||||
/// JSON event streamed by the agent over its WS endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentWsEvent {
|
||||
#[allow(dead_code)]
|
||||
job_id: String,
|
||||
status: String,
|
||||
output: Option<String>,
|
||||
error: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
progress_percent: Option<u8>,
|
||||
}
|
||||
|
||||
/// Payload broadcast via `pg_notify('job_update', …)`.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NotifyPayload {
|
||||
event_type: String, // "host" or "job"
|
||||
job_id: String,
|
||||
host_id: String,
|
||||
status: String,
|
||||
output: Option<String>,
|
||||
error_message: Option<String>,
|
||||
agent_job_id: String,
|
||||
// Job-level fields (only present when event_type === "job")
|
||||
succeeded_count: Option<i64>,
|
||||
failed_count: Option<i64>,
|
||||
host_count: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Cert PEM bytes for building AgentClient ────────────────────────────────────
|
||||
|
||||
/// Raw PEM bytes read from the security config cert paths.
|
||||
/// Used to build an [`AgentClient`] for HTTP polling fallback.
|
||||
#[derive(Clone)]
|
||||
struct CertPems {
|
||||
client_cert: Vec<u8>,
|
||||
client_key: Vec<u8>,
|
||||
ca_cert: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Read the three PEM files referenced by the security config.
|
||||
/// Mirrors the file reads in [`build_tls_config`] but returns raw bytes instead of
|
||||
/// parsing them into rustls types.
|
||||
async fn read_cert_pems(config: &AppConfig) -> anyhow::Result<CertPems> {
|
||||
let sec = &config.security;
|
||||
let client_cert = tokio::fs::read(&sec.agent_client_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
|
||||
let client_key = tokio::fs::read(&sec.agent_client_key_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
|
||||
let ca_cert = tokio::fs::read(&sec.ca_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
|
||||
Ok(CertPems {
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Long-running task: polls the DB for running host-jobs and spawns a per-pair
|
||||
/// relay task for each one that isn't already being tracked.
|
||||
pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("WS relay task started");
|
||||
|
||||
let active: Arc<Mutex<HashSet<(Uuid, Uuid)>>> = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(10));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let rows = match query_running_jobs(&pool).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: DB poll failed");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
for row in rows {
|
||||
let key = (row.job_id, row.host_id);
|
||||
|
||||
// Skip pairs that already have an active relay.
|
||||
if active.lock().await.contains(&key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build the rustls ClientConfig once per connection.
|
||||
let tls_config = match build_tls_config(&config).await {
|
||||
Ok(c) => Arc::new(c),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: TLS config error");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Read raw cert PEM bytes for HTTP polling fallback.
|
||||
let cert_pems = match read_cert_pems(&config).await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: cert PEM read error");
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
let poll_interval = config.worker.ws_relay_poll_interval_secs;
|
||||
|
||||
active.lock().await.insert(key);
|
||||
|
||||
let pool_c = pool.clone();
|
||||
let active_c = active.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
agent_job_id = %row.agent_job_id,
|
||||
host = %row.host_address,
|
||||
"WS relay: starting relay"
|
||||
);
|
||||
|
||||
match relay_one_job(&pool_c, &row, tls_config, &cert_pems, poll_interval).await {
|
||||
Ok(()) => tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: completed"
|
||||
),
|
||||
Err(e) => tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: ended with error"
|
||||
),
|
||||
}
|
||||
|
||||
active_c.lock().await.remove(&key);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── DB helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn query_running_jobs(pool: &PgPool) -> anyhow::Result<Vec<RunningHostJob>> {
|
||||
sqlx::query_as::<_, RunningHostJob>(
|
||||
r#"
|
||||
SELECT
|
||||
pjh.job_id,
|
||||
pjh.host_id,
|
||||
pjh.agent_job_id,
|
||||
COALESCE(h.fqdn, host(h.ip_address)::text) AS host_address
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.status = 'running'::job_status
|
||||
AND pjh.agent_job_id IS NOT NULL
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("query_running_jobs")
|
||||
}
|
||||
|
||||
// ── TLS ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn build_tls_config(config: &AppConfig) -> anyhow::Result<TlsClientConfig> {
|
||||
let sec = &config.security;
|
||||
|
||||
let cert_pem = tokio::fs::read(&sec.agent_client_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
|
||||
let key_pem = tokio::fs::read(&sec.agent_client_key_path)
|
||||
.await
|
||||
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
|
||||
let ca_pem = tokio::fs::read(&sec.ca_cert_path)
|
||||
.await
|
||||
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
|
||||
|
||||
// Parse client certificate chain.
|
||||
let client_certs: Vec<CertificateDer<'static>> = {
|
||||
let mut cur = std::io::Cursor::new(&cert_pem);
|
||||
rustls_pemfile::certs(&mut cur)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.context("parse client cert PEM")?
|
||||
};
|
||||
|
||||
// Parse client private key.
|
||||
let client_key: PrivateKeyDer<'static> = {
|
||||
let mut cur = std::io::Cursor::new(&key_pem);
|
||||
rustls_pemfile::private_key(&mut cur)
|
||||
.context("parse client key PEM")?
|
||||
.context("no private key in PEM")?
|
||||
};
|
||||
|
||||
// Build root store from CA cert.
|
||||
let mut root_store = RootCertStore::empty();
|
||||
{
|
||||
let mut cur = std::io::Cursor::new(&ca_pem);
|
||||
for cert_result in rustls_pemfile::certs(&mut cur) {
|
||||
root_store
|
||||
.add(cert_result.context("read CA cert entry")?)
|
||||
.context("add CA cert to root store")?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut config = TlsClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_client_auth_cert(client_certs, client_key)
|
||||
.context("build TlsClientConfig")?;
|
||||
|
||||
// WebSocket requires HTTP/1.1 — without ALPN the server may negotiate
|
||||
// h2 (HTTP/2), which breaks the WebSocket upgrade handshake
|
||||
// ("Key mismatch in Sec-WebSocket-Accept header").
|
||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
// ── Per-job relay ─────────────────────────────────────────────────────────────
|
||||
|
||||
async fn relay_one_job(
|
||||
pool: &PgPool,
|
||||
row: &RunningHostJob,
|
||||
tls_config: Arc<TlsClientConfig>,
|
||||
cert_pems: &CertPems,
|
||||
poll_interval_secs: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!(
|
||||
"wss://{}:{}/api/v1/ws/jobs",
|
||||
row.host_address, DEFAULT_AGENT_PORT,
|
||||
);
|
||||
|
||||
let (ws_stream, _) = match connect_async_tls_with_config(
|
||||
url.as_str(),
|
||||
None,
|
||||
false,
|
||||
Some(Connector::Rustls(tls_config)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(ws) => ws,
|
||||
Err(e) => {
|
||||
// Log the full error chain for TLS debugging
|
||||
let mut source = e.source();
|
||||
let mut depth = 0;
|
||||
while let Some(err) = source {
|
||||
tracing::warn!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
depth,
|
||||
error = %err,
|
||||
"WS relay: TLS connection error detail"
|
||||
);
|
||||
source = err.source();
|
||||
depth += 1;
|
||||
}
|
||||
// Fall back to HTTP polling instead of returning an error.
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
host = %row.host_address,
|
||||
"WS relay: WebSocket connection failed, falling back to HTTP polling"
|
||||
);
|
||||
return relay_one_job_poll(pool, row, cert_pems, poll_interval_secs).await;
|
||||
},
|
||||
};
|
||||
|
||||
let (_sink, mut stream) = ws_stream.split();
|
||||
|
||||
while let Some(frame) = stream.next().await {
|
||||
let frame = match frame {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: stream error"
|
||||
);
|
||||
break;
|
||||
},
|
||||
};
|
||||
|
||||
let text = match frame {
|
||||
Message::Text(t) => t.to_string(),
|
||||
Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(),
|
||||
Message::Close(_) => {
|
||||
tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly");
|
||||
break;
|
||||
},
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event: AgentWsEvent = match serde_json::from_str(&text) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e, raw = %text,
|
||||
"WS relay: unparseable agent frame"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
process_event(pool, row, &event).await;
|
||||
|
||||
if matches!(event.status.as_str(), "succeeded" | "failed" | "cancelled") {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %event.status,
|
||||
"WS relay: terminal state — stopping"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Per-job HTTP polling fallback ─────────────────────────────────────────────
|
||||
|
||||
/// Fall back to HTTP polling when the WebSocket connection fails.
|
||||
///
|
||||
/// Builds an [`AgentClient`] from the same cert paths used for TLS, polls
|
||||
/// `GET /api/v1/jobs/{id}` every `poll_interval_secs`, and calls
|
||||
/// [`process_event`] whenever the status changes from the previous poll.
|
||||
/// Stops when the job reaches a terminal state (succeeded/failed/cancelled).
|
||||
async fn relay_one_job_poll(
|
||||
pool: &PgPool,
|
||||
row: &RunningHostJob,
|
||||
cert_pems: &CertPems,
|
||||
poll_interval_secs: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
// Build the HTTP client using the same mTLS certs.
|
||||
let agent_client = AgentClient::new(
|
||||
&row.host_address,
|
||||
DEFAULT_AGENT_PORT,
|
||||
&cert_pems.client_cert,
|
||||
&cert_pems.client_key,
|
||||
&cert_pems.ca_cert,
|
||||
)
|
||||
.context("build AgentClient for polling fallback")?;
|
||||
|
||||
let mut last_status: Option<String> = None;
|
||||
let poll_interval = Duration::from_secs(poll_interval_secs);
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let job_status = match agent_client.job_status(&row.agent_job_id).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
error = %e,
|
||||
"WS relay poll: job_status request failed, will retry"
|
||||
);
|
||||
continue;
|
||||
},
|
||||
};
|
||||
|
||||
// Map agent status to the WS event format.
|
||||
// The agent uses "completed" but pm-worker expects "succeeded".
|
||||
// The agent uses "queued" but pm-worker expects "running".
|
||||
let mapped_status = match job_status.status.as_str() {
|
||||
"queued" => "running",
|
||||
"running" => "running",
|
||||
"succeeded" => "succeeded",
|
||||
"completed" => "succeeded",
|
||||
"failed" => "failed",
|
||||
"cancelled" => "cancelled",
|
||||
other => {
|
||||
tracing::warn!(
|
||||
status = %other,
|
||||
job_id = %row.job_id,
|
||||
"WS relay poll: unknown agent status, treating as running"
|
||||
);
|
||||
"running"
|
||||
},
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
agent_status = %job_status.status,
|
||||
mapped_status = %mapped_status,
|
||||
progress = ?job_status.progress_percent,
|
||||
"WS relay poll: fetched job status"
|
||||
);
|
||||
|
||||
// Only process when the status has changed since the last poll.
|
||||
if last_status.as_deref() == Some(mapped_status) {
|
||||
continue;
|
||||
}
|
||||
|
||||
last_status = Some(mapped_status.to_string());
|
||||
|
||||
let event = AgentWsEvent {
|
||||
job_id: job_status.job_id.clone(),
|
||||
status: mapped_status.to_string(),
|
||||
output: job_status.output.clone(),
|
||||
error: job_status.error.clone(),
|
||||
progress_percent: job_status.progress_percent,
|
||||
};
|
||||
|
||||
process_event(pool, row, &event).await;
|
||||
|
||||
// Stop polling on terminal states.
|
||||
if matches!(mapped_status, "succeeded" | "failed" | "cancelled") {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %mapped_status,
|
||||
"WS relay poll: terminal state — stopping"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Event processing ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) {
|
||||
// Map agent status string to DB job_status enum value.
|
||||
let db_status = match event.status.as_str() {
|
||||
"running" => "running",
|
||||
"succeeded" => "succeeded",
|
||||
"failed" => "failed",
|
||||
"cancelled" => "cancelled",
|
||||
other => {
|
||||
tracing::warn!(status = %other, "WS relay: unknown agent status");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
let output = event.output.as_deref().unwrap_or("");
|
||||
let error_msg = event.error.as_deref();
|
||||
|
||||
// Determine timestamps based on terminal state.
|
||||
let is_terminal = matches!(db_status, "succeeded" | "failed" | "cancelled");
|
||||
|
||||
// Update the DB row.
|
||||
let update_result = if is_terminal {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = $1::job_status,
|
||||
output = CASE WHEN $2 != '' THEN $2 ELSE output END,
|
||||
error_message = $3,
|
||||
completed_at = NOW()
|
||||
WHERE job_id = $4
|
||||
AND host_id = $5
|
||||
"#,
|
||||
)
|
||||
.bind(db_status)
|
||||
.bind(output)
|
||||
.bind(error_msg)
|
||||
.bind(row.job_id)
|
||||
.bind(row.host_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = $1::job_status,
|
||||
output = CASE WHEN $2 != '' THEN $2 ELSE output END
|
||||
WHERE job_id = $3
|
||||
AND host_id = $4
|
||||
"#,
|
||||
)
|
||||
.bind(db_status)
|
||||
.bind(output)
|
||||
.bind(row.job_id)
|
||||
.bind(row.host_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
};
|
||||
|
||||
if let Err(e) = update_result {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: DB update failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Also update the parent patch_jobs status when the host-level job reaches
|
||||
// a terminal state: running → if all hosts terminal then update parent.
|
||||
if is_terminal {
|
||||
update_parent_job_status(pool, row.job_id).await;
|
||||
}
|
||||
|
||||
// Fire pg_notify so browser WS handlers forward the host-level event.
|
||||
let payload = NotifyPayload {
|
||||
event_type: "host".to_string(),
|
||||
job_id: row.job_id.to_string(),
|
||||
host_id: row.host_id.to_string(),
|
||||
status: db_status.to_string(),
|
||||
output: event.output.clone(),
|
||||
error_message: event.error.clone(),
|
||||
agent_job_id: row.agent_job_id.clone(),
|
||||
succeeded_count: None,
|
||||
failed_count: None,
|
||||
host_count: None,
|
||||
};
|
||||
|
||||
let payload_json = match serde_json::to_string(&payload) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "WS relay: failed to serialize notify payload");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_json)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: pg_notify failed"
|
||||
);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %db_status,
|
||||
"WS relay: pg_notify sent"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Parent job status rollup ──────────────────────────────────────────────────
|
||||
|
||||
/// After a host-level job reaches a terminal state, check whether ALL hosts for
|
||||
/// that job are now terminal and update the parent `patch_jobs` row accordingly.
|
||||
///
|
||||
/// If the parent job transitions to a terminal status, also fires a `job_update`
|
||||
/// pg_notify with `event_type: "job"` so the frontend can update the job row.
|
||||
async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) {
|
||||
// Count hosts that are still in a non-terminal state.
|
||||
let pending: i64 = match sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
AND status NOT IN (
|
||||
'succeeded'::job_status,
|
||||
'failed'::job_status,
|
||||
'cancelled'::job_status
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if pending > 0 {
|
||||
return; // still hosts running — parent stays running
|
||||
}
|
||||
|
||||
// All hosts terminal — determine final parent status and counts.
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct RollupCounts {
|
||||
total: i64,
|
||||
succeeded: i64,
|
||||
failed: i64,
|
||||
}
|
||||
|
||||
let counts: RollupCounts = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE status = 'succeeded') AS succeeded,
|
||||
COUNT(*) FILTER (WHERE status = 'failed') AS failed
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: rollup query failed");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
let final_status = if counts.failed > 0 {
|
||||
"failed"
|
||||
} else {
|
||||
"succeeded"
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(final_status)
|
||||
.bind(job_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"update_parent_job_status: UPDATE failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"Parent job status updated"
|
||||
);
|
||||
|
||||
// Fire job-level pg_notify so the frontend can update the job row.
|
||||
let payload = NotifyPayload {
|
||||
event_type: "job".to_string(),
|
||||
job_id: job_id.to_string(),
|
||||
host_id: String::new(), // no specific host for job-level events
|
||||
status: final_status.to_string(),
|
||||
output: None,
|
||||
error_message: None,
|
||||
agent_job_id: String::new(),
|
||||
succeeded_count: Some(counts.succeeded),
|
||||
failed_count: Some(counts.failed),
|
||||
host_count: Some(counts.total),
|
||||
};
|
||||
|
||||
let payload_json = match serde_json::to_string(&payload) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: failed to serialize job-level notify payload");
|
||||
return;
|
||||
},
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_json)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
%job_id,
|
||||
"update_parent_job_status: job-level pg_notify failed"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"Job-level pg_notify sent"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user