Private
Public Access
1
0

fix: replace broken DashMap rate limiting with tower-governor middleware
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 1m1s
CI Pipeline / Rust Unit Tests (push) Successful in 1m21s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 16s
CI Pipeline / Build .deb & Release (push) Has been skipped

- Replace custom DashMap<IpAddr, Instant> rate limiting in enrollment.rs
  that fell back to 0.0.0.0 when X-Forwarded-For was missing, causing
  ALL enrollment traffic to share a single global rate limit bucket
- Use tower_governor with SmartIpKeyExtractor for proper per-IP rate
  limiting that respects X-Forwarded-For headers (critical behind HAProxy)
- Add three configurable rate limit tiers via config.toml:
  * Enrollment: 5 req/min per IP, burst 3 (strict)
  * Auth: 20 req/min per IP, burst 10 (moderate)
  * API: 120 req/min per IP, burst 30 (normal)
- Remove enrollment_rate_limits from AppState and cleanup task
- Remove manual rate limit code from enrollment.rs (headers param, IP extraction)
- Add into_make_service_with_connect_info for ConnectInfo fallback
- Add RateLimitConfig to AppConfig with sensible defaults

Fixes: #1
This commit is contained in:
2026-05-21 02:27:10 +00:00
parent 6c72dc3ac6
commit 59794bc8f2
7 changed files with 395 additions and 79 deletions

58
crates/pm-core/src/config.rs Executable file → Normal file
View File

@ -1,6 +1,61 @@
use config::{Config, ConfigError, Environment, File};
use serde::{Deserialize, Serialize};
/// Rate limiting configuration per route group.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimitConfig {
/// Enrollment endpoint: requests per minute per IP (default: 5)
#[serde(default = "default_enrollment_rpm")]
pub enrollment_rpm: u32,
/// Enrollment burst allowance (default: 3)
#[serde(default = "default_enrollment_burst")]
pub enrollment_burst: u32,
/// Public auth endpoints: requests per minute per IP (default: 20)
#[serde(default = "default_auth_rpm")]
pub auth_rpm: u32,
/// Auth burst allowance (default: 10)
#[serde(default = "default_auth_burst")]
pub auth_burst: u32,
/// Authenticated API: requests per minute per IP (default: 120)
#[serde(default = "default_api_rpm")]
pub api_rpm: u32,
/// API burst allowance (default: 30)
#[serde(default = "default_api_burst")]
pub api_burst: u32,
}
fn default_enrollment_rpm() -> u32 {
5
}
fn default_enrollment_burst() -> u32 {
3
}
fn default_auth_rpm() -> u32 {
20
}
fn default_auth_burst() -> u32 {
10
}
fn default_api_rpm() -> u32 {
120
}
fn default_api_burst() -> u32 {
30
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enrollment_rpm: default_enrollment_rpm(),
enrollment_burst: default_enrollment_burst(),
auth_rpm: default_auth_rpm(),
auth_burst: default_auth_burst(),
api_rpm: default_api_rpm(),
api_burst: default_api_burst(),
}
}
}
/// Top-level application configuration.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AppConfig {
@ -9,6 +64,8 @@ pub struct AppConfig {
pub worker: WorkerConfig,
pub logging: LoggingConfig,
pub security: SecurityConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -151,6 +208,7 @@ impl Default for AppConfig {
web_tls_key_path: "/etc/patch-manager/tls/web.key".to_string(),
sso_callback_url: default_sso_callback_url(),
},
rate_limit: RateLimitConfig::default(),
}
}
}

View File

@ -33,6 +33,8 @@ ulid = { workspace = true }
chrono = { workspace = true }
ipnet = { workspace = true }
dashmap = { version = "6" }
tower_governor = { workspace = true }
governor = { workspace = true }
reqwest = { workspace = true }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
rand = { workspace = true }

106
crates/pm-web/src/main.rs Executable file → Normal file
View File

