Private
Public Access
1
0

fix: replace broken DashMap rate limiting with tower-governor middleware
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 1m1s
CI Pipeline / Rust Unit Tests (push) Successful in 1m21s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 16s
CI Pipeline / Build .deb & Release (push) Has been skipped

- Replace custom DashMap<IpAddr, Instant> rate limiting in enrollment.rs
  that fell back to 0.0.0.0 when X-Forwarded-For was missing, causing
  ALL enrollment traffic to share a single global rate limit bucket
- Use tower_governor with SmartIpKeyExtractor for proper per-IP rate
  limiting that respects X-Forwarded-For headers (critical behind HAProxy)
- Add three configurable rate limit tiers via config.toml:
  * Enrollment: 5 req/min per IP, burst 3 (strict)
  * Auth: 20 req/min per IP, burst 10 (moderate)
  * API: 120 req/min per IP, burst 30 (normal)
- Remove enrollment_rate_limits from AppState and cleanup task
- Remove manual rate limit code from enrollment.rs (headers param, IP extraction)
- Add into_make_service_with_connect_info for ConnectInfo fallback
- Add RateLimitConfig to AppConfig with sensible defaults

Fixes: #1
This commit is contained in:
2026-05-21 02:27:10 +00:00
parent 6c72dc3ac6
commit 59794bc8f2
7 changed files with 395 additions and 79 deletions

248
Cargo.lock generated
View File

