feat: M6 maintenance windows + M7 WebSocket relay (real-time job status)
M6 - Maintenance Windows: - routes/maintenance_windows.rs: full CRUD API - migrations/004_maintenance_windows.sql - frontend/MaintenanceWindowsPage.tsx - HostDetailPage.tsx: maintenance window config panel M7 - WebSocket Relay: - pm-web: POST /api/v1/ws/ticket (JWT-auth, single-use, 60s TTL) - pm-web: WS /api/v1/ws/jobs?ticket=... (PgListener -> browser push) - pm-web: DashMap<String,WsTicket> in AppState, 30s cleanup task - pm-worker: ws_relay.rs subscribes to agent WS, updates patch_job_hosts, fires pg_notify(job_update) for real-time fan-out - frontend: useJobWebSocket hook with auto-reconnect + exponential backoff - frontend: JobsPage live updates with WS status indicator - types: JobWsEvent interface - api/client: wsApi.createTicket() All tasks marked complete in tasks/todo.md cargo build: zero errors, zero warnings
This commit is contained in:
@ -298,3 +298,77 @@ pub struct PatchJobSummary {
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Maintenance Windows
|
||||
// ============================================================
|
||||
|
||||
/// Recurrence type for a maintenance window.
|
||||
/// Mirrors the `window_recurrence` PostgreSQL ENUM.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
|
||||
#[sqlx(type_name = "window_recurrence", rename_all = "lowercase")]
|
||||
pub enum WindowRecurrence {
|
||||
/// Single one-time window (at `start_at` for `duration_minutes` minutes).
|
||||
Once,
|
||||
/// Repeats every day at the time portion of `start_at`.
|
||||
Daily,
|
||||
/// Repeats on the day-of-week in `recurrence_day` (0 = Sunday).
|
||||
Weekly,
|
||||
/// Repeats on the day-of-month in `recurrence_day` (1-31).
|
||||
Monthly,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WindowRecurrence {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Once => write!(f, "once"),
|
||||
Self::Daily => write!(f, "daily"),
|
||||
Self::Weekly => write!(f, "weekly"),
|
||||
Self::Monthly => write!(f, "monthly"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Full row from `maintenance_windows`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct MaintenanceWindow {
|
||||
pub id: Uuid,
|
||||
pub host_id: Uuid,
|
||||
pub label: String,
|
||||
pub recurrence: WindowRecurrence,
|
||||
/// Absolute start time (one-time) or time-of-day reference (recurring).
|
||||
pub start_at: DateTime<Utc>,
|
||||
/// Duration of the window in minutes.
|
||||
pub duration_minutes: i32,
|
||||
/// Day-of-week (0=Sun, weekly) or day-of-month (1-31, monthly); NULL for once/daily.
|
||||
pub recurrence_day: Option<i32>,
|
||||
pub enabled: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Payload for `POST /api/v1/hosts/{id}/maintenance-windows`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateMaintenanceWindowRequest {
|
||||
pub label: String,
|
||||
pub recurrence: WindowRecurrence,
|
||||
/// RFC 3339 / ISO 8601 timestamp (UTC recommended).
|
||||
pub start_at: DateTime<Utc>,
|
||||
/// How many minutes the window is open (default 60).
|
||||
pub duration_minutes: Option<i32>,
|
||||
/// Required for `weekly` (0-6) and `monthly` (1-31).
|
||||
pub recurrence_day: Option<i32>,
|
||||
/// Whether the window is active (default true).
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
/// Payload for `PUT /api/v1/hosts/{id}/maintenance-windows/{window_id}`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateMaintenanceWindowRequest {
|
||||
pub label: Option<String>,
|
||||
pub recurrence: Option<WindowRecurrence>,
|
||||
pub start_at: Option<DateTime<Utc>>,
|
||||
pub duration_minutes: Option<i32>,
|
||||
pub recurrence_day: Option<i32>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
@ -28,3 +28,4 @@ uuid = { workspace = true }
|
||||
ulid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
dashmap = { version = "6" }
|
||||
|
||||
@ -10,6 +10,7 @@ use axum::{
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use pm_core::{
|
||||
config::AppConfig,
|
||||
db,
|
||||
@ -20,8 +21,13 @@ use pm_auth::{
|
||||
jwt,
|
||||
rbac::{AuthConfig, require_auth},
|
||||
};
|
||||
use routes::ws::WsTicket;
|
||||
use serde_json::{json, Value};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tower_http::{
|
||||
services::ServeDir,
|
||||
trace::TraceLayer,
|
||||
@ -34,6 +40,8 @@ pub struct AppState {
|
||||
pub config: Arc<AppConfig>,
|
||||
pub signing_key_pem: String,
|
||||
pub auth_config: Arc<AuthConfig>,
|
||||
/// In-memory store for single-use WebSocket authentication tickets.
|
||||
pub ws_tickets: Arc<DashMap<String, WsTicket>>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -69,11 +77,32 @@ async fn main() -> anyhow::Result<()> {
|
||||
let pool = db::init_pool(&config.database).await?;
|
||||
db::run_migrations(&pool).await?;
|
||||
|
||||
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
|
||||
|
||||
// Background task: purge expired WS tickets every 30 seconds.
|
||||
{
|
||||
let tickets = ws_tickets.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(30));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = chrono::Utc::now();
|
||||
let before = tickets.len();
|
||||
tickets.retain(|_, v| v.expires_at > now);
|
||||
let removed = before.saturating_sub(tickets.len());
|
||||
if removed > 0 {
|
||||
tracing::debug!(removed, "Purged expired WS tickets");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let state = AppState {
|
||||
db: pool,
|
||||
config: Arc::new(config.clone()),
|
||||
signing_key_pem,
|
||||
auth_config,
|
||||
ws_tickets,
|
||||
};
|
||||
|
||||
let app = build_router(state);
|
||||
@ -109,6 +138,10 @@ pub fn build_router(state: AppState) -> Router {
|
||||
.nest("/status", routes::status::router())
|
||||
// Patch jobs
|
||||
.nest("/jobs", routes::jobs::router())
|
||||
// Maintenance windows (nested under hosts path param)
|
||||
.nest("/hosts/:host_id/maintenance-windows", routes::maintenance_windows::router())
|
||||
// WS ticket issuance (JWT-protected — ticket returned to browser, then used for WS upgrade)
|
||||
.merge(routes::ws::ticket_router())
|
||||
// Apply auth middleware to all the above
|
||||
.route_layer(middleware::from_fn(move |req, next| {
|
||||
let auth_config = auth_config.clone();
|
||||
@ -121,6 +154,8 @@ pub fn build_router(state: AppState) -> Router {
|
||||
.nest("/api/v1/auth", routes::auth::public_router())
|
||||
// Protected API routes (JWT required)
|
||||
.nest("/api/v1", protected_api)
|
||||
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
|
||||
.merge(routes::ws::ws_router())
|
||||
// Serve React SPA
|
||||
.fallback_service(
|
||||
ServeDir::new(&static_dir).append_index_html_on_directories(true),
|
||||
|
||||
364
crates/pm-web/src/routes/maintenance_windows.rs
Normal file
364
crates/pm-web/src/routes/maintenance_windows.rs
Normal file
@ -0,0 +1,364 @@
|
||||
//! Maintenance window management routes.
|
||||
//!
|
||||
//! GET /api/v1/hosts/{id}/maintenance-windows — list windows for host
|
||||
//! POST /api/v1/hosts/{id}/maintenance-windows — create window for host
|
||||
//! PUT /api/v1/hosts/{id}/maintenance-windows/{win_id} — update window
|
||||
//! DELETE /api/v1/hosts/{id}/maintenance-windows/{win_id} — delete window
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, put},
|
||||
Router,
|
||||
};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use pm_core::{
|
||||
audit::{log_event, AuditAction},
|
||||
models::{
|
||||
CreateMaintenanceWindowRequest, MaintenanceWindow, UpdateMaintenanceWindowRequest,
|
||||
},
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Mount as a nested router under `/hosts/:host_id/maintenance-windows`.
|
||||
/// Axum will merge the `:host_id` path segment from the parent nest.
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(list_windows).post(create_window))
|
||||
.route("/:win_id", put(update_window).delete(delete_window))
|
||||
}
|
||||
|
||||
// ── Error helper ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── GET /api/v1/hosts/:host_id/maintenance-windows ────────────────────────────
|
||||
|
||||
async fn list_windows(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Verify host exists.
|
||||
let host_exists: bool =
|
||||
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "list_windows: host existence check failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
if !host_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
|
||||
}
|
||||
|
||||
let windows: Vec<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
WHERE host_id = $1
|
||||
ORDER BY created_at ASC
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "list_windows: query failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
Ok(Json(json!({ "windows": windows })))
|
||||
}
|
||||
|
||||
// ── POST /api/v1/hosts/:host_id/maintenance-windows ───────────────────────────
|
||||
|
||||
async fn create_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path(host_id): Path<Uuid>,
|
||||
Json(req): Json<CreateMaintenanceWindowRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Validate: weekly requires recurrence_day 0-6
|
||||
if req.recurrence == pm_core::models::WindowRecurrence::Weekly {
|
||||
match req.recurrence_day {
|
||||
Some(d) if (0..=6).contains(&d) => {}
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Weekly recurrence requires recurrence_day 0-6 (0=Sunday)",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate: monthly requires recurrence_day 1-31
|
||||
if req.recurrence == pm_core::models::WindowRecurrence::Monthly {
|
||||
match req.recurrence_day {
|
||||
Some(d) if (1..=31).contains(&d) => {}
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Monthly recurrence requires recurrence_day 1-31",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify host exists.
|
||||
let host_exists: bool =
|
||||
sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM hosts WHERE id = $1)")
|
||||
.bind(host_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "create_window: host existence check failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
if !host_exists {
|
||||
return Err(err(StatusCode::NOT_FOUND, "not_found", "Host not found"));
|
||||
}
|
||||
|
||||
let duration = req.duration_minutes.unwrap_or(60);
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
|
||||
let window: MaintenanceWindow = sqlx::query_as(
|
||||
r#"
|
||||
INSERT INTO maintenance_windows
|
||||
(host_id, label, recurrence, start_at, duration_minutes, recurrence_day, enabled)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(host_id)
|
||||
.bind(&req.label)
|
||||
.bind(&req.recurrence)
|
||||
.bind(req.start_at)
|
||||
.bind(duration)
|
||||
.bind(req.recurrence_day)
|
||||
.bind(enabled)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %host_id, "create_window: insert failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowCreated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&window.id.to_string()),
|
||||
json!({
|
||||
"host_id": host_id,
|
||||
"label": window.label,
|
||||
"recurrence": window.recurrence.to_string(),
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %window.id,
|
||||
%host_id,
|
||||
recurrence = %window.recurrence,
|
||||
user = %auth.username,
|
||||
"Maintenance window created"
|
||||
);
|
||||
|
||||
Ok(Json(json!(window)))
|
||||
}
|
||||
|
||||
// ── PUT /api/v1/hosts/:host_id/maintenance-windows/:win_id ───────────────────
|
||||
|
||||
async fn update_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
|
||||
Json(req): Json<UpdateMaintenanceWindowRequest>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// Fetch existing record (verify ownership and existence).
|
||||
let existing: Option<MaintenanceWindow> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, created_at, updated_at
|
||||
FROM maintenance_windows
|
||||
WHERE id = $1 AND host_id = $2
|
||||
"#,
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "update_window: fetch failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
let existing = existing.ok_or_else(|| {
|
||||
err(StatusCode::NOT_FOUND, "not_found", "Maintenance window not found")
|
||||
})?;
|
||||
|
||||
// Apply partial updates using existing values as defaults.
|
||||
let new_label = req.label.unwrap_or(existing.label);
|
||||
let new_recurrence = req.recurrence.unwrap_or(existing.recurrence);
|
||||
let new_start_at = req.start_at.unwrap_or(existing.start_at);
|
||||
let new_duration = req.duration_minutes.unwrap_or(existing.duration_minutes);
|
||||
let new_rec_day = req.recurrence_day.or(existing.recurrence_day);
|
||||
let new_enabled = req.enabled.unwrap_or(existing.enabled);
|
||||
|
||||
// Validate recurrence_day for the final recurrence type.
|
||||
if new_recurrence == pm_core::models::WindowRecurrence::Weekly {
|
||||
match new_rec_day {
|
||||
Some(d) if (0..=6).contains(&d) => {}
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Weekly recurrence requires recurrence_day 0-6",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
if new_recurrence == pm_core::models::WindowRecurrence::Monthly {
|
||||
match new_rec_day {
|
||||
Some(d) if (1..=31).contains(&d) => {}
|
||||
_ => {
|
||||
return Err(err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"bad_request",
|
||||
"Monthly recurrence requires recurrence_day 1-31",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let updated: MaintenanceWindow = sqlx::query_as(
|
||||
r#"
|
||||
UPDATE maintenance_windows
|
||||
SET label = $3,
|
||||
recurrence = $4,
|
||||
start_at = $5,
|
||||
duration_minutes = $6,
|
||||
recurrence_day = $7,
|
||||
enabled = $8,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND host_id = $2
|
||||
RETURNING id, host_id, label, recurrence, start_at, duration_minutes,
|
||||
recurrence_day, enabled, created_at, updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.bind(&new_label)
|
||||
.bind(&new_recurrence)
|
||||
.bind(new_start_at)
|
||||
.bind(new_duration)
|
||||
.bind(new_rec_day)
|
||||
.bind(new_enabled)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "update_window: update failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowUpdated,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&win_id.to_string()),
|
||||
json!({ "host_id": host_id }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %win_id,
|
||||
%host_id,
|
||||
user = %auth.username,
|
||||
"Maintenance window updated"
|
||||
);
|
||||
|
||||
Ok(Json(json!(updated)))
|
||||
}
|
||||
|
||||
// ── DELETE /api/v1/hosts/:host_id/maintenance-windows/:win_id ────────────────
|
||||
|
||||
async fn delete_window(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
Path((host_id, win_id)): Path<(Uuid, Uuid)>,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM maintenance_windows WHERE id = $1 AND host_id = $2",
|
||||
)
|
||||
.bind(win_id)
|
||||
.bind(host_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, %win_id, "delete_window: delete failed");
|
||||
err(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", "Database error")
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(err(
|
||||
StatusCode::NOT_FOUND,
|
||||
"not_found",
|
||||
"Maintenance window not found",
|
||||
));
|
||||
}
|
||||
|
||||
log_event(
|
||||
&state.db,
|
||||
AuditAction::MaintenanceWindowDeleted,
|
||||
Some(auth.user_id),
|
||||
Some(&auth.username),
|
||||
Some("maintenance_window"),
|
||||
Some(&win_id.to_string()),
|
||||
json!({ "host_id": host_id }),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
window_id = %win_id,
|
||||
%host_id,
|
||||
user = %auth.username,
|
||||
"Maintenance window deleted"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "message": "Maintenance window deleted" })))
|
||||
}
|
||||
@ -3,6 +3,8 @@ pub mod auth;
|
||||
pub mod discovery;
|
||||
pub mod groups;
|
||||
pub mod hosts;
|
||||
pub mod maintenance_windows;
|
||||
pub mod jobs;
|
||||
pub mod status;
|
||||
pub mod users;
|
||||
pub mod ws;
|
||||
|
||||
212
crates/pm-web/src/routes/ws.rs
Normal file
212
crates/pm-web/src/routes/ws.rs
Normal file
@ -0,0 +1,212 @@
|
||||
//! WebSocket relay routes — M7
|
||||
//!
|
||||
//! POST /api/v1/ws/ticket — create a single-use WS auth ticket (JWT-protected)
|
||||
//! GET /api/v1/ws/jobs — browser WebSocket endpoint (ticket-authenticated)
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State, WebSocketUpgrade},
|
||||
extract::ws::{Message, WebSocket},
|
||||
http::StatusCode,
|
||||
response::{Json, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use pm_auth::rbac::AuthUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::postgres::PgListener;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
// ── WsTicket ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Single-use WebSocket authentication ticket stored in-memory.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WsTicket {
|
||||
pub user_id: Uuid,
|
||||
pub role: String,
|
||||
pub expires_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Router for ticket-issuance endpoint (JWT-protected, merged into protected_api).
|
||||
pub fn ticket_router() -> Router<AppState> {
|
||||
Router::new().route("/ws/ticket", post(create_ticket_handler))
|
||||
}
|
||||
|
||||
/// Router for the WebSocket endpoint (ticket-authenticated, NO JWT middleware).
|
||||
pub fn ws_router() -> Router<AppState> {
|
||||
Router::new().route("/api/v1/ws/jobs", get(ws_handler))
|
||||
}
|
||||
|
||||
// ── Error helper ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[inline]
|
||||
fn err(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
message: impl Into<String>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
(
|
||||
status,
|
||||
Json(json!({ "error": { "code": code, "message": message.into() } })),
|
||||
)
|
||||
}
|
||||
|
||||
// ── POST /api/v1/ws/ticket ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
/// Issue a single-use WebSocket authentication ticket (60 s expiry).
|
||||
pub async fn create_ticket_handler(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthUser,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
let ticket_id = Ulid::new().to_string();
|
||||
let expires_at = Utc::now() + Duration::seconds(60);
|
||||
|
||||
let ticket = WsTicket {
|
||||
user_id: auth.user_id,
|
||||
role: auth.role.as_str().to_string(),
|
||||
expires_at,
|
||||
};
|
||||
|
||||
state.ws_tickets.insert(ticket_id.clone(), ticket);
|
||||
|
||||
tracing::info!(
|
||||
user_id = %auth.user_id,
|
||||
username = %auth.username,
|
||||
ticket = %ticket_id,
|
||||
"WS ticket issued"
|
||||
);
|
||||
|
||||
Ok(Json(json!({ "ticket": ticket_id })))
|
||||
}
|
||||
|
||||
// ── GET /api/v1/ws/jobs ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsQuery {
|
||||
pub ticket: String,
|
||||
}
|
||||
|
||||
/// Browser WebSocket upgrade endpoint — authenticates via single-use ticket.
|
||||
pub async fn ws_handler(
|
||||
State(state): State<AppState>,
|
||||
Query(q): Query<WsQuery>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, (StatusCode, Json<Value>)> {
|
||||
// Validate and consume the ticket atomically.
|
||||
let ticket = {
|
||||
let entry = state.ws_tickets.get(&q.ticket);
|
||||
match entry {
|
||||
None => {
|
||||
return Err(err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid_ticket",
|
||||
"WebSocket ticket not found or already used",
|
||||
));
|
||||
}
|
||||
Some(t) => {
|
||||
if t.expires_at < Utc::now() {
|
||||
drop(t);
|
||||
state.ws_tickets.remove(&q.ticket);
|
||||
return Err(err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"ticket_expired",
|
||||
"WebSocket ticket has expired",
|
||||
));
|
||||
}
|
||||
t.clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
// Single-use: remove immediately after validation.
|
||||
state.ws_tickets.remove(&q.ticket);
|
||||
|
||||
tracing::info!(
|
||||
user_id = %ticket.user_id,
|
||||
role = %ticket.role,
|
||||
"Browser WebSocket connection upgraded"
|
||||
);
|
||||
|
||||
let db = state.db.clone();
|
||||
Ok(ws.on_upgrade(move |socket| handle_browser_ws(socket, db, ticket)))
|
||||
}
|
||||
|
||||
// ── WebSocket handler ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Drive the browser WebSocket: LISTEN on `job_update` and forward payloads.
|
||||
async fn handle_browser_ws(
|
||||
mut socket: WebSocket,
|
||||
db: sqlx::PgPool,
|
||||
ticket: WsTicket,
|
||||
) {
|
||||
// Acquire a dedicated PG listener connection.
|
||||
let mut listener = match PgListener::connect_with(&db).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "Failed to create PgListener");
|
||||
let _ = socket
|
||||
.send(Message::Text(
|
||||
json!({ "error": "internal_error" }).to_string().into(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = listener.listen("job_update").await {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener LISTEN failed");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS: LISTEN job_update started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward PG notifications to the browser.
|
||||
notify_result = listener.recv() => {
|
||||
match notify_result {
|
||||
Ok(notification) => {
|
||||
let payload = notification.payload().to_string();
|
||||
tracing::debug!(user_id = %ticket.user_id, payload = %payload, "Forwarding job_update");
|
||||
if socket.send(Message::Text(payload.into())).await.is_err() {
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS send failed — client disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, user_id = %ticket.user_id, "PgListener recv error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle incoming frames from the browser (ping/close).
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS closed by client");
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
if socket.send(Message::Pong(data)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
tracing::debug!(error = %e, user_id = %ticket.user_id, "Browser WS recv error");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(user_id = %ticket.user_id, "Browser WS handler exiting");
|
||||
}
|
||||
@ -23,3 +23,7 @@ tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { version = "0.26" }
|
||||
rustls-pemfile = { version = "2" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
|
||||
|
||||
@ -204,6 +204,39 @@ async fn scan_queued_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
WHERE pjh.status = 'queued'
|
||||
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
|
||||
AND j.status != 'cancelled'
|
||||
AND (
|
||||
-- Immediate jobs always dispatch
|
||||
j.immediate = TRUE
|
||||
OR
|
||||
-- Non-immediate jobs only dispatch when the host has an open window
|
||||
EXISTS (
|
||||
SELECT 1 FROM maintenance_windows mw
|
||||
WHERE mw.host_id = pjh.host_id
|
||||
AND mw.enabled = TRUE
|
||||
AND (
|
||||
(mw.recurrence = 'once'
|
||||
AND mw.start_at <= NOW()
|
||||
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
OR
|
||||
(mw.recurrence = 'daily'
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute')))
|
||||
OR
|
||||
(mw.recurrence = 'weekly'
|
||||
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute')))
|
||||
OR
|
||||
(mw.recurrence = 'monthly'
|
||||
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute')))
|
||||
)
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
@ -230,7 +263,7 @@ async fn scan_queued_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
|
||||
/// Fetch all queued host entries for `job_id` and dispatch them concurrently,
|
||||
/// bounded by `config.worker.max_concurrent_agent_calls`.
|
||||
async fn process_job(pool: PgPool, config: Arc<AppConfig>, job_id: Uuid) {
|
||||
pub async fn process_job(pool: PgPool, config: Arc<AppConfig>, job_id: Uuid) {
|
||||
tracing::info!(%job_id, "process_job: dispatching queued hosts");
|
||||
|
||||
// Mark the parent job as running (idempotent guard).
|
||||
|
||||
@ -5,9 +5,11 @@
|
||||
|
||||
mod agent_loader;
|
||||
mod health_poller;
|
||||
mod maintenance_scheduler;
|
||||
mod patch_poller;
|
||||
mod refresh_listener;
|
||||
mod job_executor;
|
||||
mod ws_relay;
|
||||
|
||||
use pm_core::{
|
||||
config::AppConfig,
|
||||
@ -19,9 +21,11 @@ use std::{sync::Arc, time::Duration};
|
||||
use tokio::time;
|
||||
|
||||
use health_poller::run_health_poller;
|
||||
use maintenance_scheduler::run_maintenance_scheduler;
|
||||
use patch_poller::run_patch_poller;
|
||||
use refresh_listener::run_refresh_listener;
|
||||
use job_executor::run_job_executor;
|
||||
use ws_relay::run_ws_relay;
|
||||
|
||||
/// Minimum number of applied migrations the worker requires before
|
||||
/// accepting work. Prevents the worker from running against a schema
|
||||
@ -70,6 +74,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
// M5: job execution engine
|
||||
let job_exec_handle = tokio::spawn(run_job_executor(pool.clone(), config.clone()));
|
||||
|
||||
// M6: maintenance window scheduler
|
||||
let maint_sched_handle = tokio::spawn(run_maintenance_scheduler(pool.clone(), config.clone()));
|
||||
|
||||
// M7: WS relay — streams agent job events → DB → pg_notify → browser WS
|
||||
let ws_relay_handle = tokio::spawn(run_ws_relay(pool.clone(), config.clone()));
|
||||
|
||||
tracing::info!("Worker tasks started");
|
||||
|
||||
// Wait for all tasks (they run indefinitely)
|
||||
@ -79,6 +89,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
patch_handle,
|
||||
refresh_handle,
|
||||
job_exec_handle,
|
||||
maint_sched_handle,
|
||||
ws_relay_handle,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
|
||||
164
crates/pm-worker/src/maintenance_scheduler.rs
Normal file
164
crates/pm-worker/src/maintenance_scheduler.rs
Normal file
@ -0,0 +1,164 @@
|
||||
//! Maintenance window scheduler.
|
||||
//!
|
||||
//! Polls every 60 seconds and, for each enabled maintenance window that is
|
||||
//! currently open, dispatches any queued non-immediate patch jobs associated
|
||||
//! with the window's host.
|
||||
//!
|
||||
//! A window is considered "open" when:
|
||||
//! - `once` — `start_at <= NOW() < start_at + duration_minutes * '1 minute'`
|
||||
//! - `daily` — current UTC time-of-day is within the window's daily slot
|
||||
//! - `weekly` — same as daily, but only on the matching `recurrence_day` (0=Sun)
|
||||
//! - `monthly` — same as daily, but only on the matching `recurrence_day` (1-31)
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pm_core::config::AppConfig;
|
||||
use sqlx::{FromRow, PgPool};
|
||||
use tokio::time;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::job_executor::process_job;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Internal types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct OpenWindowHost {
|
||||
host_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
struct QueuedJobId {
|
||||
job_id: Uuid,
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public entry point
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Run the maintenance scheduler indefinitely.
|
||||
/// Spawned by `pm-worker/src/main.rs` alongside the job executor.
|
||||
pub async fn run_maintenance_scheduler(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("Maintenance scheduler started");
|
||||
|
||||
// First tick fires immediately; consume it to align with job_executor.
|
||||
let mut ticker = time::interval(std::time::Duration::from_secs(60));
|
||||
ticker.tick().await;
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
tracing::debug!("Maintenance scheduler: checking open windows");
|
||||
dispatch_open_window_jobs(pool.clone(), config.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Core dispatch logic
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Find all hosts with a currently-open maintenance window, then for each,
|
||||
/// find their queued non-immediate job entries and dispatch them.
|
||||
async fn dispatch_open_window_jobs(pool: PgPool, config: Arc<AppConfig>) {
|
||||
// ── 1. Find all host_ids with an open window right now ─────────────────
|
||||
let open_hosts: Vec<OpenWindowHost> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT mw.host_id
|
||||
FROM maintenance_windows mw
|
||||
WHERE mw.enabled = TRUE
|
||||
AND (
|
||||
-- One-time: absolute window
|
||||
( mw.recurrence = 'once'
|
||||
AND mw.start_at <= NOW()
|
||||
AND NOW() < mw.start_at + (mw.duration_minutes * INTERVAL '1 minute')
|
||||
)
|
||||
OR
|
||||
-- Daily: time-of-day slot, any day
|
||||
( mw.recurrence = 'daily'
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
-- Weekly: matching day-of-week + time-of-day slot
|
||||
( mw.recurrence = 'weekly'
|
||||
AND EXTRACT(DOW FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
OR
|
||||
-- Monthly: matching day-of-month + time-of-day slot
|
||||
( mw.recurrence = 'monthly'
|
||||
AND EXTRACT(DAY FROM NOW() AT TIME ZONE 'UTC') = mw.recurrence_day
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time >= (mw.start_at AT TIME ZONE 'UTC')::time
|
||||
AND (NOW() AT TIME ZONE 'UTC')::time < ((mw.start_at AT TIME ZONE 'UTC')::time
|
||||
+ (mw.duration_minutes * INTERVAL '1 minute'))
|
||||
)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(hosts) => hosts,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "dispatch_open_window_jobs: open-hosts query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if open_hosts.is_empty() {
|
||||
tracing::debug!("Maintenance scheduler: no open windows this cycle");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
open_host_count = open_hosts.len(),
|
||||
"Maintenance scheduler: found hosts with open windows"
|
||||
);
|
||||
|
||||
// ── 2. For each open host, find distinct queued non-immediate job IDs ──
|
||||
for host in open_hosts {
|
||||
let job_ids: Vec<QueuedJobId> = match sqlx::query_as(
|
||||
r#"
|
||||
SELECT DISTINCT pjh.job_id
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN patch_jobs j ON j.id = pjh.job_id
|
||||
WHERE pjh.host_id = $1
|
||||
AND pjh.status = 'queued'
|
||||
AND j.immediate = FALSE
|
||||
AND j.status != 'cancelled'
|
||||
AND (pjh.retry_next_at IS NULL OR pjh.retry_next_at <= NOW())
|
||||
"#,
|
||||
)
|
||||
.bind(host.host_id)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
Ok(ids) => ids,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
host_id = %host.host_id,
|
||||
"dispatch_open_window_jobs: queued jobs query failed"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for job in job_ids {
|
||||
tracing::info!(
|
||||
job_id = %job.job_id,
|
||||
host_id = %host.host_id,
|
||||
"Maintenance scheduler: dispatching non-immediate job (window open)"
|
||||
);
|
||||
|
||||
let (p, c) = (pool.clone(), config.clone());
|
||||
let job_id = job.job_id;
|
||||
tokio::spawn(async move {
|
||||
process_job(p, c, job_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
470
crates/pm-worker/src/ws_relay.rs
Normal file
470
crates/pm-worker/src/ws_relay.rs
Normal file
@ -0,0 +1,470 @@
|
||||
//! WS relay — M7
|
||||
//!
|
||||
//! For every running `patch_job_hosts` row that has an `agent_job_id`, open a
|
||||
//! WebSocket to the corresponding agent, stream job-status events, update the
|
||||
//! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS
|
||||
//! handler can forward the event to connected clients.
|
||||
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::StreamExt;
|
||||
use rustls::{
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
ClientConfig as TlsClientConfig,
|
||||
RootCertStore,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_tungstenite::{
|
||||
connect_async_tls_with_config,
|
||||
tungstenite::protocol::Message,
|
||||
Connector,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use pm_agent_client::client::DEFAULT_AGENT_PORT;
|
||||
use pm_core::config::AppConfig;
|
||||
|
||||
// ── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct RunningHostJob {
|
||||
job_id: Uuid,
|
||||
host_id: Uuid,
|
||||
agent_job_id: String,
|
||||
host_address: String,
|
||||
}
|
||||
|
||||
/// JSON event streamed by the agent over its WS endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentWsEvent {
|
||||
#[allow(dead_code)]
|
||||
job_id: String,
|
||||
status: String,
|
||||
output: Option<String>,
|
||||
error: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
progress_percent: Option<u8>,
|
||||
}
|
||||
|
||||
/// Payload broadcast via `pg_notify('job_update', …)`.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NotifyPayload {
|
||||
job_id: String,
|
||||
host_id: String,
|
||||
status: String,
|
||||
output: Option<String>,
|
||||
error_message: Option<String>,
|
||||
agent_job_id: String,
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Long-running task: polls the DB for running host-jobs and spawns a per-pair
|
||||
/// relay task for each one that isn't already being tracked.
|
||||
pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
|
||||
tracing::info!("WS relay task started");
|
||||
|
||||
let active: Arc<Mutex<HashSet<(Uuid, Uuid)>>> = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(10));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let rows = match query_running_jobs(&pool).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: DB poll failed");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for row in rows {
|
||||
let key = (row.job_id, row.host_id);
|
||||
|
||||
// Skip pairs that already have an active relay.
|
||||
if active.lock().await.contains(&key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build the rustls ClientConfig once per connection.
|
||||
let tls_config = match build_tls_config(&config).await {
|
||||
Ok(c) => Arc::new(c),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "ws_relay: TLS config error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
active.lock().await.insert(key);
|
||||
|
||||
let pool_c = pool.clone();
|
||||
let active_c = active.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
agent_job_id = %row.agent_job_id,
|
||||
host = %row.host_address,
|
||||
"WS relay: starting relay"
|
||||
);
|
||||
|
||||
match relay_one_job(&pool_c, &row, tls_config).await {
|
||||
Ok(()) => tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: completed"
|
||||
),
|
||||
Err(e) => tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: ended with error"
|
||||
),
|
||||
}
|
||||
|
||||
active_c.lock().await.remove(&key);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── DB helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn query_running_jobs(pool: &PgPool) -> anyhow::Result<Vec<RunningHostJob>> {
|
||||
sqlx::query_as::<_, RunningHostJob>(
|
||||
r#"
|
||||
SELECT
|
||||
pjh.job_id,
|
||||
pjh.host_id,
|
||||
pjh.agent_job_id,
|
||||
COALESCE(h.fqdn, h.ip_address::text) AS host_address
|
||||
FROM patch_job_hosts pjh
|
||||
JOIN hosts h ON h.id = pjh.host_id
|
||||
WHERE pjh.status = 'running'::job_status
|
||||
AND pjh.agent_job_id IS NOT NULL
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.context("query_running_jobs")
|
||||
}
|
||||
|
||||
// ── TLS ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn build_tls_config(config: &AppConfig) -> anyhow::Result<TlsClientConfig> {
|
||||
let sec = &config.security;
|
||||
|
||||
let cert_pem = tokio::fs::read(&sec.agent_client_cert_path).await
|
||||
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
|
||||
let key_pem = tokio::fs::read(&sec.agent_client_key_path).await
|
||||
.with_context(|| format!("read agent client key '{}'" , sec.agent_client_key_path))?;
|
||||
let ca_pem = tokio::fs::read(&sec.ca_cert_path).await
|
||||
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
|
||||
|
||||
// Parse client certificate chain.
|
||||
let client_certs: Vec<CertificateDer<'static>> = {
|
||||
let mut cur = std::io::Cursor::new(&cert_pem);
|
||||
rustls_pemfile::certs(&mut cur)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.context("parse client cert PEM")?
|
||||
};
|
||||
|
||||
// Parse client private key.
|
||||
let client_key: PrivateKeyDer<'static> = {
|
||||
let mut cur = std::io::Cursor::new(&key_pem);
|
||||
rustls_pemfile::private_key(&mut cur)
|
||||
.context("parse client key PEM")?
|
||||
.context("no private key in PEM")?
|
||||
};
|
||||
|
||||
// Build root store from CA cert.
|
||||
let mut root_store = RootCertStore::empty();
|
||||
{
|
||||
let mut cur = std::io::Cursor::new(&ca_pem);
|
||||
for cert_result in rustls_pemfile::certs(&mut cur) {
|
||||
root_store
|
||||
.add(cert_result.context("read CA cert entry")?)
|
||||
.context("add CA cert to root store")?;
|
||||
}
|
||||
}
|
||||
|
||||
TlsClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_client_auth_cert(client_certs, client_key)
|
||||
.context("build TlsClientConfig")
|
||||
}
|
||||
|
||||
// ── Per-job relay ─────────────────────────────────────────────────────────────
|
||||
|
||||
async fn relay_one_job(
|
||||
pool: &PgPool,
|
||||
row: &RunningHostJob,
|
||||
tls_config: Arc<TlsClientConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let url = format!(
|
||||
"wss://{}:{}/api/v1/ws/jobs",
|
||||
row.host_address, DEFAULT_AGENT_PORT,
|
||||
);
|
||||
|
||||
let (ws_stream, _) = connect_async_tls_with_config(
|
||||
url.as_str(),
|
||||
None,
|
||||
false,
|
||||
Some(Connector::Rustls(tls_config)),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("connect agent WS {url}"))?;
|
||||
|
||||
let (_sink, mut stream) = ws_stream.split();
|
||||
|
||||
while let Some(frame) = stream.next().await {
|
||||
let frame = match frame {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: stream error"
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match frame {
|
||||
Message::Text(t) => t.to_string(),
|
||||
Message::Binary(b) => String::from_utf8(b.into()).unwrap_or_default(),
|
||||
Message::Close(_) => {
|
||||
tracing::info!(job_id = %row.job_id, "Agent WS closed cleanly");
|
||||
break;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event: AgentWsEvent = match serde_json::from_str(&text) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e, raw = %text,
|
||||
"WS relay: unparseable agent frame"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
process_event(pool, row, &event).await;
|
||||
|
||||
if matches!(event.status.as_str(), "succeeded" | "failed" | "cancelled") {
|
||||
tracing::info!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %event.status,
|
||||
"WS relay: terminal state — stopping"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Event processing ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) {
|
||||
// Map agent status string to DB job_status enum value.
|
||||
let db_status = match event.status.as_str() {
|
||||
"running" => "running",
|
||||
"succeeded" => "succeeded",
|
||||
"failed" => "failed",
|
||||
"cancelled" => "cancelled",
|
||||
other => {
|
||||
tracing::warn!(status = %other, "WS relay: unknown agent status");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let output = event.output.as_deref().unwrap_or("");
|
||||
let error_msg = event.error.as_deref();
|
||||
|
||||
// Determine timestamps based on terminal state.
|
||||
let is_terminal = matches!(db_status, "succeeded" | "failed" | "cancelled");
|
||||
|
||||
// Update the DB row.
|
||||
let update_result = if is_terminal {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = $1::job_status,
|
||||
output = CASE WHEN $2 != '' THEN $2 ELSE output END,
|
||||
error_message = $3,
|
||||
completed_at = NOW()
|
||||
WHERE job_id = $4
|
||||
AND host_id = $5
|
||||
"#,
|
||||
)
|
||||
.bind(db_status)
|
||||
.bind(output)
|
||||
.bind(error_msg)
|
||||
.bind(row.job_id)
|
||||
.bind(row.host_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE patch_job_hosts
|
||||
SET status = $1::job_status,
|
||||
output = CASE WHEN $2 != '' THEN $2 ELSE output END
|
||||
WHERE job_id = $3
|
||||
AND host_id = $4
|
||||
"#,
|
||||
)
|
||||
.bind(db_status)
|
||||
.bind(output)
|
||||
.bind(row.job_id)
|
||||
.bind(row.host_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
};
|
||||
|
||||
if let Err(e) = update_result {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: DB update failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Also update the parent patch_jobs status when the host-level job reaches
|
||||
// a terminal state: running → if all hosts terminal then update parent.
|
||||
if is_terminal {
|
||||
update_parent_job_status(pool, row.job_id).await;
|
||||
}
|
||||
|
||||
// Fire pg_notify so browser WS handlers forward the event.
|
||||
let payload = NotifyPayload {
|
||||
job_id: row.job_id.to_string(),
|
||||
host_id: row.host_id.to_string(),
|
||||
status: db_status.to_string(),
|
||||
output: event.output.clone(),
|
||||
error_message: event.error.clone(),
|
||||
agent_job_id: row.agent_job_id.clone(),
|
||||
};
|
||||
|
||||
let payload_json = match serde_json::to_string(&payload) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "WS relay: failed to serialize notify payload");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query("SELECT pg_notify('job_update', $1)")
|
||||
.bind(&payload_json)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
"WS relay: pg_notify failed"
|
||||
);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
job_id = %row.job_id,
|
||||
host_id = %row.host_id,
|
||||
status = %db_status,
|
||||
"WS relay: pg_notify sent"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Parent job status rollup ──────────────────────────────────────────────────
|
||||
|
||||
/// After a host-level job reaches a terminal state, check whether ALL hosts for
|
||||
/// that job are now terminal and update the parent `patch_jobs` row accordingly.
|
||||
async fn update_parent_job_status(pool: &PgPool, job_id: Uuid) {
|
||||
// Count hosts that are still in a non-terminal state.
|
||||
let pending: i64 = match sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT COUNT(*)
|
||||
FROM patch_job_hosts
|
||||
WHERE job_id = $1
|
||||
AND status NOT IN (
|
||||
'succeeded'::job_status,
|
||||
'failed'::job_status,
|
||||
'cancelled'::job_status
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: count query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if pending > 0 {
|
||||
return; // still hosts running — parent stays running
|
||||
}
|
||||
|
||||
// All hosts terminal — determine final parent status.
|
||||
let failed_count: i64 = match sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM patch_job_hosts WHERE job_id = $1 AND status = 'failed'::job_status",
|
||||
)
|
||||
.bind(job_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, %job_id, "update_parent_job_status: failed-count query failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let final_status = if failed_count > 0 { "failed" } else { "succeeded" };
|
||||
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE patch_jobs SET status = $1::job_status, completed_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(final_status)
|
||||
.bind(job_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"update_parent_job_status: UPDATE failed"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
%job_id,
|
||||
status = %final_status,
|
||||
"Parent job status updated"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user