Private
Public Access
1
0

fix(ws): add Origin allowlist to browser WebSocket upgrade (CSWSH hardening)

Closes Draco-Lunaris/Linux-Patch-Manager#10

The browser WebSocket endpoint at GET /api/v1/ws/jobs previously
authenticated solely via a single-use, 60-second ticket passed as a query
parameter. A leaked ticket (browser history, Referer, proxy logs, support
bundles) could be redeemed from any origin, enabling Cross-Site WebSocket
Hijacking (CSWSH).

This change adds a second gate: the Origin header must match an explicit
allowlist. The check runs BEFORE ticket validation so that rejected
cross-origin probes do not consume the legitimate users ticket.

Changes:
- pm-core: new security.allowed_origins config field; default derived
  from sso_callback_url; startup warning if both are unparseable
- pm-web: ws_handler extracts HeaderMap and calls check_origin first;
  returns 403 on missing/malformed/disallowed origins
- config: documented allowed_origins key in config.example.toml
- docs: security-review.md section 1.4 (WebSocket Origin Allowlist)
- tests: 40 unit tests (7 pm-core, 33 pm-web)
This commit is contained in:
Draco Lunaris
2026-06-02 10:45:38 -05:00
parent 80709d48a7
commit ed5df26140
8 changed files with 925 additions and 12 deletions

View File

@ -140,6 +140,73 @@ pub struct SecurityConfig {
/// Frontend URL to redirect to after SSO callback (default: http://localhost:5173/auth/sso/callback)
#[serde(default = "default_sso_callback_url")]
pub sso_callback_url: String,
/// Allowlist of browser `Origin` values permitted to open the
/// `/api/v1/ws/jobs` WebSocket upgrade. Entries are exact
/// `scheme://host[:port]` strings. If left empty in the TOML file, the
/// server derives the default from `sso_callback_url` at load time
/// (see [`derive_allowed_origins`]).
#[serde(default)]
pub allowed_origins: Vec<String>,
}
/// Derive a default `Origin` allowlist from a single SSO callback URL.
///
/// Parses `scheme://host[:port][/path]` and returns a single-element vector
/// containing `scheme://host[:port]` (with default ports normalized away —
/// e.g. `https://x:443` becomes `https://x`). Returns an empty vector if the
/// URL is unparseable; callers should log a warning in that case because the
/// WebSocket endpoint will reject all browser upgrades (fail-closed).
///
/// Exposed publicly so tests and the handler can share the same parser.
pub fn derive_allowed_origins(sso_callback_url: &str) -> Vec<String> {
let s = sso_callback_url.trim().trim_end_matches('/');
let (scheme, rest) = match s.split_once("://") {
Some(parts) if !parts.0.is_empty() => parts,
_ => return vec![],
};
let scheme_lower = scheme.to_ascii_lowercase();
if scheme_lower != "http" && scheme_lower != "https" {
return vec![];
}
// Authority is everything up to the first `/`, `?`, or `#`.
let authority_end = rest
.find(['/', '?', '#'])
.unwrap_or(rest.len());
let authority = &rest[..authority_end];
if authority.is_empty() {
return vec![];
}
// Split host:port. We treat the LAST `:` as the port separator. IPv6
// literal hosts (e.g. `[::1]`) contain a `:` inside the brackets; we
// explicitly do not support IPv6 in sso_callback_url and return empty
// for those to be safe.
let (host, port_str) = match authority.rsplit_once(':') {
Some((h, _)) if h.contains(':') => return vec![],
Some((h, p)) => (h, Some(p)),
None => (authority, None),
};
let host = host.trim();
if host.is_empty() || host.contains(char::is_whitespace) || host.contains(':') {
return vec![];
}
let default_port: Option<u16> = match scheme_lower.as_str() {
"https" => Some(443),
"http" => Some(80),
_ => None,
};
let port_num = match port_str {
Some(p) => match p.parse::<u16>() {
Ok(n) => Some(n),
Err(_) => return vec![],
},
None => None,
};
let origin = match (port_num, default_port) {
(Some(p), Some(d)) if p == d => format!("{}://{}", scheme_lower, host),
(Some(p), _) => format!("{}://{}:{}", scheme_lower, host, p),
(None, _) => format!("{}://{}", scheme_lower, host),
};
vec![origin]
}
impl AppConfig {
@ -147,6 +214,11 @@ impl AppConfig {
///
/// Environment variables follow the pattern: `PATCH_MANAGER__SECTION__KEY`
/// e.g. `PATCH_MANAGER__DATABASE__URL=postgres://...`
///
/// After deserialization, if `security.allowed_origins` is empty, it is
/// derived from `security.sso_callback_url`. A `tracing::warn!` is emitted
/// when the resulting allowlist is empty (the WS endpoint will reject all
/// browser upgrades in that case).
pub fn load(config_path: &str) -> Result<Self, ConfigError> {
let cfg = Config::builder()
.add_source(File::with_name(config_path).required(false))
@ -157,7 +229,20 @@ impl AppConfig {
)
.build()?;
cfg.try_deserialize()
let mut config: Self = cfg.try_deserialize()?;
if config.security.allowed_origins.is_empty() {
config.security.allowed_origins =
derive_allowed_origins(&config.security.sso_callback_url);
}
if config.security.allowed_origins.is_empty() {
tracing::warn!(
sso_callback_url = %config.security.sso_callback_url,
"security.allowed_origins is empty and could not be derived \
from sso_callback_url; the WebSocket endpoint will reject all \
browser upgrades"
);
}
Ok(config)
}
}
@ -207,8 +292,69 @@ impl Default for AppConfig {
web_tls_cert_path: "/etc/patch-manager/tls/web.crt".to_string(),
web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(),
sso_callback_url: default_sso_callback_url(),
allowed_origins: derive_allowed_origins(&default_sso_callback_url()),
},
rate_limit: RateLimitConfig::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_strips_default_https_port() {
assert_eq!(
derive_allowed_origins("https://app.example.com:443/auth/sso/callback"),
vec!["https://app.example.com".to_string()]
);
}
#[test]
fn derive_keeps_non_default_port() {
assert_eq!(
derive_allowed_origins("https://app.example.com:8443/auth/sso/callback"),
vec!["https://app.example.com:8443".to_string()]
);
}
#[test]
fn derive_strips_default_http_port() {
assert_eq!(
derive_allowed_origins("http://localhost:80/x"),
vec!["http://localhost".to_string()]
);
}
#[test]
fn derive_handles_trailing_slash() {
assert_eq!(
derive_allowed_origins("https://app.example.com/"),
vec!["https://app.example.com".to_string()]
);
}
#[test]
fn derive_handles_no_path() {
assert_eq!(
derive_allowed_origins("https://app.example.com"),
vec!["https://app.example.com".to_string()]
);
}
#[test]
fn derive_returns_empty_for_garbage() {
assert!(derive_allowed_origins("not a url").is_empty());
assert!(derive_allowed_origins("").is_empty());
assert!(derive_allowed_origins("ftp://x").is_empty());
}
#[test]
fn derive_lowercases_scheme() {
assert_eq!(
derive_allowed_origins("HTTPS://App.Example.com"),
vec!["https://App.Example.com".to_string()]
);
}
}

