//! WS relay — M7 //! //! For every running `patch_job_hosts` row that has an `agent_job_id`, open a //! WebSocket to the corresponding agent, stream job-status events, update the //! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS //! handler can forward the event to connected clients. use std::{collections::HashSet, sync::Arc, time::Duration}; use anyhow::Context; use futures::StreamExt; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer}, ClientConfig as TlsClientConfig, RootCertStore, }; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use tokio::sync::Mutex; use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::protocol::Message, Connector}; use uuid::Uuid; use pm_agent_client::client::DEFAULT_AGENT_PORT; use pm_core::config::AppConfig; // ── Types ───────────────────────────────────────────────────────────────────── #[derive(Debug, sqlx::FromRow)] struct RunningHostJob { job_id: Uuid, host_id: Uuid, agent_job_id: String, host_address: String, } /// JSON event streamed by the agent over its WS endpoint. #[derive(Debug, Deserialize)] struct AgentWsEvent { #[allow(dead_code)] job_id: String, status: String, output: Option, error: Option, #[allow(dead_code)] progress_percent: Option, } /// Payload broadcast via `pg_notify('job_update', …)`. #[derive(Debug, Serialize)] struct NotifyPayload { job_id: String, host_id: String, status: String, output: Option, error_message: Option, agent_job_id: String, } // ── Entry point ─────────────────────────────────────────────────────────────── /// Long-running task: polls the DB for running host-jobs and spawns a per-pair /// relay task for each one that isn't already being tracked. pub async fn run_ws_relay(pool: PgPool, config: Arc) { tracing::info!("WS relay task started"); let active: Arc>> = Arc::new(Mutex::new(HashSet::new())); let mut interval = tokio::time::interval(Duration::from_secs(10)); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { interval.tick().await; let rows = match query_running_jobs(&pool).await { Ok(r) => r, Err(e) => { tracing::error!(error = %e, "ws_relay: DB poll failed"); continue; }, }; for row in rows { let key = (row.job_id, row.host_id); // Skip pairs that already have an active relay. if active.lock().await.contains(&key) { continue; } // Build the rustls ClientConfig once per connection. let tls_config = match build_tls_config(&config).await { Ok(c) => Arc::new(c), Err(e) => { tracing::error!(error = %e, "ws_relay: TLS config error"); continue; }, }; active.lock().await.insert(key); let pool_c = pool.clone(); let active_c = active.clone(); tokio::spawn(async move { tracing::info!( job_id = %row.job_id, host_id = %row.host_id, agent_job_id = %row.agent_job_id, host = %row.host_address, "WS relay: starting relay" ); match relay_one_job(&pool_c, &row, tls_config).await { Ok(()) => tracing::info!( job_id = %row.job_id, host_id = %row.host_id, "WS relay: completed" ), Err(e) => tracing::error!( error = %e, job_id = %row.job_id, host_id = %row.host_id, "WS relay: ended with error" ), } active_c.lock().await.remove(&key); }); } } } // ── DB helpers ──────────────────────────────────────────────────────────────── async fn query_running_jobs(pool: &PgPool) -> anyhow::Result> { sqlx::query_as::<_, RunningHostJob>( r#" SELECT pjh.job_id, pjh.host_id, pjh.agent_job_id, COALESCE(h.fqdn, h.ip_address::text) AS host_address FROM patch_job_hosts pjh JOIN hosts h ON h.id = pjh.host_id WHERE pjh.status = 'running'::job_status AND pjh.agent_job_id IS NOT NULL "#, ) .fetch_all(pool) .await .context("query_running_jobs") } // ── TLS ─────────────────────────────────────────────────────────────────────── async fn build_tls_config(config: &AppConfig) -> anyhow::Result { let sec = &config.security; let cert_pem = tokio::fs::read(&sec.agent_client_cert_path) .await .with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?; let key_pem = tokio::fs::read(&sec.agent_client_key_path) .await .with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?; let ca_pem = tokio::fs::read(&sec.ca_cert_path) .await .with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?; // Parse client certificate chain. let client_certs: Vec> = { let mut cur = std::io::Cursor::new(&cert_pem); rustls_pemfile::certs(&mut cur) .collect::, _>>() .context("parse client cert PEM")? }; // Parse client private key. let client_key: PrivateKeyDer<'static> = { let mut cur = std::io::Cursor::new(&key_pem); rustls_pemfile::private_key(&mut cur) .context("parse client key PEM")? .context("no private key in PEM")? }; // Build root store from CA cert. let mut root_store = RootCertStore::empty(); { let mut cur = std::io::Cursor::new(&ca_pem); for cert_result in rustls_pemfile::certs(&mut cur) { root_store .add(cert_result.context("read CA cert entry")?) .context("add CA cert to root store")?; } } TlsClientConfig::builder() .with_root_certificates(root_store) .with_client_auth_cert(client_certs, client_key) .context("build TlsClientConfig") } // ── Per-job relay ───────────────────────────────────────────────────────────── async fn relay_one_job( pool: &PgPool, row: &RunningHostJob, tls_config: Arc, ) -> anyhow::Result<()> { let url = format!( "wss://{}:{}/api/v1/ws/jobs", row.host_address, DEFAULT_AGENT_PORT, ); let (ws_stream, _) = connect_async_tls_with_config( url.as_str(), None, false, Some(Connector::Rustls(tls_config)), ) .await .with_context(|| format!("connect agent WS {url}"))?; let (_sink, mut stream) = ws_stream.split(); while let Some(frame) = stream.next().await { let frame = match frame { Ok(f) => f, Err(e) => { tracing::warn!( error = %e, job_id = %row.job_id, host_id = %row.host_id, "WS relay: stream error" ); break; }, }; let text = match frame { Message::Text(t) => t.to_string(), Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(), Message::Close(_) => { tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly"); break; }, _ => continue, }; if text.is_empty() { continue; } let event: AgentWsEvent = match serde_json::from_str(&text) { Ok(e) => e, Err(e) => { tracing::warn!( error = %e, raw = %text, "WS relay: unparseable agent frame" ); continue; }, }; process_event(pool, row, &event).await; if matches!(event.status.as_str(), "succeeded" | "failed" | "cancelled") { tracing::info!( job_id = %row.job_id, host_id = %row.host_id, status = %event.status, "WS relay: terminal state — stopping" ); break; } } Ok(()) } // ── Event processing ────────────────────────────────────────────────────────── async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) { // Map agent status string to DB job_status enum value. let db_status = match event.status.as_str() { "running" => "running", "succeeded" => "succeeded", "failed" => "failed", "cancelled" => "cancelled", other => { tracing::warn!(status = %other, "WS relay: unknown agent status"); return; }, }; let output = event.output.as_deref().unwrap_or(""); let error_msg = event.error.as_deref(); // Determine timestamps based on terminal state. let is_terminal = matches!(db_status, "succeeded" | "failed" | "cancelled"); // Update the DB row. let update_result = if is_terminal { sqlx::query( r#" UPDATE patch_job_hosts SET status = $1::job_status, output = CASE WHEN $2 != '' THEN $2 ELSE output END, error_message = $3, completed_at = NOW() WHERE job_id = $4 AND host_id = $5 "#, ) .bind(db_status) .bind(output) .bind(error_msg) .bind(row.job_id) .bind(row.host_id) .execute(pool) .await } else { sqlx::query( r#" UPDATE patch_job_hosts SET status = $1::job_status, output = CASE WHEN $2 != '' THEN $2 ELSE output END WHERE job_id = $3 AND host_id = $4 "#, ) .bind(db_status) .bind(output) .bind(row.job_id) .bind(row.host_id) .execute(pool) .await }; if let Err(e) = update_result { tracing::error!( error = %e, job_id = %row.job_id, host_id = %row.host_id, "WS relay: DB update failed" ); return; } // Also update the parent patch_jobs status when the host-level job reaches // a terminal state: running → if all hosts terminal then update parent. if is_terminal { update_parent_job_status(pool, row.job_id).await; } // Fire pg_notify so browser WS handlers forward the event. let payload = NotifyPayload { job_id: row.job_id.to_string(), host_id: row.host_id.to_string(), status: db_status.to_string(), output: event.output.clone(), error_message: event.error.clone(), agent_job_id: row.agent_job_id.clone(), }; let payload_json = match serde_json::to_string(&payload) { Ok(s) => s, Err(e) => { tracing::error!(error = %e, "WS relay: failed to serialize notify payload"); return; }, }; if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)") .bind(&payload_json) .execute(pool) .await { tracing::error!( error = %e, job_id = %row.job_id, host_id = %row.host_id, "WS relay: pg_notify failed" ); } else { tracing::debug!( job_id = %row.job_id, host_id = %row.host_id, status = %db_status, "WS relay: pg_notify sent" ); } } // ── Parent job status rollup ────────────────────────────────────────────────── /// After a host-level job reaches a terminal state, check whether ALL hosts for /// that job are now terminal and update the parent `patch_jobs` row accordingly. async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) { // Count hosts that are still in a non-terminal state. let pending: i64 = match sqlx::query_scalar( r#" SELECT COUNT(*) FROM patch_job_hosts WHERE job_id = $1 AND status NOT IN ( 'succeeded'::job_status, 'failed'::job_status, 'cancelled'::job_status ) "#, ) .bind(job_id) .fetch_one(pool) .await { Ok(n) => n, Err(e) => { tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed"); return; }, }; if pending > 0 { return; // still hosts running — parent stays running } // All hosts terminal — determine final parent status. let failed_count: i64 = match sqlx::query_scalar( "SELECT COUNT(*) FROM patch_job_hosts WHERE job_id = $1 AND status = 'failed'::job_status", ) .bind(job_id) .fetch_one(pool) .await { Ok(n) => n, Err(e) => { tracing::error!(error = %e, %job_id, "update_parent_job_status: failed-count query failed"); return; }, }; let final_status = if failed_count > 0 { "failed" } else { "succeeded" }; if let Err(e) = sqlx::query( "UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2", ) .bind(final_status) .bind(job_id) .execute(pool) .await { tracing::error!( error = %e, %job_id, status = %final_status, "update_parent_job_status: UPDATE failed" ); } else { tracing::info!( %job_id, status = %final_status, "Parent job status updated" ); } }