M5: Patch Deployment & Job Management
Backend: - migrations/003_jobs_scheduling.sql: retry_next_at/last_error columns, pg_notify trigger for immediate job dispatch, retry index - pm-agent-client: ApplyPatchesRequest/Response, AgentJobStatus, RollbackResponse types; apply_patches/job_status/rollback_job client methods + generic POST helper - pm-core/models: JobStatus, JobKind, PatchJob, PatchJobHost, CreateJobRequest, PatchJobSummary - pm-web/routes/jobs.rs: POST/GET /api/v1/jobs, GET /jobs/:id, POST /jobs/:id/cancel, POST /jobs/:id/rollback - pm-worker/job_executor.rs: NOTIFY listener, periodic scanner, execute_host_job, poll_running_jobs, handle_host_failure (3-retry exponential backoff 1m/5m/30m), sync_job_status, retry_pending_jobs - pm-worker/main.rs: spawn job_executor Frontend: - types/index.ts: PatchInfo, PatchJobHost, PatchJob, PatchJobSummary, CreateJobRequest interfaces - api/client.ts: jobsApi (list/get/create/cancel/rollback), patchesApi (getHostPatches) - pages/PatchDeploymentPage.tsx: 3-step MUI Stepper (host select → configure → result) - pages/JobsPage.tsx: job list table, expandable per-host detail, cancel/rollback actions with confirm dialog, load-more pagination - App.tsx: /jobs and /deployment routes wired to real pages cargo check: 0 errors | vite build: 0 errors
This commit is contained in:
@ -1,10 +1,284 @@
|
||||
//! Agent HTTP client stub.
|
||||
//! Full mTLS Rustls-based implementation arrives in M4.
|
||||
//! mTLS HTTP client for communicating with Linux Patch API agents.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use pm_agent_client::client::AgentClient;
|
||||
//!
|
||||
//! # async fn example() -> Result<(), pm_agent_client::error::AgentClientError> {
|
||||
//! let client = AgentClient::new(
|
||||
//! "192.168.1.10",
|
||||
//! 12443,
|
||||
//! include_bytes!("../certs/client.crt"),
|
||||
//! include_bytes!("../certs/client.key"),
|
||||
//! include_bytes!("../certs/ca.crt"),
|
||||
//! )?;
|
||||
//!
|
||||
//! let health = client.health().await?;
|
||||
//! println!("Agent status: {}", health.status);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use thiserror::Error;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AgentClientError {
|
||||
#[error("Not yet implemented")]
|
||||
NotImplemented,
|
||||
use reqwest::{
|
||||
tls::Version,
|
||||
Certificate, ClientBuilder, Identity,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use crate::{
|
||||
error::AgentClientError,
|
||||
types::{
|
||||
AgentEnvelope, HealthData, PackagesData, PatchesData, SystemInfoData,
|
||||
ApplyPatchesRequest, ApplyPatchesResponse, AgentJobStatus, RollbackResponse,
|
||||
},
|
||||
};
|
||||
|
||||
/// Default TCP port that the Linux Patch API agent listens on.
|
||||
pub const DEFAULT_AGENT_PORT: u16 = 12443;
|
||||
|
||||
/// Request timeout applied to every agent API call.
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
// ============================================================
|
||||
// AgentClient
|
||||
// ============================================================
|
||||
|
||||
/// Async HTTP client that speaks mTLS to a single Linux Patch API agent.
|
||||
///
|
||||
/// Construct once via [`AgentClient::new`] and reuse across calls;
|
||||
/// the underlying [`reqwest::Client`] maintains a connection pool.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentClient {
|
||||
/// Underlying HTTP client (configured for mTLS + TLS 1.3).
|
||||
inner: reqwest::Client,
|
||||
/// Base URL of the agent, e.g. `https://10.0.0.5:12443/api/v1`.
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl AgentClient {
|
||||
/// Create a new [`AgentClient`] configured for mTLS.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `host_ip` – IP address (or hostname) of the agent.
|
||||
/// * `port` – TCP port the agent listens on (default [`DEFAULT_AGENT_PORT`]).
|
||||
/// * `client_cert_pem` – PEM-encoded client certificate presented during the TLS handshake.
|
||||
/// * `client_key_pem` – PEM-encoded private key matching `client_cert_pem`.
|
||||
/// * `ca_cert_pem` – PEM-encoded CA certificate used to verify the agent's server cert.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`AgentClientError::Tls`] when certificate parsing fails, or
|
||||
/// [`AgentClientError::Request`] when `reqwest` client construction fails.
|
||||
pub fn new(
|
||||
host_ip: &str,
|
||||
port: u16,
|
||||
client_cert_pem: &[u8],
|
||||
client_key_pem: &[u8],
|
||||
ca_cert_pem: &[u8],
|
||||
) -> Result<Self, AgentClientError> {
|
||||
// Build client identity: reqwest expects cert + key concatenated as PEM.
|
||||
let mut identity_pem = Vec::with_capacity(client_cert_pem.len() + client_key_pem.len());
|
||||
identity_pem.extend_from_slice(client_cert_pem);
|
||||
identity_pem.extend_from_slice(client_key_pem);
|
||||
|
||||
let identity = Identity::from_pem(&identity_pem)
|
||||
.map_err(|e| AgentClientError::Tls(format!("invalid client identity PEM: {e}")))?;
|
||||
|
||||
// Parse the CA certificate used to verify the agent's server certificate.
|
||||
let ca_cert = Certificate::from_pem(ca_cert_pem)
|
||||
.map_err(|e| AgentClientError::Tls(format!("invalid CA certificate PEM: {e}")))?;
|
||||
|
||||
// Build the reqwest client:
|
||||
// - force rustls TLS backend
|
||||
// - disable built-in OS/system trust roots (only trust our internal CA)
|
||||
// - enforce TLS 1.3 minimum
|
||||
// - attach client identity (mTLS)
|
||||
// - add our CA as a trusted root
|
||||
// - apply a global request timeout
|
||||
let inner = ClientBuilder::new()
|
||||
.use_rustls_tls()
|
||||
.tls_built_in_root_certs(false)
|
||||
.min_tls_version(Version::TLS_1_3)
|
||||
.identity(identity)
|
||||
.add_root_certificate(ca_cert)
|
||||
.timeout(REQUEST_TIMEOUT)
|
||||
.build()
|
||||
.map_err(|e| AgentClientError::Request(e))?;
|
||||
|
||||
let base_url = format!("https://{}:{}/api/v1", host_ip, port);
|
||||
tracing::debug!(base_url = %base_url, "AgentClient created");
|
||||
|
||||
Ok(Self { inner, base_url })
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Public API methods
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// `GET /api/v1/health` — check agent liveness and retrieve uptime.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn health(&self) -> Result<HealthData, AgentClientError> {
|
||||
self.get("health", &[]).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/system/info` — retrieve host system information.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn system_info(&self) -> Result<SystemInfoData, AgentClientError> {
|
||||
self.get("system/info", &[]).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/packages?status=upgradable` — list packages with available upgrades.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn packages_upgradable(&self) -> Result<PackagesData, AgentClientError> {
|
||||
self.get("packages", &[("status", "upgradable")]).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/patches` — list available patches with severity and CVE data.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url))]
|
||||
pub async fn patches(&self) -> Result<PatchesData, AgentClientError> {
|
||||
self.get("patches", &[]).await
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Private helpers
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// Execute a GET request against `{base_url}/{path}` with optional query
|
||||
/// parameters, deserialize the [`AgentEnvelope`], and extract the `data`
|
||||
/// field — or propagate an [`AgentClientError::ApiError`].
|
||||
async fn get<T>(
|
||||
&self,
|
||||
path: &str,
|
||||
query: &[(&str, &str)],
|
||||
) -> Result<T, AgentClientError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let url = format!("{}/{}", self.base_url, path);
|
||||
debug!(url = %url, ?query, "Sending GET request to agent");
|
||||
|
||||
let mut request = self.inner.get(&url);
|
||||
if !query.is_empty() {
|
||||
request = request.query(query);
|
||||
}
|
||||
|
||||
let response = request.send().await?;
|
||||
let status = response.status();
|
||||
debug!(url = %url, status = %status, "Received response from agent");
|
||||
|
||||
// Capture body text so we can attempt to deserialise the error envelope
|
||||
// even for non-2xx responses.
|
||||
let body = response.text().await?;
|
||||
|
||||
// Attempt to parse the standard agent envelope regardless of HTTP status.
|
||||
// The agent may embed a structured error body on 4xx/5xx responses.
|
||||
let envelope: AgentEnvelope<T> = serde_json::from_str(&body)?;
|
||||
|
||||
if !status.is_success() || !envelope.success {
|
||||
// Prefer the structured error from the envelope when present.
|
||||
if let Some(err) = envelope.error {
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: err.code,
|
||||
message: err.message,
|
||||
});
|
||||
}
|
||||
// Fallback: use the HTTP status as the error indicator.
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: status.as_str().to_string(),
|
||||
message: format!(
|
||||
"Agent returned HTTP {} for {}",
|
||||
status.as_u16(),
|
||||
url
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// On success the `data` field must be present.
|
||||
envelope.data.ok_or_else(|| AgentClientError::ApiError {
|
||||
code: "MISSING_DATA".to_string(),
|
||||
message: "Agent response success=true but data field is absent".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Patch apply / job management methods
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// `POST /api/v1/patches/apply` — trigger patch application on the agent.
|
||||
#[instrument(skip(self, req), fields(base_url = %self.base_url))]
|
||||
pub async fn apply_patches(
|
||||
&self,
|
||||
req: &ApplyPatchesRequest,
|
||||
) -> Result<ApplyPatchesResponse, AgentClientError> {
|
||||
self.post("patches/apply", req).await
|
||||
}
|
||||
|
||||
/// `GET /api/v1/jobs/{id}` — poll an async agent job for status.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))]
|
||||
pub async fn job_status(
|
||||
&self,
|
||||
job_id: &str,
|
||||
) -> Result<AgentJobStatus, AgentClientError> {
|
||||
self.get(&format!("jobs/{}", job_id), &[]).await
|
||||
}
|
||||
|
||||
/// `POST /api/v1/jobs/{id}/rollback` — trigger rollback on the agent.
|
||||
#[instrument(skip(self), fields(base_url = %self.base_url, job_id = %job_id))]
|
||||
pub async fn rollback_job(
|
||||
&self,
|
||||
job_id: &str,
|
||||
) -> Result<RollbackResponse, AgentClientError> {
|
||||
let empty: serde_json::Value = serde_json::json!({});
|
||||
self.post(&format!("jobs/{}/rollback", job_id), &empty).await
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Private POST helper
|
||||
// --------------------------------------------------------
|
||||
|
||||
/// Execute a POST request against `{base_url}/{path}`, serialize `body` as
|
||||
/// JSON, deserialize the [`AgentEnvelope`], and extract the `data` field —
|
||||
/// or propagate an [`AgentClientError::ApiError`].
|
||||
async fn post<Req, Resp>(
|
||||
&self,
|
||||
path: &str,
|
||||
body: &Req,
|
||||
) -> Result<Resp, AgentClientError>
|
||||
where
|
||||
Req: Serialize,
|
||||
Resp: DeserializeOwned,
|
||||
{
|
||||
let url = format!("{}/{}", self.base_url, path);
|
||||
debug!(url = %url, "Sending POST request to agent");
|
||||
|
||||
let response = self.inner.post(&url).json(body).send().await?;
|
||||
let status = response.status();
|
||||
debug!(url = %url, status = %status, "Received POST response from agent");
|
||||
|
||||
let body_text = response.text().await?;
|
||||
let envelope: AgentEnvelope<Resp> = serde_json::from_str(&body_text)?;
|
||||
|
||||
if !status.is_success() || !envelope.success {
|
||||
if let Some(err) = envelope.error {
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: err.code,
|
||||
message: err.message,
|
||||
});
|
||||
}
|
||||
return Err(AgentClientError::ApiError {
|
||||
code: status.as_str().to_string(),
|
||||
message: format!("Agent returned HTTP {} for {}", status.as_u16(), url),
|
||||
});
|
||||
}
|
||||
|
||||
envelope.data.ok_or_else(|| AgentClientError::ApiError {
|
||||
code: "MISSING_DATA".to_string(),
|
||||
message: "Agent response success=true but data field is absent".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
49
crates/pm-agent-client/src/error.rs
Normal file
49
crates/pm-agent-client/src/error.rs
Normal file
@ -0,0 +1,49 @@
|
||||
//! Error types for the pm-agent-client crate.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Top-level error type returned by [`crate::client::AgentClient`] methods.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AgentClientError {
|
||||
/// TLS configuration or handshake failure.
|
||||
#[error("TLS error: {0}")]
|
||||
Tls(String),
|
||||
|
||||
/// Unable to establish a TCP/TLS connection to the agent.
|
||||
#[error("Connection error: {0}")]
|
||||
Connect(#[source] reqwest::Error),
|
||||
|
||||
/// An HTTP request or response transport error (not a timeout).
|
||||
#[error("Request error: {0}")]
|
||||
Request(#[source] reqwest::Error),
|
||||
|
||||
/// The request did not complete within the configured timeout.
|
||||
#[error("Request timed out")]
|
||||
Timeout,
|
||||
|
||||
/// The agent returned a non-2xx HTTP status or `success: false` in the
|
||||
/// response envelope.
|
||||
#[error("Agent API error [{code}]: {message}")]
|
||||
ApiError {
|
||||
/// Machine-readable error code supplied by the agent (e.g. `"NOT_FOUND"`).
|
||||
code: String,
|
||||
/// Human-readable description returned by the agent.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// JSON deserialization of the agent response failed.
|
||||
#[error("Failed to deserialise agent response: {0}")]
|
||||
Deserialize(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for AgentClientError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
if err.is_timeout() {
|
||||
AgentClientError::Timeout
|
||||
} else if err.is_connect() {
|
||||
AgentClientError::Connect(err)
|
||||
} else {
|
||||
AgentClientError::Request(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,46 @@
|
||||
//! pm-agent-client — mTLS HTTP client for Linux Patch API agent communication.
|
||||
//! `pm-agent-client` — mTLS HTTP client for Linux Patch API agent communication.
|
||||
//!
|
||||
//! M1: Stub. Full implementation in M4.
|
||||
//! This crate provides [`client::AgentClient`], an async HTTP client that
|
||||
//! establishes mutual-TLS connections (TLS 1.3) to `linux_patch_api` agents
|
||||
//! running on managed hosts.
|
||||
//!
|
||||
//! # Quick start
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use pm_agent_client::AgentClient;
|
||||
//!
|
||||
//! # async fn run() -> Result<(), pm_agent_client::AgentClientError> {
|
||||
//! let client = AgentClient::new(
|
||||
//! "10.0.1.5",
|
||||
//! 12443,
|
||||
//! include_bytes!("../certs/client.crt"),
|
||||
//! include_bytes!("../certs/client.key"),
|
||||
//! include_bytes!("../certs/ca.crt"),
|
||||
//! )?;
|
||||
//!
|
||||
//! let health = client.health().await?;
|
||||
//! println!("Agent {}: {}", health.status, health.version);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
pub mod types;
|
||||
|
||||
// ── Convenience re-exports ──────────────────────────────────────────────────
|
||||
|
||||
/// Primary client — re-exported from [`client::AgentClient`].
|
||||
pub use client::{AgentClient, DEFAULT_AGENT_PORT};
|
||||
|
||||
/// Error type — re-exported from [`error::AgentClientError`].
|
||||
pub use error::AgentClientError;
|
||||
|
||||
/// Response envelope and all data types.
|
||||
pub use types::{
|
||||
AgentEnvelope, AgentErrorBody,
|
||||
HealthData,
|
||||
Package, PackagesData,
|
||||
Patch, PatchesData,
|
||||
SystemInfoData,
|
||||
};
|
||||
|
||||
205
crates/pm-agent-client/src/types.rs
Normal file
205
crates/pm-agent-client/src/types.rs
Normal file
@ -0,0 +1,205 @@
|
||||
//! Response and request types for the Linux Patch API agent endpoints.
|
||||
//!
|
||||
//! All agent responses are wrapped in [`AgentEnvelope<T>`].
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
// ============================================================
|
||||
// Envelope & error
|
||||
// ============================================================
|
||||
|
||||
/// Generic response wrapper returned by every agent endpoint.
|
||||
///
|
||||
/// ```json
|
||||
/// { "success": true, "request_id": "…", "timestamp": "…", "data": {…}, "error": null }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentEnvelope<T> {
|
||||
/// `true` when the request succeeded; `false` on error.
|
||||
pub success: bool,
|
||||
/// Server-assigned request identifier (UUID v4).
|
||||
pub request_id: Uuid,
|
||||
/// Server timestamp for the response (ISO-8601 / RFC-3339).
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Response payload — present when `success` is `true`.
|
||||
pub data: Option<T>,
|
||||
/// Error detail — present when `success` is `false`.
|
||||
pub error: Option<AgentErrorBody>,
|
||||
}
|
||||
|
||||
/// Structured error returned inside [`AgentEnvelope::error`].
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentErrorBody {
|
||||
/// Machine-readable error code (e.g. `"INTERNAL_ERROR"`).
|
||||
pub code: String,
|
||||
/// Human-readable description of what went wrong.
|
||||
pub message: String,
|
||||
/// Optional free-form extra detail from the agent.
|
||||
#[serde(default)]
|
||||
pub details: Option<serde_json::Value>,
|
||||
/// Whether the caller may safely retry the request.
|
||||
#[serde(default)]
|
||||
pub retryable: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/health
|
||||
// ============================================================
|
||||
|
||||
/// Payload returned by `GET /api/v1/health`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct HealthData {
|
||||
/// Agent status string, e.g. `"ok"` or `"degraded"`.
|
||||
pub status: String,
|
||||
/// Seconds elapsed since the agent process started.
|
||||
pub uptime_seconds: u64,
|
||||
/// Agent software version string.
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/system/info
|
||||
// ============================================================
|
||||
|
||||
/// Payload returned by `GET /api/v1/system/info`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct SystemInfoData {
|
||||
/// Hostname of the managed system.
|
||||
pub hostname: String,
|
||||
/// OS family / distribution name (e.g. `"Ubuntu"`).
|
||||
pub os: String,
|
||||
/// OS version string.
|
||||
pub os_version: String,
|
||||
/// Kernel version string.
|
||||
pub kernel: String,
|
||||
/// CPU architecture (e.g. `"x86_64"`).
|
||||
pub architecture: String,
|
||||
/// When the agent last checked for updates (`null` if never).
|
||||
pub last_update_check: Option<DateTime<Utc>>,
|
||||
/// When updates were last applied (`null` if never).
|
||||
pub last_update_apply: Option<DateTime<Utc>>,
|
||||
/// Whether the system has a pending reboot.
|
||||
pub pending_reboot: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/packages?status=upgradable
|
||||
// ============================================================
|
||||
|
||||
/// A single package entry.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Package {
|
||||
/// Package name.
|
||||
pub name: String,
|
||||
/// Installed version.
|
||||
pub version: String,
|
||||
/// Package status string (e.g. `"installed"`, `"upgradable"`).
|
||||
pub status: String,
|
||||
/// Whether a newer version is available.
|
||||
pub upgradable: bool,
|
||||
/// Latest available version (`null` if not upgradable).
|
||||
pub latest_version: Option<String>,
|
||||
/// Short package description.
|
||||
pub description: String,
|
||||
/// CVE identifiers associated with this package.
|
||||
#[serde(default)]
|
||||
pub cve_ids: Vec<String>,
|
||||
}
|
||||
|
||||
/// Payload returned by `GET /api/v1/packages`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct PackagesData {
|
||||
/// List of packages matching the query filters.
|
||||
pub packages: Vec<Package>,
|
||||
/// Total count of matching packages.
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/patches
|
||||
// ============================================================
|
||||
|
||||
/// A single available patch.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Patch {
|
||||
/// Package / patch name.
|
||||
pub name: String,
|
||||
/// Currently installed version.
|
||||
pub current_version: String,
|
||||
/// Version available after applying this patch.
|
||||
pub available_version: String,
|
||||
/// Severity level (e.g. `"critical"`, `"high"`, `"medium"`, `"low"`).
|
||||
pub severity: String,
|
||||
/// Human-readable description of the patch.
|
||||
pub description: String,
|
||||
/// CVE identifiers addressed by this patch.
|
||||
#[serde(default)]
|
||||
pub cve_ids: Vec<String>,
|
||||
/// Whether applying this patch requires a system reboot.
|
||||
pub requires_reboot: bool,
|
||||
}
|
||||
|
||||
/// Payload returned by `GET /api/v1/patches`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct PatchesData {
|
||||
/// List of available patches.
|
||||
pub patches: Vec<Patch>,
|
||||
/// Total patch count.
|
||||
pub total: u64,
|
||||
/// Number of patches classified as security updates.
|
||||
pub security_updates: u64,
|
||||
/// Whether any patch in the list requires a reboot.
|
||||
pub requires_reboot: bool,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/patches/apply
|
||||
// ============================================================
|
||||
|
||||
/// Request body for `POST /api/v1/patches/apply`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApplyPatchesRequest {
|
||||
/// Package names to apply. Empty = apply all available patches.
|
||||
pub packages: Vec<String>,
|
||||
/// If true, allow automatic reboot after patching if required.
|
||||
pub allow_reboot: bool,
|
||||
}
|
||||
|
||||
/// Response from `POST /api/v1/patches/apply`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ApplyPatchesResponse {
|
||||
/// Agent-assigned async job ID for status polling.
|
||||
pub job_id: String,
|
||||
/// Initial status: typically `"running"` or `"queued"`.
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// GET /api/v1/jobs/{id}
|
||||
// ============================================================
|
||||
|
||||
/// Status of an async agent job returned by `GET /api/v1/jobs/{id}`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentJobStatus {
|
||||
pub job_id: String,
|
||||
/// Current status: `"running"`, `"succeeded"`, `"failed"`, or `"cancelled"`.
|
||||
pub status: String,
|
||||
pub progress_percent: Option<u8>,
|
||||
pub output: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// POST /api/v1/jobs/{id}/rollback
|
||||
// ============================================================
|
||||
|
||||
/// Response from `POST /api/v1/jobs/{id}/rollback`.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RollbackResponse {
|
||||
pub job_id: String,
|
||||
pub status: String,
|
||||
}
|
||||
@ -192,3 +192,109 @@ pub struct RegisterDiscoveredRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub group_ids: Option<Vec<Uuid>>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Patch Jobs
|
||||
// ============================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[sqlx(type_name = "job_status", rename_all = "lowercase")]
|
||||
pub enum JobStatus {
|
||||
Queued,
|
||||
Pending,
|
||||
Running,
|
||||
Succeeded,
|
||||
Failed,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for JobStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Queued => write!(f, "queued"),
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Succeeded => write!(f, "succeeded"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
Self::Cancelled => write!(f, "cancelled"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[sqlx(type_name = "job_kind", rename_all = "snake_case")]
|
||||
pub enum JobKind {
|
||||
#[sqlx(rename = "patch_apply")]
|
||||
PatchApply,
|
||||
#[sqlx(rename = "patch_remove")]
|
||||
PatchRemove,
|
||||
Reboot,
|
||||
Rollback,
|
||||
}
|
||||
|
||||
/// Full `patch_jobs` row.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PatchJob {
|
||||
pub id: Uuid,
|
||||
pub kind: JobKind,
|
||||
pub status: JobStatus,
|
||||
pub created_by_user_id: Option<Uuid>,
|
||||
pub parent_job_id: Option<Uuid>,
|
||||
pub maintenance_window_id: Option<Uuid>,
|
||||
pub immediate: bool,
|
||||
pub patch_selection: serde_json::Value,
|
||||
pub notes: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Full `patch_job_hosts` row (includes columns added in migration 003).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PatchJobHost {
|
||||
pub id: Uuid,
|
||||
pub job_id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub status: JobStatus,
|
||||
pub agent_job_id: Option<String>,
|
||||
pub retry_count: i32,
|
||||
pub output: String,
|
||||
pub error_message: Option<String>,
|
||||
pub retry_next_at: Option<DateTime<Utc>>,
|
||||
pub last_error: Option<String>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Request payload for creating a patch job via `POST /api/v1/jobs`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateJobRequest {
|
||||
/// Host IDs to patch.
|
||||
pub host_ids: Vec<Uuid>,
|
||||
/// Package names to apply (empty = all available patches).
|
||||
pub packages: Vec<String>,
|
||||
/// If true: apply immediately. If false: queue for next maintenance window.
|
||||
pub immediate: bool,
|
||||
/// Optional maintenance window to bind to.
|
||||
pub maintenance_window_id: Option<Uuid>,
|
||||
/// Allow reboot if required by patches.
|
||||
pub allow_reboot: Option<bool>,
|
||||
/// Optional operator notes.
|
||||
pub notes: Option<String>,
|
||||
}
|
||||
|
||||
/// Summary row for job list view (aggregates per-host counts).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PatchJobSummary {
|
||||
pub id: Uuid,
|
||||
pub kind: JobKind,
|
||||
pub status: JobStatus,
|
||||
pub immediate: bool,
|
||||
pub host_count: i64,
|
||||
pub succeeded_count: i64,
|
||||
pub failed_count: i64,
|
||||
pub notes: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
@ -105,6 +105,10 @@ pub fn build_router(state: AppState) -> Router {
|
||||
.nest("/users", routes::users::router())
|
||||
// Discovery
|
||||
.nest("/discovery", routes::discovery::router())
|
||||
// Fleet status
|
||||
.nest("/status", routes::status::router())
|
||||
// Patch jobs
|
||||
.nest("/jobs", routes::jobs::router())
|
||||
// Apply auth middleware to all the above
|
||||
.route_layer(middleware::from_fn(move |req, next| {
|
||||
let auth_config = auth_config.clone();
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
//! GET /api/v1/hosts/{id}/groups — list groups for host
|
||||
//! POST /api/v1/hosts/{id}/groups — assign host to group
|
||||
//! DELETE /api/v1/hosts/{id}/groups/{group_id} — remove host from group
|
||||
//! POST /api/v1/hosts/{id}/refresh — queue on-demand refresh (operator+)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
@ -34,6 +35,7 @@ pub fn router() -> Router<AppState> {
|
||||
.route("/:id", get(get_host).delete(remove_host))
|
||||
.route("/:id/groups", get(list_host_groups).post(add_host_to_group))
|
||||
.route("/:id/groups/:group_id", delete(remove_host_from_group))
|
||||
.route("/:id/refresh", post(refresh_host))
|
||||
}
|
||||
|
||||
// ── Query params ─────────────────────────────────────────────────────────────
|
||||
@ -470,3 +472,56 @@ async fn resolve_fqdn(fqdn: &str) -> Result<String, String> {
|
||||
_ => Err(format!("Failed to resolve FQDN: {fqdn}")),
|
||||
}
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:id/refresh ───────────────────────────────────────────
|
||||
|
||||
/// Queue an on-demand health + patch refresh for a single host.
|
||||
///
|
||||
/// Sends a PostgreSQL NOTIFY on the `refresh_requested` channel; the
|
||||
/// pm-worker refresh listener picks this up and polls the host immediately.
|
||||
/// Requires Operator or Admin role (any authenticated user).
|
||||
async fn refresh_host(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<(StatusCode, Json<Value>), (StatusCode, Json<Value>)> {
|
||||
// Verify the host exists.
|
||||
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "refresh_host: db error checking host existence");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !exists {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({ "error": { "code": "not_found", "message": "Host not found" } })),
|
||||
));
|
||||
}
|
||||
|
||||
// NOTIFY the worker's refresh listener.
|
||||
sqlx::query("SELECT pg_notify('refresh_requested', $1)")
|
||||
.bind(id.to_string())
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "refresh_host: pg_notify failed");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Failed to queue refresh" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(%id, "On-demand refresh queued");
|
||||
|
||||
Ok((
|
||||
StatusCode::ACCEPTED,
|
||||
Json(json!({ "message": "Refresh queued" })),
|
||||
))
|
||||
}
|
||||
|
||||
609
crates/pm-web/src/routes/jobs.rs
Normal file
609
crates/pm-web/src/routes/jobs.rs
Normal file
@ -0,0 +1,609 @@
|
||||
//! Patch job management routes.
|
||||
//!
|
||||
//! POST /api/v1/jobs — create a new patch job (operator+)
|
||||
//! GET /api/v1/jobs — list jobs with pagination (RBAC scoped)
|
||||
//! GET /api/v1/jobs/{id} — get job detail + per-host status
|
||||
//! POST /api/v1/jobs/{id}/cancel — cancel a queued/pending job (admin or creator)
|
||||
//! POST /api/v1/jobs/{id}/rollback — create a rollback job (admin only)
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{CreateJobRequest, PatchJobSummary},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_jobs).post(create_job))
|
||||
.route("/:id", get(get_job))
|
||||
.route("/:id/cancel", post(cancel_job))
|
||||
.route("/:id/rollback", post(rollback_job))
|
||||
}
|
||||
|
||||
// ── Query params ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct JobListQuery {
|
||||
pub limit: Option<i64>,
|
||||
pub offset: Option<i64>,
|
||||
}
|
||||
|
||||
// ── Response types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JobListResponse {
|
||||
jobs: Vec<PatchJobSummary>,
|
||||
total: i64,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
}
|
||||
|
||||
/// Per-host row included in `GET /api/v1/jobs/{id}` response.
|
||||
#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
|
||||
struct JobHostRow {
|
||||
pub id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub display_name: String,
|
||||
pub status: String,
|
||||
pub agent_job_id: Option<String>,
|
||||
pub retry_count: i32,
|
||||
pub output: String,
|
||||
pub error_message: Option<String>,
|
||||
pub last_error: Option<String>,
|
||||
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
// ── Error helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── RBAC helper ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Returns `true` when the operator's groups contain at least one host that
|
||||
/// belongs to the given job. Admins always pass this check at the call site.
|
||||
async fn operator_can_access_job(
|
||||
pool: &sqlx::PgPool,
|
||||
user_id: Uuid,
|
||||
job_id: Uuid,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN host_groups hg ON hg.host_id = pjh.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh.job_id = $1
|
||||
AND ug.user_id = $2
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn create_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Json(req): Json<CreateJobRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
if req.host_ids.is_empty() {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"host_ids must not be empty",
|
||||
));
|
||||
}
|
||||
|
||||
// Encode package list as JSONB.
|
||||
let patch_selection = serde_json::to_value(&req.packages).unwrap_or(json!([]));
|
||||
let notes = req.notes.clone().unwrap_or_default();
|
||||
|
||||
// Insert the parent job row; the DB NOTIFY trigger fires automatically
|
||||
// when immediate = TRUE (see migration 003_jobs_scheduling.sql).
|
||||
let job_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, created_by_user_id, maintenance_window_id,
|
||||
immediate, patch_selection, notes)
|
||||
VALUES
|
||||
('patch_apply'::job_kind, 'queued'::job_status, $1, $2, $3, $4, $5)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.bind(req.maintenance_window_id)
|
||||
.bind(req.immediate)
|
||||
.bind(&patch_selection)
|
||||
.bind(¬es)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "create_job: insert patch_jobs failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
// Insert one patch_job_hosts row per requested host.
|
||||
for host_id in &req.host_ids {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
VALUES ($1, $2, 'queued'::job_status)
|
||||
ON CONFLICT (job_id, host_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e, %job_id, %host_id,
|
||||
"create_job: insert patch_job_hosts failed"
|
||||
);
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&job_id.to_string()),
|
||||
json!({
|
||||
"kind": "patch_apply",
|
||||
"immediate": req.immediate,
|
||||
"host_count": req.host_ids.len(),
|
||||
"packages": req.packages,
|
||||
"notes": notes,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
job_id = %job_id,
|
||||
host_count = req.host_ids.len(),
|
||||
immediate = req.immediate,
|
||||
user = %auth.username,
|
||||
"Patch job created"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "id": job_id, "message": "Job created" })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/jobs ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn list_jobs(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Query(q): Query<JobListQuery>,
|
||||
) -> Result<Json<JobListResponse>, (StatusCode, Json<Value>)> {
|
||||
let limit = q.limit.unwrap_or(50).min(200);
|
||||
let offset = q.offset.unwrap_or(0);
|
||||
|
||||
let jobs: Vec<PatchJobSummary> = if auth.role.is_admin() {
|
||||
// Admins see every job.
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.kind,
|
||||
pj.status,
|
||||
pj.immediate,
|
||||
pj.notes,
|
||||
pj.created_at,
|
||||
pj.started_at,
|
||||
pj.completed_at,
|
||||
COUNT(pjh.id) AS host_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
|
||||
FROM patch_jobs pj
|
||||
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
|
||||
GROUP BY pj.id
|
||||
ORDER BY pj.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
} else {
|
||||
// Operators: only jobs where at least one host is in their groups.
|
||||
sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pj.id,
|
||||
pj.kind,
|
||||
pj.status,
|
||||
pj.immediate,
|
||||
pj.notes,
|
||||
pj.created_at,
|
||||
pj.started_at,
|
||||
pj.completed_at,
|
||||
COUNT(pjh.id) AS host_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'succeeded'::job_status) AS succeeded_count,
|
||||
COUNT(pjh.id) FILTER (WHERE pjh.status = 'failed'::job_status) AS failed_count
|
||||
FROM patch_jobs pj
|
||||
LEFT JOIN patch_job_hosts pjh ON pjh.job_id = pj.id
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh2
|
||||
JOIN host_groups hg ON hg.host_id = pjh2.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh2.job_id = pj.id
|
||||
AND ug.user_id = $3
|
||||
)
|
||||
GROUP BY pj.id
|
||||
ORDER BY pj.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.bind(auth.user_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
}
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "list_jobs: query failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
// Total count for pagination metadata.
|
||||
let total: i64 = if auth.role.is_admin() {
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM patch_jobs")
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(DISTINCT pj.id)
|
||||
FROM patch_jobs pj
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN host_groups hg ON hg.host_id = pjh.host_id
|
||||
JOIN user_groups ug ON ug.group_id = hg.group_id
|
||||
WHERE pjh.job_id = pj.id
|
||||
AND ug.user_id = $1
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
};
|
||||
|
||||
Ok(Json(JobListResponse { jobs, total, limit, offset }))
|
||||
}
|
||||
// ── GET /api/v1/jobs/:id ─────────────────────────────────────────────────────
|
||||
|
||||
async fn get_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// RBAC: operators may only view jobs touching their group's hosts.
|
||||
if !auth.role.is_admin() {
|
||||
let allowed = operator_can_access_job(&state.db, auth.user_id, id)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
if !allowed {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Access denied",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the job header row as JSON.
|
||||
let job: Option<Value> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT row_to_json(j) FROM (
|
||||
SELECT id, kind, status, created_by_user_id, parent_job_id,
|
||||
maintenance_window_id, immediate, patch_selection, notes,
|
||||
created_at, started_at, completed_at
|
||||
FROM patch_jobs
|
||||
WHERE id = $1
|
||||
) j
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "get_job: failed to fetch job");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
let job = job.ok_or_else(|| {
|
||||
err(StatusCode::NOT_FOUND, "not_found", "Job not found")
|
||||
})?;
|
||||
|
||||
// Fetch per-host status rows joined to the host display name.
|
||||
let hosts: Vec<JobHostRow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
pjh.id,
|
||||
pjh.host_id,
|
||||
COALESCE(h.display_name, h.fqdn) AS display_name,
|
||||
pjh.status::text AS status,
|
||||
pjh.agent_job_id,
|
||||
pjh.retry_count,
|
||||
pjh.output,
|
||||
pjh.error_message,
|
||||
pjh.last_error,
|
||||
pjh.started_at,
|
||||
pjh.completed_at
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.job_id = $1
|
||||
ORDER BY h.display_name
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "get_job: failed to fetch host rows");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "job": job, "hosts": hosts })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs/:id/cancel ─────────────────────────────────────────────
|
||||
|
||||
async fn cancel_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Fetch the job to verify it exists and check ownership.
|
||||
let row: Option<(String, Option<Uuid>)> = sqlx::query_as(
|
||||
"SELECT status::text, created_by_user_id FROM patch_jobs WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: db fetch failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
let (status_str, creator_id) = row.ok_or_else(|| {
|
||||
err(StatusCode::NOT_FOUND, "not_found", "Job not found")
|
||||
})?;
|
||||
|
||||
// Only admin or the job creator may cancel.
|
||||
if !auth.role.is_admin() {
|
||||
let is_creator = creator_id.map_or(false, |cid| cid == auth.user_id);
|
||||
if !is_creator {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Only admin or the job creator may cancel this job",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Only queued or pending jobs can be cancelled.
|
||||
if status_str != "queued" && status_str != "pending" {
|
||||
return Err(err(
|
||||
StatusCode::CONFLICT,
|
||||
"invalid_state",
|
||||
format!(
|
||||
"Cannot cancel a job in '{}' state; only queued or pending jobs may be cancelled",
|
||||
status_str
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Cancel the parent job.
|
||||
sqlx::query(
|
||||
"UPDATE patch_jobs SET status = 'cancelled'::job_status WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: update patch_jobs failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
// Cancel all queued/pending host rows for this job.
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'cancelled'::job_status
|
||||
WHERE job_id = $1
|
||||
AND status IN ('queued'::job_status, 'pending'::job_status)
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "cancel_job: update patch_job_hosts failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobCancelled,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&id.to_string()),
|
||||
json!({ "previous_status": status_str }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(job_id = %id, user = %auth.username, "Patch job cancelled");
|
||||
Ok(Json(json!({ "message": "Job cancelled" })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/jobs/:id/rollback ────────────────────────────────────────────
|
||||
|
||||
async fn rollback_job(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Admin-only operation.
|
||||
if !auth.role.is_admin() {
|
||||
return Err(err(
|
||||
StatusCode::FORBIDDEN,
|
||||
"forbidden",
|
||||
"Admin role required to create rollback jobs",
|
||||
));
|
||||
}
|
||||
|
||||
// Verify the original job exists.
|
||||
let original_exists: bool =
|
||||
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM patch_jobs WHERE id = $1)")
|
||||
.bind(id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "rollback_job: existence check failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
if !original_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Job not found"));
|
||||
}
|
||||
|
||||
// Gather the host IDs from the original job.
|
||||
let host_ids: Vec<Uuid> =
|
||||
sqlx::query_scalar("SELECT host_id FROM patch_job_hosts WHERE job_id = $1")
|
||||
.bind(id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %id, "rollback_job: host fetch failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
if host_ids.is_empty() {
|
||||
return Err(err(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"no_hosts",
|
||||
"Original job has no host entries to roll back",
|
||||
));
|
||||
}
|
||||
|
||||
// Create the rollback job row (immediate = true so the worker picks it up
|
||||
// right away and the NOTIFY trigger fires).
|
||||
let rollback_job_id: Uuid = sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO patch_jobs
|
||||
(kind, status, created_by_user_id, parent_job_id, immediate,
|
||||
patch_selection, notes)
|
||||
VALUES
|
||||
('rollback'::job_kind, 'queued'::job_status, $1, $2, TRUE,
|
||||
'[]'::jsonb, $3)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(auth.user_id)
|
||||
.bind(id)
|
||||
.bind(format!("Rollback of job {}", id))
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, parent_job_id = %id, "rollback_job: insert failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
// Replicate host list into the rollback job.
|
||||
for host_id in &host_ids {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO patch_job_hosts (job_id, host_id, status)
|
||||
VALUES ($1, $2, 'queued'::job_status)
|
||||
ON CONFLICT (job_id, host_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(rollback_job_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e, %rollback_job_id, %host_id,
|
||||
"rollback_job: insert patch_job_hosts failed"
|
||||
);
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::PatchJobRollback,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("job"),
|
||||
Some(&rollback_job_id.to_string()),
|
||||
json!({
|
||||
"original_job_id": id,
|
||||
"rollback_job_id": rollback_job_id,
|
||||
"host_count": host_ids.len(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
rollback_job_id = %rollback_job_id,
|
||||
original_job_id = %id,
|
||||
user = %auth.username,
|
||||
"Rollback job created"
|
||||
);
|
||||
|
||||
Ok(Json(json!({
|
||||
"id": rollback_job_id,
|
||||
"parent_job_id": id,
|
||||
"message": "Rollback job created"
|
||||
})))
|
||||
}
|
||||
@ -3,4 +3,6 @@ pub mod auth;
|
||||
pub mod discovery;
|
||||
pub mod groups;
|
||||
pub mod hosts;
|
||||
pub mod jobs;
|
||||
pub mod status;
|
||||
pub mod users;
|
||||
|
||||
151
crates/pm-web/src/routes/status.rs
Normal file
151
crates/pm-web/src/routes/status.rs
Normal file
@ -0,0 +1,151 @@
|
||||
//! Fleet status routes.
|
||||
//!
|
||||
//! GET /api/v1/status/fleet — aggregate health and patch summary across all hosts.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new().route("/fleet", get(fleet_status))
|
||||
}
|
||||
|
||||
// ── Response type ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FleetStatus {
|
||||
pub total_hosts: i64,
|
||||
pub healthy: i64,
|
||||
pub degraded: i64,
|
||||
pub unreachable: i64,
|
||||
pub pending: i64,
|
||||
pub total_pending_patches: i64,
|
||||
pub hosts_requiring_reboot: i64,
|
||||
pub compliance_pct: f64,
|
||||
}
|
||||
|
||||
// ── GET /api/v1/status/fleet ──────────────────────────────────────────────────
|
||||
|
||||
pub async fn fleet_status(
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<FleetStatus>, (StatusCode, Json<Value>)> {
|
||||
// ── 1. Host health aggregates ─────────────────────────────────────────
|
||||
let health_row: (i64, i64, i64, i64, i64) = sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) AS total_hosts,
|
||||
COUNT(*) FILTER (WHERE health_status = 'healthy') AS healthy,
|
||||
COUNT(*) FILTER (WHERE health_status = 'degraded') AS degraded,
|
||||
COUNT(*) FILTER (WHERE health_status = 'unreachable') AS unreachable,
|
||||
COUNT(*) FILTER (WHERE health_status = 'pending') AS pending
|
||||
FROM hosts
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query host health aggregates");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (total_hosts, healthy, degraded, unreachable, pending) = health_row;
|
||||
|
||||
// ── 2. Total pending patches across fleet (latest row per host) ───────
|
||||
let total_pending_patches: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COALESCE(SUM(patch_count), 0)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) patch_count
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query total pending patches");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// ── 3. Hosts requiring a reboot (latest patch row per host) ───────────
|
||||
let hosts_requiring_reboot: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) available_patches
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
WHERE available_patches @> '[{"requires_reboot": true}]'
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query reboot-required hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// ── 4. Compliance: hosts with zero pending patches / total hosts ───────
|
||||
// Hosts that have been polled and have patch_count == 0 are considered
|
||||
// compliant. Hosts with no patch data at all are excluded from the
|
||||
// compliance calculation.
|
||||
let compliant_hosts: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT ON (host_id) patch_count
|
||||
FROM host_patch_data
|
||||
ORDER BY host_id, polled_at DESC
|
||||
) latest
|
||||
WHERE patch_count = 0
|
||||
"#,
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "fleet_status: failed to query compliant hosts");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": { "code": "internal_error", "message": "Database error" } })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let compliance_pct = if total_hosts == 0 {
|
||||
100.0_f64
|
||||
} else {
|
||||
(compliant_hosts as f64 / total_hosts as f64) * 100.0
|
||||
};
|
||||
|
||||
// Round to one decimal place.
|
||||
let compliance_pct = (compliance_pct * 10.0).round() / 10.0;
|
||||
|
||||
Ok(Json(FleetStatus {
|
||||
total_hosts,
|
||||
healthy,
|
||||
degraded,
|
||||
unreachable,
|
||||
pending,
|
||||
total_pending_patches,
|
||||
hosts_requiring_reboot,
|
||||
compliance_pct,
|
||||
}))
|
||||
}
|
||||
@ -11,7 +11,8 @@ path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
pm-core = { path = "../pm-core" }
|
||||
tokio = { workspace = true }
|
||||
pm-agent-client = { path = "../pm-agent-client" }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
sqlx = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
45
crates/pm-worker/src/agent_loader.rs
Normal file
45
crates/pm-worker/src/agent_loader.rs
Normal file
@ -0,0 +1,45 @@
|
||||
//! Helper for loading mTLS certificate/key material from disk.
|
||||
//!
|
||||
//! Reads PEM files referenced in [`SecurityConfig`] and returns the raw bytes
|
||||
//! needed by [`pm_agent_client::AgentClient`].
|
||||
|
||||
use pm_core::config::SecurityConfig;
|
||||
|
||||
/// Raw PEM bytes for mTLS client authentication and CA verification.
|
||||
pub struct AgentCerts {
|
||||
pub client_cert: Vec<u8>,
|
||||
pub client_key: Vec<u8>,
|
||||
pub ca_cert: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Load agent mTLS certificates from the paths specified in [`SecurityConfig`].
|
||||
///
|
||||
/// Returns an error if any file cannot be read. The caller should handle
|
||||
/// the error gracefully (log and skip the poll cycle) rather than crashing.
|
||||
pub fn load_agent_certs(security: &SecurityConfig) -> anyhow::Result<AgentCerts> {
|
||||
let client_cert = std::fs::read(&security.agent_client_cert_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read agent client cert '{}': {}",
|
||||
security.agent_client_cert_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let client_key = std::fs::read(&security.agent_client_key_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read agent client key '{}': {}",
|
||||
security.agent_client_key_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let ca_cert = std::fs::read(&security.ca_cert_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read CA cert '{}': {}",
|
||||
security.ca_cert_path,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(AgentCerts { client_cert, client_key, ca_cert })
|
||||
}
|
||||
202
crates/pm-worker/src/health_poller.rs
Normal file
202
crates/pm-worker/src/health_poller.rs
Normal file
@ -0,0 +1,202 @@
|
||||
//! Periodic health poller for all registered hosts.
|
||||
//!
|
||||
//! Polls every host via the agent `/health` endpoint on each tick of
|
||||
//! `health_poll_interval_secs`, with bounded concurrency controlled by a
|
||||
//! [`tokio::sync::Semaphore`].
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
use pm_core::{
|
||||
config::AppConfig,
|
||||
models::HostHealthStatus,
|
||||
};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{
|
||||
sync::Semaphore,
|
||||
time,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host projection fetched for each poll cycle.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the health poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all registered hosts are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_health_data` and the `hosts` table is updated.
|
||||
pub async fn run_health_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.health_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(
|
||||
interval_secs,
|
||||
"Health poller started"
|
||||
);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
// Load certs on each cycle so cert rotation is picked up automatically.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health poller: failed to load agent certs — skipping cycle");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
// Fetch all hosts.
|
||||
let hosts: Vec<HostRow> = match sqlx::query_as(
|
||||
"SELECT id, ip_address::text AS ip_address, agent_port FROM hosts ORDER BY id",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Health poller: failed to fetch hosts");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!("Health poller: no hosts registered, skipping cycle");
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = hosts.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for host in hosts {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
poll_host_health(pool, host, &cert, &key, &ca).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Collect results and tally counts.
|
||||
let mut healthy = 0usize;
|
||||
let mut degraded = 0usize;
|
||||
let mut unreachable = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(HostHealthStatus::Healthy) => healthy += 1,
|
||||
Ok(HostHealthStatus::Degraded) => degraded += 1,
|
||||
Ok(HostHealthStatus::Unreachable) => unreachable += 1,
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::error!(error = %e, "Health poller task panicked"),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
total,
|
||||
healthy,
|
||||
degraded,
|
||||
unreachable,
|
||||
"Health poll cycle complete"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll a single host, persist the result, and return the determined status.
|
||||
async fn poll_host_health(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> HostHealthStatus {
|
||||
// Determine status and optional health payload.
|
||||
let (status, payload) = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Health poller: failed to build AgentClient"
|
||||
);
|
||||
(HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
Ok(client) => match client.health().await {
|
||||
Ok(data) => {
|
||||
let payload = serde_json::to_value(&data).unwrap_or_default();
|
||||
(HostHealthStatus::Healthy, payload)
|
||||
}
|
||||
Err(AgentClientError::Timeout) => {
|
||||
tracing::warn!(host_id = %host.id, "Health poller: agent timed out");
|
||||
(HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
Err(AgentClientError::Connect(_)) => {
|
||||
tracing::warn!(host_id = %host.id, "Health poller: agent connection refused");
|
||||
(HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Health poller: agent error");
|
||||
(HostHealthStatus::Degraded, serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Insert into host_health_data.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_data (host_id, status, payload)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&status)
|
||||
.bind(&payload)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to insert health data");
|
||||
}
|
||||
|
||||
// Update hosts table.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE hosts
|
||||
SET health_status = $2, last_health_at = NOW()
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&status)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Health poller: failed to update host status");
|
||||
}
|
||||
|
||||
status
|
||||
}
|
||||
826
crates/pm-worker/src/job_executor.rs
Normal file
826
crates/pm-worker/src/job_executor.rs
Normal file
@ -0,0 +1,826 @@
|
||||
//! Job execution engine.
|
||||
//!
|
||||
//! Picks up patch jobs from the database, dispatches them to agents via mTLS,
|
||||
//! tracks progress, and handles retries with exponential back-off.
|
||||
//!
|
||||
//! Two concurrent loops run inside [`run_job_executor`]:
|
||||
//!
|
||||
//! 1. **NOTIFY listener** — listens on `job_enqueued`; triggers immediate
|
||||
//! dispatch for newly-enqueued jobs.
|
||||
//! 2. **Periodic scanner** — every 60 seconds:
|
||||
//! - picks up queued non-immediate jobs that were missed by NOTIFY,
|
||||
//! - polls running agent jobs for completion,
|
||||
//! - retries pending host jobs whose back-off window has elapsed.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use pm_agent_client::{AgentClient, types::ApplyPatchesRequest};
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{sync::Semaphore, time};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Internal DB row types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct PatchJobHostQueued {
|
||||
id: Uuid,
|
||||
host_id: Uuid,
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct PatchJobHostRunning {
|
||||
id: Uuid,
|
||||
agent_job_id: String,
|
||||
job_id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct PatchJobHostPending {
|
||||
id: Uuid,
|
||||
host_id: Uuid,
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct JobPatchSelection {
|
||||
patch_selection: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct RetryRow {
|
||||
job_id: Uuid,
|
||||
retry_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct StatusCounts {
|
||||
running_count: i64,
|
||||
pending_count: i64,
|
||||
queued_count: i64,
|
||||
succeeded_count: i64,
|
||||
failed_count: i64,
|
||||
cancelled_count: i64,
|
||||
total_count: i64,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public entry point
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Spawn the job executor and run it indefinitely.
|
||||
///
|
||||
/// Runs two independent tasks joined until both complete (they never do under
|
||||
/// normal operation):
|
||||
/// - NOTIFY-driven immediate dispatch (auto-reconnect on DB disconnect).
|
||||
/// - 60-second periodic scanner for queued / running / pending rows.
|
||||
pub async fn run_job_executor(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Job executor started");
|
||||
|
||||
let (pool_n, cfg_n) = (pool.clone(), config.clone());
|
||||
let (pool_s, cfg_s) = (pool.clone(), config.clone());
|
||||
|
||||
let notify_task = tokio::spawn(async move {
|
||||
run_notify_listener(pool_n, cfg_n).await;
|
||||
});
|
||||
let scan_task = tokio::spawn(async move {
|
||||
run_periodic_scanner(pool_s, cfg_s).await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(notify_task, scan_task);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// NOTIFY listener (outer reconnect wrapper)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn run_notify_listener(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Job executor NOTIFY listener starting");
|
||||
loop {
|
||||
if let Err(e) = notify_listen_loop(&pool, &config).await {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Job executor NOTIFY listener disconnected, reconnecting in 5s"
|
||||
);
|
||||
time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inner NOTIFY loop — returns `Err` only on a fatal connection error so the
|
||||
/// outer loop can reconnect.
|
||||
async fn notify_listen_loop(
|
||||
pool: &PgPool,
|
||||
config: &Arc<AppConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut listener =
|
||||
sqlx::postgres::PgListener::connect(&config.database.url).await?;
|
||||
listener.listen("job_enqueued").await?;
|
||||
tracing::debug!("Job executor NOTIFY listener connected");
|
||||
|
||||
loop {
|
||||
let notification = listener.recv().await?;
|
||||
let payload = notification.payload().to_string();
|
||||
tracing::info!(payload, "job_enqueued notification received");
|
||||
|
||||
let job_id = match payload.parse::<Uuid>() {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
payload,
|
||||
error = %e,
|
||||
"Job executor: invalid UUID in job_enqueued payload"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
tokio::spawn(async move {
|
||||
process_job(p, c, job_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Periodic scanner
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn run_periodic_scanner(pool: PgPool, config: Arc<AppConfig>) {
|
||||
// First tick fires immediately — consume it to avoid a duplicate burst
|
||||
// right after NOTIFY already dispatched the same jobs.
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(60));
|
||||
ticker.tick().await;
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
tracing::debug!("Job executor periodic scan starting");
|
||||
|
||||
// 1. Pick up queued pjh rows that belong to non-cancelled jobs.
|
||||
scan_queued_jobs(pool.clone(), config.clone()).await;
|
||||
|
||||
// 2. Poll running pjh rows against the agent.
|
||||
poll_running_jobs(pool.clone(), config.clone()).await;
|
||||
|
||||
// 3. Retry pending pjh rows whose back-off window has elapsed.
|
||||
retry_pending_jobs(pool.clone(), config.clone()).await;
|
||||
|
||||
tracing::debug!("Job executor periodic scan complete");
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// scan_queued_jobs — feeds non-immediate jobs into process_job
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Discover distinct job-IDs that have queued host entries ready for dispatch
|
||||
/// and call [`process_job`] for each.
|
||||
async fn scan_queued_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
#[derive(FromRow)]
|
||||
struct JobIdRow {
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
let rows: Vec<JobIdRow> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT pjh.job_id
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs j ON j.id = pjh.job_id
|
||||
WHERE pjh.status = 'queued'
|
||||
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
|
||||
AND j.status != 'cancelled'
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "scan_queued_jobs: DB query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for row in rows {
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
tokio::spawn(async move {
|
||||
process_job(p, c, row.job_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// process_job
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Fetch all queued host entries for `job_id` and dispatch them concurrently,
|
||||
/// bounded by `config.worker.max_concurrent_agent_calls`.
|
||||
async fn process_job(pool: PgPool, config: Arc<AppConfig>, job_id: Uuid) {
|
||||
tracing::info!(%job_id, "process_job: dispatching queued hosts");
|
||||
|
||||
// Mark the parent job as running (idempotent guard).
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_jobs
|
||||
SET status = 'running',
|
||||
started_at = COALESCE(started_at, NOW())
|
||||
WHERE id = $1
|
||||
AND status NOT IN ('running','succeeded','failed','cancelled')
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%job_id, error = %e, "process_job: failed to mark job running");
|
||||
}
|
||||
|
||||
// Fetch all queued host entries for this job.
|
||||
let hosts: Vec<PatchJobHostQueued> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, job_id
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
AND status = 'queued'
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
tracing::error!(%job_id, error = %e, "process_job: failed to fetch queued hosts");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!(%job_id, "process_job: no queued hosts found (already dispatched)");
|
||||
return;
|
||||
}
|
||||
|
||||
let sem = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
for host in hosts {
|
||||
let permit = match sem.clone().acquire_owned().await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!(%job_id, error = %e, "process_job: semaphore closed");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
let pjh_id = host.id;
|
||||
let host_id = host.host_id;
|
||||
|
||||
tokio::spawn(async move {
|
||||
execute_host_job(p, c, job_id, host_id, pjh_id).await;
|
||||
drop(permit);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// execute_host_job
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Connect to a single host agent, submit the patch job, and record the
|
||||
/// agent-assigned async job ID for later polling.
|
||||
async fn execute_host_job(
|
||||
pool: PgPool,
|
||||
config: Arc<AppConfig>,
|
||||
job_id: Uuid,
|
||||
host_id: Uuid,
|
||||
pjh_id: Uuid,
|
||||
) {
|
||||
tracing::info!(%job_id, %host_id, %pjh_id, "execute_host_job: starting");
|
||||
|
||||
// ── 1. Fetch host connection details ─────────────────────────────────────
|
||||
let host: HostRow = match sqlx::query_as(
|
||||
"SELECT ip_address::text AS ip_address, agent_port FROM hosts WHERE id = $1",
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(Some(h)) => h,
|
||||
Ok(None) => {
|
||||
tracing::error!(%host_id, "execute_host_job: host not found");
|
||||
handle_host_failure(
|
||||
pool,
|
||||
pjh_id,
|
||||
format!("Host {host_id} not found in database"),
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(%host_id, error = %e, "execute_host_job: DB error fetching host");
|
||||
handle_host_failure(pool, pjh_id, format!("DB error fetching host: {e}")).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// ── 2. Fetch the job's patch_selection ──────────────────────────────────
|
||||
let patch_sel: JobPatchSelection = match sqlx::query_as(
|
||||
"SELECT patch_selection FROM patch_jobs WHERE id = $1",
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => {
|
||||
tracing::error!(%job_id, "execute_host_job: parent job not found");
|
||||
handle_host_failure(pool, pjh_id, format!("Parent job {job_id} not found")).await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(%job_id, error = %e, "execute_host_job: DB error fetching job");
|
||||
handle_host_failure(pool, pjh_id, format!("DB error fetching job: {e}")).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let packages: Vec<String> =
|
||||
serde_json::from_value(patch_sel.patch_selection).unwrap_or_default();
|
||||
|
||||
// ── 3. Load mTLS certs ───────────────────────────────────────────────────
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(%host_id, error = %e, "execute_host_job: failed to load agent certs");
|
||||
handle_host_failure(pool, pjh_id, format!("Failed to load agent certs: {e}")).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// ── 4. Build AgentClient ─────────────────────────────────────────────────
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
&certs.client_cert,
|
||||
&certs.client_key,
|
||||
&certs.ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(%host_id, error = %e, "execute_host_job: failed to build AgentClient");
|
||||
handle_host_failure(pool, pjh_id, format!("Failed to build agent client: {e}")).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// ── 5. Mark pjh as running ───────────────────────────────────────────────
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'running',
|
||||
started_at = COALESCE(started_at, NOW())
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(pjh_id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%pjh_id, error = %e, "execute_host_job: failed to mark pjh running");
|
||||
}
|
||||
|
||||
// ── 6. Submit the patch job to the agent ─────────────────────────────────
|
||||
let req = ApplyPatchesRequest { packages, allow_reboot: true };
|
||||
|
||||
match client.apply_patches(&req).await {
|
||||
Ok(resp) => {
|
||||
tracing::info!(
|
||||
%pjh_id,
|
||||
agent_job_id = %resp.job_id,
|
||||
"execute_host_job: agent accepted job"
|
||||
);
|
||||
|
||||
// ── 7. Store agent_job_id; status stays 'running' (agent is async) ──
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE patch_job_hosts SET agent_job_id = $1 WHERE id = $2",
|
||||
)
|
||||
.bind(&resp.job_id)
|
||||
.bind(pjh_id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
%pjh_id,
|
||||
error = %e,
|
||||
"execute_host_job: failed to store agent_job_id"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(%pjh_id, error = %e, "execute_host_job: agent rejected job");
|
||||
handle_host_failure(pool, pjh_id, format!("Agent error: {e}")).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// poll_running_jobs
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Poll all running pjh rows that have an agent job ID and update their status.
|
||||
pub async fn poll_running_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let rows: Vec<PatchJobHostRunning> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT pjh.id,
|
||||
pjh.agent_job_id,
|
||||
pjh.job_id,
|
||||
h.ip_address::text AS ip_address,
|
||||
h.agent_port
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.status = 'running'
|
||||
AND pjh.agent_job_id IS NOT NULL
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "poll_running_jobs: DB query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for row in rows {
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
tokio::spawn(async move {
|
||||
poll_single_host(p, c, row).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll one running host entry and update its status from the agent response.
|
||||
async fn poll_single_host(
|
||||
pool: PgPool,
|
||||
config: Arc<AppConfig>,
|
||||
row: PatchJobHostRunning,
|
||||
) {
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
pjh_id = %row.id,
|
||||
error = %e,
|
||||
"poll_single_host: failed to load agent certs"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let client = match AgentClient::new(
|
||||
&row.ip_address,
|
||||
row.agent_port as u16,
|
||||
&certs.client_cert,
|
||||
&certs.client_key,
|
||||
&certs.ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
pjh_id = %row.id,
|
||||
error = %e,
|
||||
"poll_single_host: failed to build AgentClient"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let status = match client.job_status(&row.agent_job_id).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
pjh_id = %row.id,
|
||||
agent_job_id = %row.agent_job_id,
|
||||
error = %e,
|
||||
"poll_single_host: agent status call failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match status.status.as_str() {
|
||||
"succeeded" => {
|
||||
tracing::info!(pjh_id = %row.id, "poll_single_host: agent job succeeded");
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'succeeded',
|
||||
completed_at = NOW(),
|
||||
output = $2
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(row.id)
|
||||
.bind(status.output.as_deref())
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(pjh_id = %row.id, error = %e, "poll_single_host: update failed");
|
||||
}
|
||||
sync_job_status(&pool, row.job_id).await;
|
||||
}
|
||||
"failed" => {
|
||||
tracing::warn!(pjh_id = %row.id, "poll_single_host: agent job failed");
|
||||
let err_msg = status
|
||||
.error
|
||||
.unwrap_or_else(|| "Agent reported failure (no detail)".to_string());
|
||||
handle_host_failure(pool, row.id, err_msg).await;
|
||||
}
|
||||
"running" | "queued" => {
|
||||
// Still in progress — nothing to update; will poll again next cycle.
|
||||
tracing::debug!(
|
||||
pjh_id = %row.id,
|
||||
agent_status = %status.status,
|
||||
"poll_single_host: job still in progress"
|
||||
);
|
||||
}
|
||||
other => {
|
||||
tracing::warn!(
|
||||
pjh_id = %row.id,
|
||||
agent_status = %other,
|
||||
"poll_single_host: unexpected agent status — ignoring"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// handle_host_failure
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Apply exponential back-off retry logic to a failed host job entry.
|
||||
///
|
||||
/// Retries up to 3 times (1 min / 5 min / 30 min delays). After the third
|
||||
/// failure the entry is marked `failed` and the parent job status is synced.
|
||||
async fn handle_host_failure(pool: PgPool, pjh_id: Uuid, error_msg: String) {
|
||||
let row: Option<RetryRow> = match sqlx::query_as(
|
||||
"SELECT job_id, retry_count FROM patch_job_hosts WHERE id = $1",
|
||||
)
|
||||
.bind(pjh_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(%pjh_id, error = %e, "handle_host_failure: DB error fetching retry row");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let row = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
tracing::error!(%pjh_id, "handle_host_failure: pjh row not found");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if row.retry_count < 3 {
|
||||
let new_retry_count = row.retry_count + 1;
|
||||
let retry_next_at = Utc::now()
|
||||
+ match new_retry_count {
|
||||
1 => ChronoDuration::minutes(1),
|
||||
2 => ChronoDuration::minutes(5),
|
||||
_ => ChronoDuration::minutes(30),
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
%pjh_id,
|
||||
retry_count = new_retry_count,
|
||||
?retry_next_at,
|
||||
error = %error_msg,
|
||||
"handle_host_failure: scheduling retry"
|
||||
);
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'pending',
|
||||
retry_count = $2,
|
||||
retry_next_at = $3,
|
||||
last_error = $4
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(pjh_id)
|
||||
.bind(new_retry_count)
|
||||
.bind(retry_next_at)
|
||||
.bind(&error_msg)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%pjh_id, error = %e, "handle_host_failure: failed to set pending");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(
|
||||
%pjh_id,
|
||||
retry_count = row.retry_count,
|
||||
error = %error_msg,
|
||||
"handle_host_failure: max retries exceeded, marking failed"
|
||||
);
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = 'failed',
|
||||
error_message = $2,
|
||||
completed_at = NOW()
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(pjh_id)
|
||||
.bind(&error_msg)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%pjh_id, error = %e, "handle_host_failure: failed to mark pjh failed");
|
||||
}
|
||||
|
||||
sync_job_status(&pool, row.job_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// sync_job_status
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Roll up `patch_job_hosts` aggregate status into the parent `patch_jobs` row.
|
||||
///
|
||||
/// Logic (in priority order):
|
||||
/// 1. Any `running` or `pending` hosts → keep parent `running`.
|
||||
/// 2. All hosts `succeeded` → parent `succeeded`.
|
||||
/// 3. All hosts `cancelled` → parent `cancelled`.
|
||||
/// 4. Any `failed` with none still active → parent `failed` (includes partial).
|
||||
async fn sync_job_status(pool: &PgPool, job_id: Uuid) {
|
||||
let counts: StatusCounts = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE status = 'running') AS running_count,
|
||||
COUNT(*) FILTER (WHERE status = 'pending') AS pending_count,
|
||||
COUNT(*) FILTER (WHERE status = 'queued') AS queued_count,
|
||||
COUNT(*) FILTER (WHERE status = 'succeeded') AS succeeded_count,
|
||||
COUNT(*) FILTER (WHERE status = 'failed') AS failed_count,
|
||||
COUNT(*) FILTER (WHERE status = 'cancelled') AS cancelled_count,
|
||||
COUNT(*) AS total_count
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(%job_id, error = %e, "sync_job_status: DB query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Determine the aggregate status.
|
||||
let new_status: &str;
|
||||
let set_completed: bool;
|
||||
|
||||
if counts.running_count > 0 || counts.pending_count > 0 || counts.queued_count > 0 {
|
||||
// Still work in flight — keep parent running.
|
||||
new_status = "running";
|
||||
set_completed = false;
|
||||
} else if counts.total_count > 0 && counts.succeeded_count == counts.total_count {
|
||||
// Every host succeeded.
|
||||
new_status = "succeeded";
|
||||
set_completed = true;
|
||||
} else if counts.total_count > 0 && counts.cancelled_count == counts.total_count {
|
||||
// Every host cancelled.
|
||||
new_status = "cancelled";
|
||||
set_completed = true;
|
||||
} else if counts.failed_count > 0 {
|
||||
// At least one failure and nothing still active → failed (partial counts too).
|
||||
new_status = "failed";
|
||||
set_completed = true;
|
||||
} else {
|
||||
// Fallback: nothing actionable yet.
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
%job_id,
|
||||
new_status,
|
||||
running = counts.running_count,
|
||||
pending = counts.pending_count,
|
||||
queued = counts.queued_count,
|
||||
succeeded = counts.succeeded_count,
|
||||
failed = counts.failed_count,
|
||||
"sync_job_status: updating parent job"
|
||||
);
|
||||
|
||||
let result = if set_completed {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_jobs
|
||||
SET status = $2,
|
||||
completed_at = COALESCE(completed_at, NOW())
|
||||
WHERE id = $1
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(new_status)
|
||||
.execute(pool)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE patch_jobs SET status = $2 WHERE id = $1",
|
||||
)
|
||||
.bind(job_id)
|
||||
.bind(new_status)
|
||||
.execute(pool)
|
||||
.await
|
||||
};
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::error!(%job_id, error = %e, "sync_job_status: failed to update parent job");
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// retry_pending_jobs
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Find pending host entries whose back-off window has elapsed, reset them to
|
||||
/// `queued`, and dispatch them immediately.
|
||||
pub async fn retry_pending_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let rows: Vec<PatchJobHostPending> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT pjh.id, pjh.host_id, pjh.job_id
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs j ON j.id = pjh.job_id
|
||||
WHERE pjh.status = 'pending'
|
||||
AND pjh.retry_next_at <= NOW()
|
||||
AND j.status != 'cancelled'
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "retry_pending_jobs: DB query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for row in rows {
|
||||
// Reset to queued so execute_host_job can pick it up cleanly.
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE patch_job_hosts SET status = 'queued', retry_next_at = NULL WHERE id = $1",
|
||||
)
|
||||
.bind(row.id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
pjh_id = %row.id,
|
||||
error = %e,
|
||||
"retry_pending_jobs: failed to reset pjh to queued"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
pjh_id = %row.id,
|
||||
job_id = %row.job_id,
|
||||
"retry_pending_jobs: re-dispatching host job"
|
||||
);
|
||||
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
let (job_id, host_id, pjh_id) = (row.job_id, row.host_id, row.id);
|
||||
tokio::spawn(async move {
|
||||
execute_host_job(p, c, job_id, host_id, pjh_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -3,6 +3,12 @@
|
||||
//! Handles scheduled polling, job execution, maintenance window scheduling,
|
||||
//! retry logic, email notifications, and data pruning.
|
||||
|
||||
mod agent_loader;
|
||||
mod health_poller;
|
||||
mod patch_poller;
|
||||
mod refresh_listener;
|
||||
mod job_executor;
|
||||
|
||||
use pm_core::{
|
||||
config::AppConfig,
|
||||
db,
|
||||
@ -12,6 +18,11 @@ use sqlx::PgPool;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time;
|
||||
|
||||
use health_poller::run_health_poller;
|
||||
use patch_poller::run_patch_poller;
|
||||
use refresh_listener::run_refresh_listener;
|
||||
use job_executor::run_job_executor;
|
||||
|
||||
/// Minimum number of applied migrations the worker requires before
|
||||
/// accepting work. Prevents the worker from running against a schema
|
||||
/// that hasn't been migrated yet.
|
||||
@ -51,14 +62,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
config.worker.heartbeat_interval_secs,
|
||||
));
|
||||
|
||||
// TODO M4: spawn health_poller, patch_data_poller
|
||||
// TODO M5: spawn job_executor
|
||||
// TODO M6: spawn job_scheduler
|
||||
// M4: agent health poller, patch data poller, on-demand refresh listener
|
||||
let health_handle = tokio::spawn(run_health_poller(pool.clone(), config.clone()));
|
||||
let patch_handle = tokio::spawn(run_patch_poller(pool.clone(), config.clone()));
|
||||
let refresh_handle = tokio::spawn(run_refresh_listener(pool.clone(), config.clone()));
|
||||
|
||||
// M5: job execution engine
|
||||
let job_exec_handle = tokio::spawn(run_job_executor(pool.clone(), config.clone()));
|
||||
|
||||
tracing::info!("Worker tasks started");
|
||||
|
||||
// Wait for all tasks (they run indefinitely)
|
||||
let _ = tokio::join!(heartbeat_handle);
|
||||
let _ = tokio::join!(
|
||||
heartbeat_handle,
|
||||
health_handle,
|
||||
patch_handle,
|
||||
refresh_handle,
|
||||
job_exec_handle,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
209
crates/pm-worker/src/patch_poller.rs
Normal file
209
crates/pm-worker/src/patch_poller.rs
Normal file
@ -0,0 +1,209 @@
|
||||
//! Periodic patch-data poller for all registered hosts.
|
||||
//!
|
||||
//! Polls every host via the agent `/patches` and `/packages` endpoints on
|
||||
//! each tick of `patch_poll_interval_secs`, with bounded concurrency
|
||||
//! controlled by a [`tokio::sync::Semaphore`].
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::AgentClient;
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::{
|
||||
sync::Semaphore,
|
||||
time,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host projection fetched for each poll cycle.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the patch poller loop indefinitely.
|
||||
///
|
||||
/// On each tick all registered hosts are queried concurrently (up to
|
||||
/// `max_concurrent_agent_calls` in-flight at once). Results are persisted
|
||||
/// to `host_patch_data` and `hosts.last_patch_at` is updated.
|
||||
pub async fn run_patch_poller(pool: PgPool, config: Arc<AppConfig>) {
|
||||
let interval_secs = config.worker.patch_poll_interval_secs;
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
|
||||
tracing::info!(
|
||||
interval_secs,
|
||||
"Patch poller started"
|
||||
);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller: failed to load agent certs — skipping cycle");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let client_cert = Arc::new(certs.client_cert);
|
||||
let client_key = Arc::new(certs.client_key);
|
||||
let ca_cert = Arc::new(certs.ca_cert);
|
||||
|
||||
let hosts: Vec<HostRow> = match sqlx::query_as(
|
||||
"SELECT id, ip_address::text AS ip_address, agent_port FROM hosts ORDER BY id",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller: failed to fetch hosts");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if hosts.is_empty() {
|
||||
tracing::debug!("Patch poller: no hosts registered, skipping cycle");
|
||||
continue;
|
||||
}
|
||||
|
||||
let total = hosts.len();
|
||||
let semaphore = Arc::new(Semaphore::new(config.worker.max_concurrent_agent_calls));
|
||||
|
||||
let mut handles = Vec::with_capacity(total);
|
||||
|
||||
for host in hosts {
|
||||
let pool = pool.clone();
|
||||
let sem = semaphore.clone();
|
||||
let cert = client_cert.clone();
|
||||
let key = client_key.clone();
|
||||
let ca = ca_cert.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = sem.acquire().await.expect("semaphore closed");
|
||||
poll_host_patches(pool, host, &cert, &key, &ca).await
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let mut succeeded = 0usize;
|
||||
let mut failed = 0usize;
|
||||
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(true) => succeeded += 1,
|
||||
Ok(false) => failed += 1,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Patch poller task panicked");
|
||||
failed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
total,
|
||||
succeeded,
|
||||
failed,
|
||||
"Patch poll cycle complete"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll a single host for patch and package data, persist the result.
|
||||
/// Returns `true` on success, `false` on any error.
|
||||
async fn poll_host_patches(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) -> bool {
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: failed to build AgentClient");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch patches and packages concurrently.
|
||||
let (patches_result, packages_result) =
|
||||
tokio::join!(client.patches(), client.packages_upgradable());
|
||||
|
||||
let patches_data = match patches_result {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: patches() failed");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
let packages_data = match packages_result {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Patch poller: packages_upgradable() failed");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
let available_patches = serde_json::to_value(&patches_data.patches).unwrap_or_default();
|
||||
let installed_packages = serde_json::to_value(&packages_data.packages).unwrap_or_default();
|
||||
let patch_count = patches_data.total as i32;
|
||||
let cve_count = patches_data
|
||||
.patches
|
||||
.iter()
|
||||
.filter(|p| !p.cve_ids.is_empty())
|
||||
.count() as i32;
|
||||
|
||||
// Insert into host_patch_data.
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_patch_data
|
||||
(host_id, available_patches, installed_packages, patch_count, cve_count)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&available_patches)
|
||||
.bind(&installed_packages)
|
||||
.bind(patch_count)
|
||||
.bind(cve_count)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to insert patch data");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update hosts.last_patch_at.
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE hosts SET last_patch_at = NOW() WHERE id = $1",
|
||||
)
|
||||
.bind(host.id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(host_id = %host.id, error = %e, "Patch poller: failed to update last_patch_at");
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
host_id = %host.id,
|
||||
patch_count,
|
||||
cve_count,
|
||||
"Patch data collected"
|
||||
);
|
||||
|
||||
true
|
||||
}
|
||||
265
crates/pm-worker/src/refresh_listener.rs
Normal file
265
crates/pm-worker/src/refresh_listener.rs
Normal file
@ -0,0 +1,265 @@
|
||||
//! On-demand refresh listener.
|
||||
//!
|
||||
//! Listens on the PostgreSQL `refresh_requested` NOTIFY channel. When a
|
||||
//! notification arrives the payload is expected to be a host UUID string.
|
||||
//! The listener immediately polls that host for health and patch data and
|
||||
//! persists the results — bypassing the normal poll intervals.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_agent_client::{AgentClient, AgentClientError};
|
||||
use pm_core::{
|
||||
config::AppConfig,
|
||||
models::HostHealthStatus,
|
||||
};
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::time;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent_loader::load_agent_certs;
|
||||
|
||||
/// Minimal host row used for on-demand refresh.
|
||||
#[derive(Debug, FromRow)]
|
||||
struct HostRow {
|
||||
id: Uuid,
|
||||
ip_address: String,
|
||||
agent_port: i32,
|
||||
}
|
||||
|
||||
/// Run the LISTEN/NOTIFY refresh listener indefinitely.
|
||||
///
|
||||
/// Automatically reconnects if the underlying PostgreSQL connection drops.
|
||||
pub async fn run_refresh_listener(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Refresh listener started — listening on 'refresh_requested'");
|
||||
|
||||
loop {
|
||||
if let Err(e) = listen_loop(&pool, &config).await {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
"Refresh listener disconnected, reconnecting in 5s"
|
||||
);
|
||||
time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inner loop — returns `Err` only on a fatal listener error so the outer
|
||||
/// loop can reconnect.
|
||||
async fn listen_loop(pool: &PgPool, config: &AppConfig) -> anyhow::Result<()> {
|
||||
let mut listener =
|
||||
sqlx::postgres::PgListener::connect(&config.database.url).await?;
|
||||
|
||||
listener.listen("refresh_requested").await?;
|
||||
|
||||
tracing::debug!("Refresh listener connected and listening");
|
||||
|
||||
loop {
|
||||
let notification = listener.recv().await?;
|
||||
let payload = notification.payload().to_string();
|
||||
|
||||
tracing::info!(payload, "Refresh notification received");
|
||||
|
||||
let host_id = match payload.parse::<Uuid>() {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
payload,
|
||||
error = %e,
|
||||
"Refresh listener: invalid UUID in notification payload"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch the host from the database.
|
||||
let host: Option<HostRow> = sqlx::query_as(
|
||||
"SELECT id, ip_address::text AS ip_address, agent_port FROM hosts WHERE id = $1",
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let host = match host {
|
||||
Some(h) => h,
|
||||
None => {
|
||||
tracing::warn!(%host_id, "Refresh listener: host not found");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Load certs for this refresh.
|
||||
let certs = match load_agent_certs(&config.security) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
%host_id,
|
||||
error = %e,
|
||||
"Refresh listener: failed to load agent certs"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn the actual work so the listener loop is not blocked.
|
||||
let pool_clone = pool.clone();
|
||||
let cert = certs.client_cert;
|
||||
let key = certs.client_key;
|
||||
let ca = certs.ca_cert;
|
||||
|
||||
tokio::spawn(async move {
|
||||
refresh_host(pool_clone, host, &cert, &key, &ca).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a full health + patch refresh for one host and persist results.
|
||||
async fn refresh_host(
|
||||
pool: PgPool,
|
||||
host: HostRow,
|
||||
client_cert: &[u8],
|
||||
client_key: &[u8],
|
||||
ca_cert: &[u8],
|
||||
) {
|
||||
let client = match AgentClient::new(
|
||||
&host.ip_address,
|
||||
host.agent_port as u16,
|
||||
client_cert,
|
||||
client_key,
|
||||
ca_cert,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to build AgentClient"
|
||||
);
|
||||
persist_health_unreachable(&pool, host.id).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// ── Health ────────────────────────────────────────────────────────────
|
||||
let (health_status, health_payload) = match client.health().await {
|
||||
Ok(data) => {
|
||||
let payload = serde_json::to_value(&data).unwrap_or_default();
|
||||
(HostHealthStatus::Healthy, payload)
|
||||
}
|
||||
Err(AgentClientError::Timeout) | Err(AgentClientError::Connect(_)) => {
|
||||
tracing::warn!(host_id = %host.id, "Refresh: agent unreachable");
|
||||
(HostHealthStatus::Unreachable, serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(host_id = %host.id, error = %e, "Refresh: health error");
|
||||
(HostHealthStatus::Degraded, serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
};
|
||||
|
||||
persist_health(&pool, host.id, &health_status, &health_payload).await;
|
||||
|
||||
// ── Patch data ────────────────────────────────────────────────────────
|
||||
let (patches_result, packages_result) =
|
||||
tokio::join!(client.patches(), client.packages_upgradable());
|
||||
|
||||
match (patches_result, packages_result) {
|
||||
(Ok(patches_data), Ok(packages_data)) => {
|
||||
let available_patches =
|
||||
serde_json::to_value(&patches_data.patches).unwrap_or_default();
|
||||
let installed_packages =
|
||||
serde_json::to_value(&packages_data.packages).unwrap_or_default();
|
||||
let patch_count = patches_data.total as i32;
|
||||
let cve_count = patches_data
|
||||
.patches
|
||||
.iter()
|
||||
.filter(|p| !p.cve_ids.is_empty())
|
||||
.count() as i32;
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_patch_data
|
||||
(host_id, available_patches, installed_packages, patch_count, cve_count)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
"#,
|
||||
)
|
||||
.bind(host.id)
|
||||
.bind(&available_patches)
|
||||
.bind(&installed_packages)
|
||||
.bind(patch_count)
|
||||
.bind(cve_count)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to insert patch data"
|
||||
);
|
||||
} else {
|
||||
let _ = sqlx::query(
|
||||
"UPDATE hosts SET last_patch_at = NOW() WHERE id = $1",
|
||||
)
|
||||
.bind(host.id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
host_id = %host.id,
|
||||
patch_count,
|
||||
cve_count,
|
||||
"On-demand refresh complete"
|
||||
);
|
||||
}
|
||||
}
|
||||
(Err(e), _) | (_, Err(e)) => {
|
||||
tracing::warn!(
|
||||
host_id = %host.id,
|
||||
error = %e,
|
||||
"Refresh: failed to collect patch data"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_health_unreachable(pool: &PgPool, host_id: Uuid) {
|
||||
let status = HostHealthStatus::Unreachable;
|
||||
let payload = serde_json::Value::Object(Default::default());
|
||||
persist_health(pool, host_id, &status, &payload).await;
|
||||
}
|
||||
|
||||
async fn persist_health(
|
||||
pool: &PgPool,
|
||||
host_id: Uuid,
|
||||
status: &HostHealthStatus,
|
||||
payload: &serde_json::Value,
|
||||
) {
|
||||
if let Err(e) = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO host_health_data (host_id, status, payload)
|
||||
VALUES ($1, $2, $3)
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(status)
|
||||
.bind(payload)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
%host_id,
|
||||
error = %e,
|
||||
"Refresh: failed to insert health data"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE hosts SET health_status = $2, last_health_at = NOW() WHERE id = $1",
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(status)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(%host_id, error = %e, "Refresh: failed to update host health_status");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user