@ -687,6 +687,19 @@ dependencies = [
"cipher", "cipher",
] ]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "dashmap" name = "dashmap"
version = "6.1.0" version = "6.1.0"
@ -771,7 +784,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users", "redox_users",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -885,7 +898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -970,6 +983,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foldhash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]] [[package]]
name = "font-kit" name = "font-kit"
version = "0.14.3" version = "0.14.3"
@ -1046,6 +1065,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "forwarded-header-value"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
dependencies = [
"nonempty",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "freetype-sys" name = "freetype-sys"
version = "0.20.1" version = "0.20.1"
@ -1155,6 +1184,12 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af43fadb8a98512d547e37b4e92e0ced13e205c061b87b4623eff01d918d6968"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.32" version = "0.3.32"
@ -1242,6 +1277,49 @@ dependencies = [
"weezl", "weezl",
] ]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if",
"dashmap 5.5.3",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.8.6",
"smallvec",
"spinning_top",
]
[[package]]
name = "governor"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9efcab3c1958580ff1f25a2a41be1668f7603d849bb63af523b208a3cc1223b8"
dependencies = [
"cfg-if",
"dashmap 6.1.0",
"futures-sink",
"futures-timer",
"futures-util",
"getrandom 0.3.4",
"hashbrown 0.16.1",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.9.4",
"smallvec",
"spinning_top",
"web-time",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.13" version = "0.4.13"
@ -1275,7 +1353,18 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [ dependencies = [
"allocator-api2", "allocator-api2",
"equivalent", "equivalent",
"foldhash", "foldhash 0.1.5",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.2.0",
] ]
[[package]] [[package]]
@ -1445,6 +1534,19 @@ dependencies = [
"webpki-roots 1.0.7", "webpki-roots 1.0.7",
] ]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
dependencies = [
"hyper",
"hyper-util",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]] [[package]]
name = "hyper-tls" name = "hyper-tls"
version = "0.6.0" version = "0.6.0"
@ -1994,6 +2096,12 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@ -2013,13 +2121,25 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "nonempty"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.50.3" version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -2307,6 +2427,26 @@ dependencies = [
"sha2", "sha2",
] ]
[[package]]
name = "pin-project"
version = "1.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2466b2336ed02bcdca6b294417127b90ec92038d1d5c4fbeac971a922e0e0924"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c96395f0a926bc13b1c17622aaddda1ecb55d49c8f1bf9777e4d877800a43f8b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.17" version = "0.2.17"
@ -2500,7 +2640,8 @@ dependencies = [
"axum-server", "axum-server",
"base64", "base64",
"chrono", "chrono",
"dashmap", "dashmap 6.1.0",
"governor 0.6.3",
"hex", "hex",
"ipnet", "ipnet",
"jsonwebtoken", "jsonwebtoken",
@ -2520,6 +2661,7 @@ dependencies = [
"tokio", "tokio",
"tower", "tower",
"tower-http", "tower-http",
"tower_governor",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"ulid", "ulid",
@ -2600,6 +2742,12 @@ dependencies = [
"bstr", "bstr",
] ]
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]] [[package]]
name = "potential_utf" name = "potential_utf"
version = "0.1.5" version = "0.1.5"
@ -2662,6 +2810,21 @@ version = "0.1.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f"
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quinn" name = "quinn"
version = "0.11.9" version = "0.11.9"
@ -2803,6 +2966,15 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
] ]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags 2.11.1",
]
[[package]] [[package]]
name = "rcgen" name = "rcgen"
version = "0.13.2" version = "0.13.2"
@ -2999,7 +3171,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -3307,7 +3479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -3319,6 +3491,15 @@ dependencies = [
"lock_api", "lock_api",
] ]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"
@ -3612,7 +3793,7 @@ dependencies = [
"getrandom 0.4.2", "getrandom 0.4.2",
"once_cell", "once_cell",
"rustix", "rustix",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -3912,6 +4093,35 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801"
[[package]]
name = "tonic"
version = "0.14.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac2a5518c70fa84342385732db33fb3f44bc4cc748936eb5833d2df34d6445ef"
dependencies = [
"async-trait",
"axum",
"base64",
"bytes",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"socket2",
"sync_wrapper",
"tokio",
"tokio-stream",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "totp-rs" name = "totp-rs"
version = "5.7.1" version = "5.7.1"
@ -3936,9 +4146,12 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-util", "futures-util",
"indexmap",
"pin-project-lite", "pin-project-lite",
"slab",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-util",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing", "tracing",
@ -3985,6 +4198,23 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tower_governor"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44de9b94d849d3c46e06a883d72d408c2de6403367b39df2b1c9d9e7b6736fe6"
dependencies = [
"axum",
"forwarded-header-value",
"governor 0.10.4",
"http",
"pin-project",
"thiserror 2.0.18",
"tonic",
"tower",
"tracing",
]
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.44" version = "0.1.44"
@ -4477,7 +4707,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]

4
Cargo.toml Executable file → Normal file
View File

@ -81,5 +81,9 @@ aes-gcm = { version = "0.10" }
ipnet = { version = "2" } ipnet = { version = "2" }
url = { version = "2" } url = { version = "2" }
# Rate limiting
tower_governor = { version = "0.8", features = ["tracing"] }
governor = "0.6"
# Email # Email
lettre = { version = "0.11.22", features = ["tokio1-rustls-transport"] } lettre = { version = "0.11.22", features = ["tokio1-rustls-transport"] }

View File

@ -107,3 +107,20 @@ web_tls_key_path = "/etc/patch-manager/tls/web.key"
# The backend sends tokens as query parameters to this URL. # The backend sends tokens as query parameters to this URL.
# Default: "http://localhost:5173/auth/sso/callback" (Vite dev server) # Default: "http://localhost:5173/auth/sso/callback" (Vite dev server)
sso_callback_url = "http://localhost:5173/auth/sso/callback" sso_callback_url = "http://localhost:5173/auth/sso/callback"
# ============================================================
# Rate Limiting
# ============================================================
[rate_limit]
# Enrollment endpoint: requests per minute per IP (default: 5)
enrollment_rpm = 5
# Enrollment burst allowance (default: 3)
enrollment_burst = 3
# Public auth endpoints: requests per minute per IP (default: 20)
auth_rpm = 20
# Auth burst allowance (default: 10)
auth_burst = 10
# Authenticated API: requests per minute per IP (default: 120)
api_rpm = 120
# API burst allowance (default: 30)
api_burst = 30

58
crates/pm-core/src/config.rs Executable file → Normal file
View File

@ -1,6 +1,61 @@
use config::{Config, ConfigError, Environment, File}; use config::{Config, ConfigError, Environment, File};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Rate limiting configuration per route group.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimitConfig {
/// Enrollment endpoint: requests per minute per IP (default: 5)
#[serde(default = "default_enrollment_rpm")]
pub enrollment_rpm: u32,
/// Enrollment burst allowance (default: 3)
#[serde(default = "default_enrollment_burst")]
pub enrollment_burst: u32,
/// Public auth endpoints: requests per minute per IP (default: 20)
#[serde(default = "default_auth_rpm")]
pub auth_rpm: u32,
/// Auth burst allowance (default: 10)
#[serde(default = "default_auth_burst")]
pub auth_burst: u32,
/// Authenticated API: requests per minute per IP (default: 120)
#[serde(default = "default_api_rpm")]
pub api_rpm: u32,
/// API burst allowance (default: 30)
#[serde(default = "default_api_burst")]
pub api_burst: u32,
}
fn default_enrollment_rpm() -> u32 {
5
}
fn default_enrollment_burst() -> u32 {
3
}
fn default_auth_rpm() -> u32 {
20
}
fn default_auth_burst() -> u32 {
10
}
fn default_api_rpm() -> u32 {
120
}
fn default_api_burst() -> u32 {
30
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enrollment_rpm: default_enrollment_rpm(),
enrollment_burst: default_enrollment_burst(),
auth_rpm: default_auth_rpm(),
auth_burst: default_auth_burst(),
api_rpm: default_api_rpm(),
api_burst: default_api_burst(),
}
}
}
/// Top-level application configuration. /// Top-level application configuration.
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AppConfig { pub struct AppConfig {
@ -9,6 +64,8 @@ pub struct AppConfig {
pub worker: WorkerConfig, pub worker: WorkerConfig,
pub logging: LoggingConfig, pub logging: LoggingConfig,
pub security: SecurityConfig, pub security: SecurityConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
@ -151,6 +208,7 @@ impl Default for AppConfig {
web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(), web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(),
sso_callback_url: default_sso_callback_url(), sso_callback_url: default_sso_callback_url(),
}, },
rate_limit: RateLimitConfig::default(),
} }
} }
} }

