Private
Public Access
1
0

fix: add ALPN http/1.1 for WebSocket, polling fallback, and job-level WS events
Some checks failed
CI Pipeline / Rust Format Check (push) Failing after 19s
CI Pipeline / Clippy Lints (push) Successful in 46s
CI Pipeline / Rust Unit Tests (push) Successful in 1m30s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 1m11s
CI Pipeline / Build .deb & Release (push) Has been skipped

- ws_relay.rs: Add ALPN protocol http/1.1 to rustls ClientConfig to prevent
  HTTP/2 negotiation which breaks WebSocket upgrades (Sec-WebSocket-Accept mismatch)
- ws_relay.rs: Add detailed TLS error chain logging for debugging connection failures
- ws_relay.rs: Add HTTP polling fallback when WebSocket connection fails, using
  AgentClient to poll /api/v1/jobs/{id} every ws_relay_poll_interval_secs
- config.rs: Add ws_relay_poll_interval_secs field (default: 10 seconds)
- config.example.toml: Add ws_relay_poll_interval_secs documentation
- jobs.rs: Fire pg_notify with event_type job on cancel
- job_executor.rs: Fire pg_notify with event_type job when parent job transitions
- ws_relay.rs: Add event_type field to NotifyPayload (host vs job events)
- Frontend: Add event_type, succeeded_count, failed_count, host_count to JobWsEvent
- Frontend: handleWsEvent distinguishes host vs job events for accurate status updates
This commit is contained in:
2026-05-04 15:16:20 +00:00
parent 177a608b97
commit e3a27eb2ed
4 changed files with 246 additions and 293 deletions

View File