@ -15,12 +15,11 @@ use pm_core::{
use routes::sso::{OidcCache, SsoSession};
use routes::ws::WsTicket;
use serde_json::{json, Value};
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::sync::Mutex;
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
};
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
@ -41,8 +40,6 @@ pub struct AppState {
pub oidc_cache: Arc<Mutex<OidcCache>>,
/// Internal certificate authority for mTLS client cert issuance.
pub ca: Arc<pm_ca::CertAuthority>,
/// IP-based rate limits for enrollment requests.
pub enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>>,
/// Short-lived cache for approved enrollment PKI bundles.
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
}
@ -104,7 +101,6 @@ async fn main() -> anyhow::Result<()> {
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
let enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>> = Arc::new(DashMap::new());
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
// Background task: purge expired WS tickets every 30 seconds.
@ -144,19 +140,6 @@ async fn main() -> anyhow::Result<()> {
});
}
// Background task: purge expired enrollment rate limits every 5 minutes.
{
let limits = enrollment_rate_limits.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300));
loop {
interval.tick().await;
let now = Instant::now();
limits.retain(|_, v| now.duration_since(*v) < Duration::from_secs(3600));
}
});
}
// Background task: purge approved enrollment PKI bundles every 10 minutes.
{
let approved = approved_enrollments.clone();
@ -177,7 +160,6 @@ async fn main() -> anyhow::Result<()> {
ws_tickets,
sso_sessions,
ca: Arc::new(ca),
enrollment_rate_limits,
approved_enrollments,
oidc_cache,
};
@ -205,7 +187,7 @@ async fn main() -> anyhow::Result<()> {
tracing::info!(%addr, "Listening (HTTPS)");
axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service())
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
} else {
tracing::warn!(
@ -216,7 +198,11 @@ async fn main() -> anyhow::Result<()> {
);
tracing::info!(%addr, "Listening (HTTP — no TLS)");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
}
Ok(())
@ -226,8 +212,59 @@ async fn main() -> anyhow::Result<()> {
pub fn build_router(state: AppState) -> Router {
let static_dir = state.config.server.static_dir.clone();
let auth_config = state.auth_config.clone();
let rl = &state.config.rate_limit;
// All protected API routes — require valid JWT
// Enrollment rate limiting: strict (5 req/min per IP, burst 3)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 12_000ms = ~5/min sustained
let enrollment_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(12_000)
.burst_size(rl.enrollment_burst)
.finish()
.expect("Invalid enrollment governor config"),
);
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 3_000ms = ~20/min sustained
let auth_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(3_000)
.burst_size(rl.auth_burst)
.finish()
.expect("Invalid auth governor config"),
);
// API rate limiting: normal (120 req/min per IP, burst 30)
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
// governor quota: 1 request per 500ms = ~120/min sustained
let api_governor = Arc::new(
GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_millisecond(500)
.burst_size(rl.api_burst)
.finish()
.expect("Invalid API governor config"),
);
// Enrollment routes with strict per-IP rate limiting
let enrollment_router =
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
// Public auth routes with moderate per-IP rate limiting
let auth_public_router =
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
// SSO routes with moderate per-IP rate limiting
let sso_public_router =
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
let sso_azure_router =
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
// All protected API routes — require valid JWT, with normal per-IP rate limiting
let protected_api = Router::new()
// Auth: MFA setup/verify
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
@ -267,7 +304,8 @@ pub fn build_router(state: AppState) -> Router {
.nest("/settings", routes::settings::router())
// Admin enrollment routes (JWT protected, Admin role enforced)
.nest("/admin", routes::enrollment::admin_router())
// Apply auth middleware to all the above
// Apply rate limiting then auth middleware
.layer(GovernorLayer::new(api_governor))
.route_layer(middleware::from_fn(move |req, next| {
let auth_config = auth_config.clone();
require_auth(auth_config, req, next)
@ -275,15 +313,15 @@ pub fn build_router(state: AppState) -> Router {
Router::new()
.route("/status/health", get(health_handler))
// Public auth routes (no JWT needed)
.nest("/api/v1/auth", routes::auth::public_router())
// Public auth routes (rate-limited, no JWT)
.nest("/api/v1/auth", auth_public_router)
// Public enrollment endpoints (rate-limited, no JWT)
.nest("/api/v1", routes::enrollment::router())
// Public SSO routes (no JWT needed)
.nest("/api/v1/auth/sso", routes::sso::public_router())
// Public Azure SSO routes (no JWT needed)
.nest("/api/v1/auth/azure", routes::sso::azure_compat_router())
// Protected API routes (JWT required)
.nest("/api/v1", enrollment_router)
// Public SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/sso", sso_public_router)
// Public Azure SSO routes (rate-limited, no JWT)
.nest("/api/v1/auth/azure", sso_azure_router)
// Protected API routes (JWT required, rate-limited)
.nest("/api/v1", protected_api)
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
.merge(routes::ws::ws_router())

View File

@ -1,7 +1,7 @@
use crate::AppState;
use axum::{
extract::{Path, State},
http::{HeaderMap, StatusCode},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
Json, Router,
@ -16,8 +16,6 @@ use pm_core::{
};
use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize;
use std::net::IpAddr;
use std::time::Instant;
#[derive(Debug, Clone, Serialize)]
pub struct HostConflict {
@ -34,43 +32,12 @@ pub fn router() -> Router<AppState> {
/// POST /api/v1/enroll
/// Initiates host self-enrollment.
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
async fn enroll_host(
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<CreateEnrollmentRequest>,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
// 1. IP-based Rate Limiting
// Prefer real IP from headers if behind proxy (e.g., X-Forwarded-For)
let ip = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(',').next())
.and_then(|h| h.trim().parse::<IpAddr>().ok())
.unwrap_or_else(|| {
tracing::warn!(
"No X-Forwarded-For header found for enrollment request from public endpoint"
);
// Default to a placeholder IP since we can't extract the socket addr without the ConnectInfo layer
"0.0.0.0".parse().unwrap()
});
{
let mut rate_limits = state
.enrollment_rate_limits
.entry(ip)
.or_insert(Instant::now() - std::time::Duration::from_secs(3600));
let last_request = rate_limits.value();
if last_request.elapsed().as_secs() < 60 {
// 1 request per minute per IP
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({ "error": "Rate limit exceeded. Try again in a minute." })),
));
}
*rate_limits = Instant::now();
}
// 2. Generate secure random polling token
// Generate secure random polling token
let polling_token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)