View File

@ -33,6 +33,8 @@ ulid = { workspace = true }
chrono = { workspace = true } chrono = { workspace = true }
ipnet = { workspace = true } ipnet = { workspace = true }
dashmap = { version = "6" } dashmap = { version = "6" }
tower_governor = { workspace = true }
governor = { workspace = true }
reqwest = { workspace = true } reqwest = { workspace = true }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] } lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
rand = { workspace = true } rand = { workspace = true }

106
crates/pm-web/src/main.rs Executable file → Normal file
View File

@ -15,12 +15,11 @@ use pm_core::{
use routes::sso::{OidcCache, SsoSession}; use routes::sso::{OidcCache, SsoSession};
use routes::ws::WsTicket; use routes::ws::WsTicket;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::{ use std::{net::SocketAddr, sync::Arc, time::Duration};
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
};
use tower_http::{ use tower_http::{
services::{ServeDir, ServeFile}, services::{ServeDir, ServeFile},
trace::TraceLayer, trace::TraceLayer,
@ -41,8 +40,6 @@ pub struct AppState {
pub oidc_cache: Arc<Mutex<OidcCache>>, pub oidc_cache: Arc<Mutex<OidcCache>>,
/// Internal certificate authority for mTLS client cert issuance. /// Internal certificate authority for mTLS client cert issuance.
pub ca: Arc<pm_ca::CertAuthority>, pub ca: Arc<pm_ca::CertAuthority>,
/// IP-based rate limits for enrollment requests.
pub enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>>,
/// Short-lived cache for approved enrollment PKI bundles. /// Short-lived cache for approved enrollment PKI bundles.
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>, pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
} }
@ -104,7 +101,6 @@ async fn main() -> anyhow::Result<()> {
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new()); let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new()); let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default())); let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
let enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>> = Arc::new(DashMap::new());
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new()); let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
// Background task: purge expired WS tickets every 30 seconds. // Background task: purge expired WS tickets every 30 seconds.
@ -144,19 +140,6 @@ async fn main() -> anyhow::Result<()> {
}); });
} }
// Background task: purge expired enrollment rate limits every 5 minutes.
{
let limits = enrollment_rate_limits.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300));
loop {
interval.tick().await;
let now = Instant::now();
limits.retain(|_, v| now.duration_since(*v) < Duration::from_secs(3600));
}
});
}
// Background task: purge approved enrollment PKI bundles every 10 minutes. // Background task: purge approved enrollment PKI bundles every 10 minutes.
{ {
let approved = approved_enrollments.clone(); let approved = approved_enrollments.clone();
@ -177,7 +160,6 @@ async fn main() -> anyhow::Result<()> {
ws_tickets, ws_tickets,
sso_sessions, sso_sessions,
ca: Arc::new(ca), ca: Arc::new(ca),
enrollment_rate_limits,
approved_enrollments, approved_enrollments,
oidc_cache, oidc_cache,
}; };
@ -205,7 +187,7 @@ async fn main() -> anyhow::Result<()> {
tracing::info!(%addr, "Listening (HTTPS)"); tracing::info!(%addr, "Listening (HTTPS)");
axum_server::bind_rustls(addr, tls_config) axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service()) .serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?; .await?;
} else { } else {
tracing::warn!( tracing::warn!(
@ -216,7 +198,11 @@ async fn main() -> anyhow::Result<()> {
); );
tracing::info!(%addr, "Listening (HTTP — no TLS)"); tracing::info!(%addr, "Listening (HTTP — no TLS)");
let listener = tokio::net::TcpListener::bind(addr).await?; let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?; axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
} }
Ok(()) Ok(())
@ -226,8 +212,59 @@ async fn main() -> anyhow::Result<()> {
pub fn build_router(state: AppState) -> Router { pub fn build_router(state: AppState) -> Router {
let static_dir = state.config.server.static_dir.clone(); let static_dir = state.config.server.static_dir.clone();
let auth_config = state.auth_config.clone(); let auth_config = state.auth_config.clone();
let rl = &state.config.rate_limit;
// All protected API routes — require valid JWT // Enrollment rate limiting: strict (5 req/min per IP, burst 3)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 12_000ms = ~5/min sustained
let enrollment_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(12_000)
.burst_size(rl.enrollment_burst)
.finish()
.expect("Invalid enrollment governor config"),
);
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 3_000ms = ~20/min sustained
let auth_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(3_000)
.burst_size(rl.auth_burst)
.finish()
.expect("Invalid auth governor config"),
);
// API rate limiting: normal (120 req/min per IP, burst 30)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 500ms = ~120/min sustained
let api_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(500)
.burst_size(rl.api_burst)
.finish()
.expect("Invalid API governor config"),
);
// Enrollment routes with strict per-IP rate limiting
let enrollment_router =
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
// Public auth routes with moderate per-IP rate limiting
let auth_public_router =
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
// SSO routes with moderate per-IP rate limiting
let sso_public_router =
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
let sso_azure_router =
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
// All protected API routes — require valid JWT, with normal per-IP rate limiting
let protected_api = Router::new() let protected_api = Router::new()
// Auth: MFA setup/verify // Auth: MFA setup/verify
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*) // Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
@ -267,7 +304,8 @@ pub fn build_router(state: AppState) -> Router {
.nest("/settings", routes::settings::router()) .nest("/settings", routes::settings::router())
// Admin enrollment routes (JWT protected, Admin role enforced) // Admin enrollment routes (JWT protected, Admin role enforced)
.nest("/admin", routes::enrollment::admin_router()) .nest("/admin", routes::enrollment::admin_router())
// Apply auth middleware to all the above // Apply rate limiting then auth middleware
.layer(GovernorLayer::new(api_governor))
.route_layer(middleware::from_fn(move |req, next| { .route_layer(middleware::from_fn(move |req, next| {
let auth_config = auth_config.clone(); let auth_config = auth_config.clone();
require_auth(auth_config, req, next) require_auth(auth_config, req, next)
@ -275,15 +313,15 @@ pub fn build_router(state: AppState) -> Router {
Router::new() Router::new()
.route("/status/health", get(health_handler)) .route("/status/health", get(health_handler))
// Public auth routes (no JWT needed) // Public auth routes (rate-limited, no JWT)
.nest("/api/v1/auth", routes::auth::public_router()) .nest("/api/v1/auth", auth_public_router)
// Public enrollment endpoints (rate-limited, no JWT) // Public enrollment endpoints (rate-limited, no JWT)
.nest("/api/v1", routes::enrollment::router()) .nest("/api/v1", enrollment_router)
// Public SSO routes (no JWT needed) // Public SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/sso", routes::sso::public_router()) .nest("/api/v1/auth/sso", sso_public_router)
// Public Azure SSO routes (no JWT needed) // Public Azure SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/azure", routes::sso::azure_compat_router()) .nest("/api/v1/auth/azure", sso_azure_router)
// Protected API routes (JWT required) // Protected API routes (JWT required, rate-limited)
.nest("/api/v1", protected_api) .nest("/api/v1", protected_api)
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware // WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
.merge(routes::ws::ws_router()) .merge(routes::ws::ws_router())

View File

@ -1,7 +1,7 @@
use crate::AppState; use crate::AppState;
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, State},
http::{HeaderMap, StatusCode}, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{delete, get, post}, routing::{delete, get, post},
Json, Router, Json, Router,
@ -16,8 +16,6 @@ use pm_core::{
}; };
use rand::{distributions::Alphanumeric, Rng}; use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize; use serde::Serialize;
use std::net::IpAddr;
use std::time::Instant;
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct HostConflict { pub struct HostConflict {
@ -34,43 +32,12 @@ pub fn router() -> Router<AppState> {
/// POST /api/v1/enroll /// POST /api/v1/enroll
/// Initiates host self-enrollment. /// Initiates host self-enrollment.
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
async fn enroll_host( async fn enroll_host(
State(state): State<AppState>, State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<CreateEnrollmentRequest>, Json(payload): Json<CreateEnrollmentRequest>,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
// 1. IP-based Rate Limiting // Generate secure random polling token
// Prefer real IP from headers if behind proxy (e.g., X-Forwarded-For)
let ip = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(',').next())
.and_then(|h| h.trim().parse::<IpAddr>().ok())
.unwrap_or_else(|| {
tracing::warn!(
"No X-Forwarded-For header found for enrollment request from public endpoint"
);
// Default to a placeholder IP since we can't extract the socket addr without the ConnectInfo layer
"0.0.0.0".parse().unwrap()
});
{
let mut rate_limits = state
.enrollment_rate_limits
.entry(ip)
.or_insert(Instant::now() - std::time::Duration::from_secs(3600));
let last_request = rate_limits.value();
if last_request.elapsed().as_secs() < 60 {
// 1 request per minute per IP
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({ "error": "Rate limit exceeded. Try again in a minute." })),
));
}
*rate_limits = Instant::now();
}
// 2. Generate secure random polling token
let polling_token: String = rand::thread_rng() let polling_token: String = rand::thread_rng()
.sample_iter(&Alphanumeric) .sample_iter(&Alphanumeric)
.take(64) .take(64)