@ -5,7 +5,7 @@
//! DB row, and fire `pg_notify('job_update', payload_json)` so the browser WS
//! handler can forward the event to connected clients.
use std::{collections::HashSet, sync::Arc, time::Duration};
use std::{collections::HashSet, error::Error, sync::Arc, time::Duration};
use anyhow::Context;
use futures::StreamExt;
@ -21,6 +21,7 @@ use uuid::Uuid;
use pm_agent_client::client::DEFAULT_AGENT_PORT;
use pm_core::config::AppConfig;
use pm_agent_client::client::AgentClient;
// ── Types ─────────────────────────────────────────────────────────────────────
@ -60,6 +61,38 @@ struct NotifyPayload {
host_count: Option<i64>,
}
// ── Cert PEM bytes for building AgentClient ────────────────────────────────────
/// Raw PEM bytes read from the security config cert paths.
/// Used to build an [`AgentClient`] for HTTP polling fallback.
#[derive(Clone)]
struct CertPems {
client_cert: Vec<u8>,
client_key: Vec<u8>,
ca_cert: Vec<u8>,
}
/// Read the three PEM files referenced by the security config.
/// Mirrors the file reads in [`build_tls_config`] but returns raw bytes instead of
/// parsing them into rustls types.
async fn read_cert_pems(config: &AppConfig) -> anyhow::Result<CertPems> {
let sec = &config.security;
let client_cert = tokio::fs::read(&sec.agent_client_cert_path)
.await
.with_context(|| format!("read agent client cert '{}'", sec.agent_client_cert_path))?;
let client_key = tokio::fs::read(&sec.agent_client_key_path)
.await
.with_context(|| format!("read agent client key '{}'", sec.agent_client_key_path))?;
let ca_cert = tokio::fs::read(&sec.ca_cert_path)
.await
.with_context(|| format!("read CA cert '{}'", sec.ca_cert_path))?;
Ok(CertPems {
client_cert,
client_key,
ca_cert,
})
}
// ── Entry point ───────────────────────────────────────────────────────────────
/// Long-running task: polls the DB for running host-jobs and spawns a per-pair
@ -100,6 +133,17 @@ pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
},
};
// Read raw cert PEM bytes for HTTP polling fallback.
let cert_pems = match read_cert_pems(&config).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "ws_relay: cert PEM read error");
continue;
},
};
let poll_interval = config.worker.ws_relay_poll_interval_secs;
active.lock().await.insert(key);
let pool_c = pool.clone();
@ -114,7 +158,7 @@ pub async fn run_ws_relay(pool: PgPool, config: Arc<AppConfig>) {
"WS relay: starting relay"
);
match relay_one_job(&pool_c, &row, tls_config).await {
match relay_one_job(&pool_c, &row, tls_config, &cert_pems, poll_interval).await {
Ok(()) => tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
@ -197,32 +241,68 @@ async fn build_tls_config(config: &AppConfig) -> anyhow::Result<TlsClientConfig>
}
}
TlsClientConfig::builder()
let mut config = TlsClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(client_certs, client_key)
.context("build TlsClientConfig")
.context("build TlsClientConfig")?;
// WebSocket requires HTTP/1.1 — without ALPN the server may negotiate
// h2 (HTTP/2), which breaks the WebSocket upgrade handshake
// ("Key mismatch in Sec-WebSocket-Accept header").
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(config)
}
// ── Per-job relay ─────────────────────────────────────────────────────────────
async fn relay_one_job(
pool: &PgPool,
row: &RunningHostJob,
tls_config: Arc<TlsClientConfig>,
cert_pems: &CertPems,
poll_interval_secs: u64,
) -> anyhow::Result<()> {
let url = format!(
"wss://{}:{}/api/v1/ws/jobs",
row.host_address, DEFAULT_AGENT_PORT,
);
let (ws_stream, _) = connect_async_tls_with_config(
let (ws_stream, _) = match connect_async_tls_with_config(
url.as_str(),
None,
false,
Some(Connector::Rustls(tls_config)),
)
.await
.with_context(|| format!("connect agent WS {url}"))?;
{
Ok(ws) => ws,
Err(e) => {
// Log the full error chain for TLS debugging
let mut source = e.source();
let mut depth = 0;
while let Some(err) = source {
tracing::warn!(
job_id = %row.job_id,
host_id = %row.host_id,
depth,
error = %err,
"WS relay: TLS connection error detail"
);
source = err.source();
depth += 1;
}
// Fall back to HTTP polling instead of returning an error.
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
host = %row.host_address,
"WS relay: WebSocket connection failed, falling back to HTTP polling"
);
return relay_one_job_poll(pool, row, cert_pems, poll_interval_secs).await;
},
};
let (_sink, mut stream) = ws_stream.split();
@ -281,6 +361,110 @@ async fn relay_one_job(
Ok(())
}
// ── Per-job HTTP polling fallback ─────────────────────────────────────────────
/// Fall back to HTTP polling when the WebSocket connection fails.
///
/// Builds an [`AgentClient`] from the same cert paths used for TLS, polls
/// `GET /api/v1/jobs/{id}` every `poll_interval_secs`, and calls
/// [`process_event`] whenever the status changes from the previous poll.
/// Stops when the job reaches a terminal state (succeeded/failed/cancelled).
async fn relay_one_job_poll(
pool: &PgPool,
row: &RunningHostJob,
cert_pems: &CertPems,
poll_interval_secs: u64,
) -> anyhow::Result<()> {
// Build the HTTP client using the same mTLS certs.
let agent_client = AgentClient::new(
&row.host_address,
DEFAULT_AGENT_PORT,
&cert_pems.client_cert,
&cert_pems.client_key,
&cert_pems.ca_cert,
)
.context("build AgentClient for polling fallback")?;
let mut last_status: Option<String> = None;
let poll_interval = Duration::from_secs(poll_interval_secs);
loop {
tokio::time::sleep(poll_interval).await;
let job_status = match agent_client.job_status(&row.agent_job_id).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
error = %e,
"WS relay poll: job_status request failed, will retry"
);
continue;
},
};
// Map agent status to the WS event format.
// The agent uses "completed" but pm-worker expects "succeeded".
// The agent uses "queued" but pm-worker expects "running".
let mapped_status = match job_status.status.as_str() {
"queued" => "running",
"running" => "running",
"succeeded" => "succeeded",
"completed" => "succeeded",
"failed" => "failed",
"cancelled" => "cancelled",
other => {
tracing::warn!(
status = %other,
job_id = %row.job_id,
"WS relay poll: unknown agent status, treating as running"
);
"running"
},
};
tracing::debug!(
job_id = %row.job_id,
host_id = %row.host_id,
agent_status = %job_status.status,
mapped_status = %mapped_status,
progress = ?job_status.progress_percent,
"WS relay poll: fetched job status"
);
// Only process when the status has changed since the last poll.
if last_status.as_deref() == Some(mapped_status) {
continue;
}
last_status = Some(mapped_status.to_string());
let event = AgentWsEvent {
job_id: job_status.job_id.clone(),
status: mapped_status.to_string(),
output: job_status.output.clone(),
error: job_status.error.clone(),
progress_percent: job_status.progress_percent,
};
process_event(pool, row, &event).await;
// Stop polling on terminal states.
if matches!(mapped_status, "succeeded" | "failed" | "cancelled") {
tracing::info!(
job_id = %row.job_id,
host_id = %row.host_id,
status = %mapped_status,
"WS relay poll: terminal state — stopping"
);
break;
}
}
Ok(())
}
// ── Event processing ──────────────────────────────────────────────────────────
async fn process_event(pool: &PgPool, row: &RunningHostJob, event: &AgentWsEvent) {