From 59794bc8f2bad534564f588b3dabc9c3ae651cf4 Mon Sep 17 00:00:00 2001 From: Echo Date: Thu, 21 May 2026 02:27:10 +0000 Subject: [PATCH] fix: replace broken DashMap rate limiting with tower-governor middleware - Replace custom DashMap rate limiting in enrollment.rs that fell back to 0.0.0.0 when X-Forwarded-For was missing, causing ALL enrollment traffic to share a single global rate limit bucket - Use tower_governor with SmartIpKeyExtractor for proper per-IP rate limiting that respects X-Forwarded-For headers (critical behind HAProxy) - Add three configurable rate limit tiers via config.toml: * Enrollment: 5 req/min per IP, burst 3 (strict) * Auth: 20 req/min per IP, burst 10 (moderate) * API: 120 req/min per IP, burst 30 (normal) - Remove enrollment_rate_limits from AppState and cleanup task - Remove manual rate limit code from enrollment.rs (headers param, IP extraction) - Add into_make_service_with_connect_info for ConnectInfo fallback - Add RateLimitConfig to AppConfig with sensible defaults Fixes: #1 --- Cargo.lock | 248 ++++++++++++++++++++++++- Cargo.toml | 4 + config/config.example.toml | 17 ++ crates/pm-core/src/config.rs | 58 ++++++ crates/pm-web/Cargo.toml | 2 + crates/pm-web/src/main.rs | 106 +++++++---- crates/pm-web/src/routes/enrollment.rs | 39 +--- 7 files changed, 395 insertions(+), 79 deletions(-) mode change 100755 => 100644 Cargo.toml mode change 100755 => 100644 crates/pm-core/src/config.rs mode change 100755 => 100644 crates/pm-web/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 6c1e352..b42fabc 100755 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,6 +687,19 @@ dependencies = [ "cipher", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -771,7 +784,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -885,7 +898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -970,6 +983,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "font-kit" version = "0.14.3" @@ -1046,6 +1065,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror 1.0.69", +] + [[package]] name = "freetype-sys" version = "0.20.1" @@ -1155,6 +1184,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +[[package]] +name = "futures-timer" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af43fadb8a98512d547e37b4e92e0ced13e205c061b87b4623eff01d918d6968" + [[package]] name = "futures-util" version = "0.3.32" @@ -1242,6 +1277,49 @@ dependencies = [ "weezl", ] +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if", + "dashmap 5.5.3", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.8.6", + "smallvec", + "spinning_top", +] + +[[package]] +name = "governor" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9efcab3c1958580ff1f25a2a41be1668f7603d849bb63af523b208a3cc1223b8" +dependencies = [ + "cfg-if", + "dashmap 6.1.0", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.4", + "hashbrown 0.16.1", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.9.4", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "h2" version = "0.4.13" @@ -1275,7 +1353,18 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -1445,6 +1534,19 @@ dependencies = [ "webpki-roots 1.0.7", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -1994,6 +2096,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -2013,13 +2121,25 @@ dependencies = [ "memchr", ] +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2307,6 +2427,26 @@ dependencies = [ "sha2", ] +[[package]] +name = "pin-project" +version = "1.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2466b2336ed02bcdca6b294417127b90ec92038d1d5c4fbeac971a922e0e0924" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c96395f0a926bc13b1c17622aaddda1ecb55d49c8f1bf9777e4d877800a43f8b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.17" @@ -2500,7 +2640,8 @@ dependencies = [ "axum-server", "base64", "chrono", - "dashmap", + "dashmap 6.1.0", + "governor 0.6.3", "hex", "ipnet", "jsonwebtoken", @@ -2520,6 +2661,7 @@ dependencies = [ "tokio", "tower", "tower-http", + "tower_governor", "tracing", "tracing-subscriber", "ulid", @@ -2600,6 +2742,12 @@ dependencies = [ "bstr", ] +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.5" @@ -2662,6 +2810,21 @@ version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2803,6 +2966,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.11.1", +] + [[package]] name = "rcgen" version = "0.13.2" @@ -2999,7 +3171,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -3307,7 +3479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -3319,6 +3491,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3612,7 +3793,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -3912,6 +4093,35 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "tonic" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac2a5518c70fa84342385732db33fb3f44bc4cc748936eb5833d2df34d6445ef" +dependencies = [ + "async-trait", + "axum", + "base64", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "socket2", + "sync_wrapper", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "totp-rs" version = "5.7.1" @@ -3936,9 +4146,12 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", + "indexmap", "pin-project-lite", + "slab", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -3985,6 +4198,23 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "tower_governor" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44de9b94d849d3c46e06a883d72d408c2de6403367b39df2b1c9d9e7b6736fe6" +dependencies = [ + "axum", + "forwarded-header-value", + "governor 0.10.4", + "http", + "pin-project", + "thiserror 2.0.18", + "tonic", + "tower", + "tracing", +] + [[package]] name = "tracing" version = "0.1.44" @@ -4477,7 +4707,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml old mode 100755 new mode 100644 index 29e7bc8..cb58bf7 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,5 +81,9 @@ aes-gcm = { version = "0.10" } ipnet = { version = "2" } url = { version = "2" } +# Rate limiting +tower_governor = { version = "0.8", features = ["tracing"] } +governor = "0.6" + # Email lettre = { version = "0.11.22", features = ["tokio1-rustls-transport"] } diff --git a/config/config.example.toml b/config/config.example.toml index a76ab8d..4d22c43 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -107,3 +107,20 @@ web_tls_key_path = "/etc/patch-manager/tls/web.key" # The backend sends tokens as query parameters to this URL. # Default: "http://localhost:5173/auth/sso/callback" (Vite dev server) sso_callback_url = "http://localhost:5173/auth/sso/callback" + +# ============================================================ +# Rate Limiting +# ============================================================ +[rate_limit] +# Enrollment endpoint: requests per minute per IP (default: 5) +enrollment_rpm = 5 +# Enrollment burst allowance (default: 3) +enrollment_burst = 3 +# Public auth endpoints: requests per minute per IP (default: 20) +auth_rpm = 20 +# Auth burst allowance (default: 10) +auth_burst = 10 +# Authenticated API: requests per minute per IP (default: 120) +api_rpm = 120 +# API burst allowance (default: 30) +api_burst = 30 diff --git a/crates/pm-core/src/config.rs b/crates/pm-core/src/config.rs old mode 100755 new mode 100644 index 83d58bb..628cafa --- a/crates/pm-core/src/config.rs +++ b/crates/pm-core/src/config.rs @@ -1,6 +1,61 @@ use config::{Config, ConfigError, Environment, File}; use serde::{Deserialize, Serialize}; +/// Rate limiting configuration per route group. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct RateLimitConfig { + /// Enrollment endpoint: requests per minute per IP (default: 5) + #[serde(default = "default_enrollment_rpm")] + pub enrollment_rpm: u32, + /// Enrollment burst allowance (default: 3) + #[serde(default = "default_enrollment_burst")] + pub enrollment_burst: u32, + /// Public auth endpoints: requests per minute per IP (default: 20) + #[serde(default = "default_auth_rpm")] + pub auth_rpm: u32, + /// Auth burst allowance (default: 10) + #[serde(default = "default_auth_burst")] + pub auth_burst: u32, + /// Authenticated API: requests per minute per IP (default: 120) + #[serde(default = "default_api_rpm")] + pub api_rpm: u32, + /// API burst allowance (default: 30) + #[serde(default = "default_api_burst")] + pub api_burst: u32, +} + +fn default_enrollment_rpm() -> u32 { + 5 +} +fn default_enrollment_burst() -> u32 { + 3 +} +fn default_auth_rpm() -> u32 { + 20 +} +fn default_auth_burst() -> u32 { + 10 +} +fn default_api_rpm() -> u32 { + 120 +} +fn default_api_burst() -> u32 { + 30 +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + enrollment_rpm: default_enrollment_rpm(), + enrollment_burst: default_enrollment_burst(), + auth_rpm: default_auth_rpm(), + auth_burst: default_auth_burst(), + api_rpm: default_api_rpm(), + api_burst: default_api_burst(), + } + } +} + /// Top-level application configuration. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AppConfig { @@ -9,6 +64,8 @@ pub struct AppConfig { pub worker: WorkerConfig, pub logging: LoggingConfig, pub security: SecurityConfig, + #[serde(default)] + pub rate_limit: RateLimitConfig, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -151,6 +208,7 @@ impl Default for AppConfig { web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(), sso_callback_url: default_sso_callback_url(), }, + rate_limit: RateLimitConfig::default(), } } } diff --git a/crates/pm-web/Cargo.toml b/crates/pm-web/Cargo.toml index c5bbd3a..71fde4b 100644 --- a/crates/pm-web/Cargo.toml +++ b/crates/pm-web/Cargo.toml @@ -33,6 +33,8 @@ ulid = { workspace = true } chrono = { workspace = true } ipnet = { workspace = true } dashmap = { version = "6" } +tower_governor = { workspace = true } +governor = { workspace = true } reqwest = { workspace = true } lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] } rand = { workspace = true } diff --git a/crates/pm-web/src/main.rs b/crates/pm-web/src/main.rs old mode 100755 new mode 100644 index 40617f0..e88cad0 --- a/crates/pm-web/src/main.rs +++ b/crates/pm-web/src/main.rs @@ -15,12 +15,11 @@ use pm_core::{ use routes::sso::{OidcCache, SsoSession}; use routes::ws::WsTicket; use serde_json::{json, Value}; -use std::{ - net::{IpAddr, SocketAddr}, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::sync::Mutex; +use tower_governor::{ + governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer, +}; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, @@ -41,8 +40,6 @@ pub struct AppState { pub oidc_cache: Arc>, /// Internal certificate authority for mTLS client cert issuance. pub ca: Arc, - /// IP-based rate limits for enrollment requests. - pub enrollment_rate_limits: Arc>, /// Short-lived cache for approved enrollment PKI bundles. pub approved_enrollments: Arc>, } @@ -104,7 +101,6 @@ async fn main() -> anyhow::Result<()> { let ws_tickets: Arc> = Arc::new(DashMap::new()); let sso_sessions: Arc> = Arc::new(DashMap::new()); let oidc_cache: Arc> = Arc::new(Mutex::new(OidcCache::default())); - let enrollment_rate_limits: Arc> = Arc::new(DashMap::new()); let approved_enrollments: Arc> = Arc::new(DashMap::new()); // Background task: purge expired WS tickets every 30 seconds. @@ -144,19 +140,6 @@ async fn main() -> anyhow::Result<()> { }); } - // Background task: purge expired enrollment rate limits every 5 minutes. - { - let limits = enrollment_rate_limits.clone(); - tokio::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(300)); - loop { - interval.tick().await; - let now = Instant::now(); - limits.retain(|_, v| now.duration_since(*v) < Duration::from_secs(3600)); - } - }); - } - // Background task: purge approved enrollment PKI bundles every 10 minutes. { let approved = approved_enrollments.clone(); @@ -177,7 +160,6 @@ async fn main() -> anyhow::Result<()> { ws_tickets, sso_sessions, ca: Arc::new(ca), - enrollment_rate_limits, approved_enrollments, oidc_cache, }; @@ -205,7 +187,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!(%addr, "Listening (HTTPS)"); axum_server::bind_rustls(addr, tls_config) - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .await?; } else { tracing::warn!( @@ -216,7 +198,11 @@ async fn main() -> anyhow::Result<()> { ); tracing::info!(%addr, "Listening (HTTP — no TLS)"); let listener = tokio::net::TcpListener::bind(addr).await?; - axum::serve(listener, app).await?; + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await?; } Ok(()) @@ -226,8 +212,59 @@ async fn main() -> anyhow::Result<()> { pub fn build_router(state: AppState) -> Router { let static_dir = state.config.server.static_dir.clone(); let auth_config = state.auth_config.clone(); + let rl = &state.config.rate_limit; - // All protected API routes — require valid JWT + // Enrollment rate limiting: strict (5 req/min per IP, burst 3) + // Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy. + // governor quota: 1 request per 12_000ms = ~5/min sustained + let enrollment_governor = Arc::new( + GovernorConfigBuilder::default() + .key_extractor(SmartIpKeyExtractor) + .per_millisecond(12_000) + .burst_size(rl.enrollment_burst) + .finish() + .expect("Invalid enrollment governor config"), + ); + + // Auth rate limiting: moderate (20 req/min per IP, burst 10) + // Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy. + // governor quota: 1 request per 3_000ms = ~20/min sustained + let auth_governor = Arc::new( + GovernorConfigBuilder::default() + .key_extractor(SmartIpKeyExtractor) + .per_millisecond(3_000) + .burst_size(rl.auth_burst) + .finish() + .expect("Invalid auth governor config"), + ); + + // API rate limiting: normal (120 req/min per IP, burst 30) + // Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy. + // governor quota: 1 request per 500ms = ~120/min sustained + let api_governor = Arc::new( + GovernorConfigBuilder::default() + .key_extractor(SmartIpKeyExtractor) + .per_millisecond(500) + .burst_size(rl.api_burst) + .finish() + .expect("Invalid API governor config"), + ); + + // Enrollment routes with strict per-IP rate limiting + let enrollment_router = + routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor)); + + // Public auth routes with moderate per-IP rate limiting + let auth_public_router = + routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor))); + + // SSO routes with moderate per-IP rate limiting + let sso_public_router = + routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor))); + let sso_azure_router = + routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor)); + + // All protected API routes — require valid JWT, with normal per-IP rate limiting let protected_api = Router::new() // Auth: MFA setup/verify // Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*) @@ -267,7 +304,8 @@ pub fn build_router(state: AppState) -> Router { .nest("/settings", routes::settings::router()) // Admin enrollment routes (JWT protected, Admin role enforced) .nest("/admin", routes::enrollment::admin_router()) - // Apply auth middleware to all the above + // Apply rate limiting then auth middleware + .layer(GovernorLayer::new(api_governor)) .route_layer(middleware::from_fn(move |req, next| { let auth_config = auth_config.clone(); require_auth(auth_config, req, next) @@ -275,15 +313,15 @@ pub fn build_router(state: AppState) -> Router { Router::new() .route("/status/health", get(health_handler)) - // Public auth routes (no JWT needed) - .nest("/api/v1/auth", routes::auth::public_router()) + // Public auth routes (rate-limited, no JWT) + .nest("/api/v1/auth", auth_public_router) // Public enrollment endpoints (rate-limited, no JWT) - .nest("/api/v1", routes::enrollment::router()) - // Public SSO routes (no JWT needed) - .nest("/api/v1/auth/sso", routes::sso::public_router()) - // Public Azure SSO routes (no JWT needed) - .nest("/api/v1/auth/azure", routes::sso::azure_compat_router()) - // Protected API routes (JWT required) + .nest("/api/v1", enrollment_router) + // Public SSO routes (rate-limited, no JWT) + .nest("/api/v1/auth/sso", sso_public_router) + // Public Azure SSO routes (rate-limited, no JWT) + .nest("/api/v1/auth/azure", sso_azure_router) + // Protected API routes (JWT required, rate-limited) .nest("/api/v1", protected_api) // WebSocket browser endpoint — ticket-authenticated, outside JWT middleware .merge(routes::ws::ws_router()) diff --git a/crates/pm-web/src/routes/enrollment.rs b/crates/pm-web/src/routes/enrollment.rs index dd1890f..710fcc7 100644 --- a/crates/pm-web/src/routes/enrollment.rs +++ b/crates/pm-web/src/routes/enrollment.rs @@ -1,7 +1,7 @@ use crate::AppState; use axum::{ extract::{Path, State}, - http::{HeaderMap, StatusCode}, + http::StatusCode, response::{IntoResponse, Response}, routing::{delete, get, post}, Json, Router, @@ -16,8 +16,6 @@ use pm_core::{ }; use rand::{distributions::Alphanumeric, Rng}; use serde::Serialize; -use std::net::IpAddr; -use std::time::Instant; #[derive(Debug, Clone, Serialize)] pub struct HostConflict { @@ -34,43 +32,12 @@ pub fn router() -> Router { /// POST /api/v1/enroll /// Initiates host self-enrollment. +/// Rate limiting is handled by tower-governor middleware (per-IP, configurable). async fn enroll_host( State(state): State, - headers: HeaderMap, Json(payload): Json, ) -> Result)> { - // 1. IP-based Rate Limiting - // Prefer real IP from headers if behind proxy (e.g., X-Forwarded-For) - let ip = headers - .get("x-forwarded-for") - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.split(',').next()) - .and_then(|h| h.trim().parse::().ok()) - .unwrap_or_else(|| { - tracing::warn!( - "No X-Forwarded-For header found for enrollment request from public endpoint" - ); - // Default to a placeholder IP since we can't extract the socket addr without the ConnectInfo layer - "0.0.0.0".parse().unwrap() - }); - - { - let mut rate_limits = state - .enrollment_rate_limits - .entry(ip) - .or_insert(Instant::now() - std::time::Duration::from_secs(3600)); - let last_request = rate_limits.value(); - if last_request.elapsed().as_secs() < 60 { - // 1 request per minute per IP - return Err(( - StatusCode::TOO_MANY_REQUESTS, - Json(serde_json::json!({ "error": "Rate limit exceeded. Try again in a minute." })), - )); - } - *rate_limits = Instant::now(); - } - - // 2. Generate secure random polling token + // Generate secure random polling token let polling_token: String = rand::thread_rng() .sample_iter(&Alphanumeric) .take(64)