457
crates/pm-web/src/routes/ws.rs Executable file → Normal file
View File

@ -6,7 +6,7 @@
use axum::{
extract::ws::{Message, WebSocket},
extract::{Query, State, WebSocketUpgrade},
http::StatusCode,
http::{HeaderMap, StatusCode},
response::{Json, Response},
routing::{get, post},
Router,
@ -57,6 +57,162 @@ fn err(
)
}
// ── Origin parsing & allowlist matching ───────────────────────────────────────
/// Parsed browser `Origin` header value.
#[derive(Debug, Clone, PartialEq, Eq)]
struct Origin {
scheme: String,
host: String,
/// `None` means "use scheme default" (80 for http, 443 for https).
port: Option<u16>,
}
impl Origin {
/// Render back to canonical `scheme://host[:port]` form with default
/// ports normalized away (so `https://x:443` becomes `https://x`).
fn canonical(&self) -> String {
let default_port: Option<u16> = match self.scheme.as_str() {
"https" => Some(443),
"http" => Some(80),
_ => None,
};
match (self.port, default_port) {
(Some(p), Some(d)) if p == d => format!("{}://{}", self.scheme, self.host),
(Some(p), _) => format!("{}://{}:{}", self.scheme, self.host, p),
(None, _) => format!("{}://{}", self.scheme, self.host),
}
}
}
/// Parse a raw `Origin` header value. Returns `None` for missing scheme,
/// unsupported schemes (only `http`/`https`), empty host, or whitespace in
/// the host. IPv6 literal hosts are explicitly rejected to keep the parser
/// simple — WebSocket connections from IPv6 browser origins are not a
/// realistic deployment for this product.
fn parse_origin_header(value: &str) -> Option<Origin> {
let s = value.trim().trim_end_matches('/');
if s.is_empty() {
return None;
}
let (scheme, rest) = s.split_once("://")?;
let scheme = scheme.to_ascii_lowercase();
if scheme != "http" && scheme != "https" {
return None;
}
// Authority is everything up to the first `/`, `?`, or `#`.
let authority_end = rest
.find(['/', '?', '#'])
.unwrap_or(rest.len());
let authority = &rest[..authority_end];
if authority.is_empty() {
return None;
}
// Treat the LAST `:` as the port separator. IPv6 literal hosts (e.g.
// `[::1]`) contain a `:` inside the brackets; reject those.
let (host, port_str) = match authority.rsplit_once(':') {
Some((h, _)) if h.contains(':') => return None,
Some((h, p)) => (h, Some(p)),
None => (authority, None),
};
let host = host.trim();
if host.is_empty() || host.contains(char::is_whitespace) || host.contains(':') {
return None;
}
let port = match port_str {
Some(p) => match p.parse::<u16>() {
Ok(n) => Some(n),
Err(_) => return None,
},
None => None,
};
Some(Origin {
scheme,
host: host.to_ascii_lowercase(),
port,
})
}
/// Match a parsed `Origin` against an allowlist. Each allowlist entry is
/// itself parsed with [`parse_origin_header`] and compared by its canonical
/// string form, so entry syntax is forgiving (`https://x:443` matches an
/// incoming `https://x`). The host comparison is case-insensitive (the
/// parser lowercases the host); scheme and port are exact.
///
/// An empty allowlist returns `false` (fail-closed).
fn is_origin_allowed(origin: &Origin, allowlist: &[String]) -> bool {
if allowlist.is_empty() {
return false;
}
let incoming = origin.canonical();
allowlist.iter().any(|entry| {
match parse_origin_header(entry) {
Some(parsed) => parsed.canonical() == incoming,
None => false,
}
})
}
/// Read the `Origin` header from a request and check it against the
/// configured allowlist. Returns `Ok(())` when the request may proceed; on
/// rejection returns the appropriate `(StatusCode, Json)` error tuple and
/// the reason string (for logging).
fn check_origin(
headers: &HeaderMap,
allowlist: &[String],
) -> Result<(), ((StatusCode, Json<Value>), &'static str)> {
let raw = match headers.get(axum::http::header::ORIGIN) {
Some(v) => v,
None => {
return Err((
err(
StatusCode::FORBIDDEN,
"forbidden_origin",
"Origin header required",
),
"missing",
));
}
};
let raw_str = match raw.to_str() {
Ok(s) => s,
Err(_) => {
return Err((
err(
StatusCode::FORBIDDEN,
"forbidden_origin",
"Origin header not valid ASCII",
),
"non-ascii",
));
}
};
let origin = match parse_origin_header(raw_str) {
Some(o) => o,
None => {
return Err((
err(
StatusCode::FORBIDDEN,
"forbidden_origin",
"Malformed Origin header",
),
"malformed",
));
}
};
if !is_origin_allowed(&origin, allowlist) {
return Err((
err(
StatusCode::FORBIDDEN,
"forbidden_origin",
"Origin not allowed",
),
"not-allowlisted",
));
}
Ok(())
}
// ── POST /api/v1/ws/ticket ────────────────────────────────────────────────────
/// Issue a single-use WebSocket authentication ticket (60 s expiry).
@ -93,11 +249,40 @@ pub struct WsQuery {
}
/// Browser WebSocket upgrade endpoint — authenticates via single-use ticket.
///
/// The handler enforces two independent gates, in this order:
///
/// 1. `Origin` header allowlist (CSWSH defense-in-depth). Performed first so
/// that a cross-origin probe with a leaked/stolen ticket does not consume
/// the legitimate user's ticket.
/// 2. Single-use, 60-second ticket (existing behavior, unchanged).
pub async fn ws_handler(
State(state): State<AppState>,
headers: HeaderMap,
Query(q): Query<WsQuery>,
ws: WebSocketUpgrade,
) -> Result<Response, (StatusCode, Json<Value>)> {
// Gate 1: Origin allowlist (CSWSH defense-in-depth).
let allowlist = &state.config.security.allowed_origins;
if let Err((http_err, reason)) = check_origin(&headers, allowlist) {
let raw_origin = headers
.get(axum::http::header::ORIGIN)
.and_then(|v| v.to_str().ok())
.unwrap_or("<absent>");
// Never log the ticket value.
tracing::warn!(
reason = reason,
origin = %raw_origin,
"WebSocket upgrade rejected: forbidden origin"
);
return Err(http_err);
}
let allowed_origin = headers
.get(axum::http::header::ORIGIN)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
// Validate and consume the ticket atomically.
let ticket = {
let entry = state.ws_tickets.get(&q.ticket);
@ -129,6 +314,7 @@ pub async fn ws_handler(
tracing::info!(
user_id = %ticket.user_id,
role = %ticket.role,
origin = %allowed_origin,
"Browser WebSocket connection upgraded"
);
@ -203,3 +389,272 @@ async fn handle_browser_ws(mut socket: WebSocket, db: sqlx::PgPool, ticket: WsTi
tracing::info!(user_id = %ticket.user_id, "Browser WS handler exiting");
}
// ── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
// ── parse_origin_header ─────────────────────────────────────────────────
#[test]
fn parse_basic_https() {
assert_eq!(
parse_origin_header("https://app.example.com"),
Some(Origin {
scheme: "https".into(),
host: "app.example.com".into(),
port: None,
})
);
}
#[test]
fn parse_with_explicit_port() {
assert_eq!(
parse_origin_header("https://app.example.com:8443"),
Some(Origin {
scheme: "https".into(),
host: "app.example.com".into(),
port: Some(8443),
})
);
}
#[test]
fn parse_lowercases_scheme() {
assert_eq!(
parse_origin_header("HTTPS://App.Example.com").unwrap().scheme,
"https"
);
}
#[test]
fn parse_lowercases_host() {
assert_eq!(
parse_origin_header("https://App.Example.com").unwrap().host,
"app.example.com"
);
}
#[test]
fn parse_ignores_path_query_fragment() {
let o = parse_origin_header("https://app.example.com:443/some/path?q=1#frag").unwrap();
assert_eq!(o.host, "app.example.com");
assert_eq!(o.port, Some(443));
}
#[test]
fn parse_strips_trailing_slash() {
assert_eq!(
parse_origin_header("https://app.example.com/"),
Some(Origin {
scheme: "https".into(),
host: "app.example.com".into(),
port: None,
})
);
}
#[test]
fn parse_rejects_empty() {
assert!(parse_origin_header("").is_none());
assert!(parse_origin_header(" ").is_none());
}
#[test]
fn parse_rejects_unsupported_scheme() {
assert!(parse_origin_header("ftp://x").is_none());
assert!(parse_origin_header("file:///etc/passwd").is_none());
assert!(parse_origin_header("javascript:alert(1)").is_none());
}
#[test]
fn parse_rejects_empty_host() {
assert!(parse_origin_header("https://").is_none());
assert!(parse_origin_header("https:///path").is_none());
}
#[test]
fn parse_rejects_host_with_whitespace() {
assert!(parse_origin_header("https://bad host").is_none());
}
#[test]
fn parse_rejects_malformed_port() {
assert!(parse_origin_header("https://x:notaport").is_none());
assert!(parse_origin_header("https://x:99999").is_none());
}
#[test]
fn parse_rejects_ipv6_literal() {
assert!(parse_origin_header("https://[::1]").is_none());
}
#[test]
fn parse_rejects_no_scheme_separator() {
assert!(parse_origin_header("app.example.com").is_none());
}
// ── canonical ──────────────────────────────────────────────────────────
#[test]
fn canonical_strips_default_https_port() {
let o = Origin {
scheme: "https".into(),
host: "x".into(),
port: Some(443),
};
assert_eq!(o.canonical(), "https://x");
}
#[test]
fn canonical_strips_default_http_port() {
let o = Origin {
scheme: "http".into(),
host: "x".into(),
port: Some(80),
};
assert_eq!(o.canonical(), "http://x");
}
#[test]
fn canonical_keeps_non_default_port() {
let o = Origin {
scheme: "https".into(),
host: "x".into(),
port: Some(8443),
};
assert_eq!(o.canonical(), "https://x:8443");
}
// ── is_origin_allowed ──────────────────────────────────────────────────
#[test]
fn allowed_exact_match() {
let o = parse_origin_header("https://app.example.com").unwrap();
assert!(is_origin_allowed(&o, &["https://app.example.com".into()]));
}
#[test]
fn allowed_default_port_normalization_incoming() {
let o = parse_origin_header("https://app.example.com:443").unwrap();
assert!(is_origin_allowed(&o, &["https://app.example.com".into()]));
}
#[test]
fn allowed_default_port_normalization_allowlist() {
let o = parse_origin_header("https://app.example.com").unwrap();
assert!(is_origin_allowed(&o, &["https://app.example.com:443".into()]));
}
#[test]
fn allowed_case_insensitive_host() {
let o = parse_origin_header("https://App.Example.com").unwrap();
assert!(is_origin_allowed(&o, &["https://app.example.com".into()]));
}
#[test]
fn rejected_different_host() {
let o = parse_origin_header("https://evil.example").unwrap();
assert!(!is_origin_allowed(&o, &["https://app.example.com".into()]));
}
#[test]
fn rejected_different_scheme() {
let o = parse_origin_header("http://app.example.com").unwrap();
assert!(!is_origin_allowed(&o, &["https://app.example.com".into()]));
}
#[test]
fn rejected_different_port() {
let o = parse_origin_header("https://app.example.com:8443").unwrap();
assert!(!is_origin_allowed(&o, &["https://app.example.com".into()]));
}
#[test]
fn rejected_empty_allowlist() {
let o = parse_origin_header("https://app.example.com").unwrap();
assert!(!is_origin_allowed(&o, &[]));
}
#[test]
fn rejected_garbage_in_allowlist() {
let o = parse_origin_header("https://app.example.com").unwrap();
assert!(!is_origin_allowed(&o, &["not a url".into()]));
}
#[test]
fn allowed_multi_entry_allowlist() {
let o = parse_origin_header("https://app.example.com").unwrap();
assert!(is_origin_allowed(
&o,
&[
"https://other.example".into(),
"https://app.example.com".into(),
]
));
}
// ── check_origin (integration of parse + allow) ────────────────────────
#[test]
fn check_rejects_missing_header() {
let h = HeaderMap::new();
let err = check_origin(&h, &["https://app.example.com".into()]).unwrap_err();
assert_eq!(err.0 .0, StatusCode::FORBIDDEN);
assert_eq!(err.1, "missing");
}
#[test]
fn check_rejects_malformed_header() {
let mut h = HeaderMap::new();
h.insert(axum::http::header::ORIGIN, "not a url".parse().unwrap());
let err = check_origin(&h, &["https://app.example.com".into()]).unwrap_err();
assert_eq!(err.0 .0, StatusCode::FORBIDDEN);
assert_eq!(err.1, "malformed");
}
#[test]
fn check_rejects_disallowed_origin() {
let mut h = HeaderMap::new();
h.insert(axum::http::header::ORIGIN, "https://evil.example".parse().unwrap());
let err = check_origin(&h, &["https://app.example.com".into()]).unwrap_err();
assert_eq!(err.0 .0, StatusCode::FORBIDDEN);
assert_eq!(err.1, "not-allowlisted");
}
#[test]
fn check_rejects_empty_allowlist() {
let mut h = HeaderMap::new();
h.insert(axum::http::header::ORIGIN, "https://app.example.com".parse().unwrap());
let err = check_origin(&h, &[]).unwrap_err();
assert_eq!(err.0 .0, StatusCode::FORBIDDEN);
assert_eq!(err.1, "not-allowlisted");
}
#[test]
fn check_allows_valid_origin() {
let mut h = HeaderMap::new();
h.insert(axum::http::header::ORIGIN, "https://app.example.com".parse().unwrap());
assert!(check_origin(&h, &["https://app.example.com".into()]).is_ok());
}
#[test]
fn check_allows_default_port_normalization() {
let mut h = HeaderMap::new();
h.insert(
axum::http::header::ORIGIN,
"https://app.example.com:443".parse().unwrap(),
);
assert!(check_origin(&h, &["https://app.example.com".into()]).is_ok());
}
#[test]
fn check_allows_case_insensitive_host() {
let mut h = HeaderMap::new();
h.insert(axum::http::header::ORIGIN, "https://App.Example.com".parse().unwrap());
assert!(check_origin(&h, &["https://app.example.com".into()]).is_ok());
}
}