fix: replace broken DashMap rate limiting with tower-governor middleware
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 1m1s
CI Pipeline / Rust Unit Tests (push) Successful in 1m21s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 16s
CI Pipeline / Build .deb & Release (push) Has been skipped
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 4s
CI Pipeline / Clippy Lints (push) Successful in 1m1s
CI Pipeline / Rust Unit Tests (push) Successful in 1m21s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 16s
CI Pipeline / Build .deb & Release (push) Has been skipped
- Replace custom DashMap<IpAddr, Instant> rate limiting in enrollment.rs that fell back to 0.0.0.0 when X-Forwarded-For was missing, causing ALL enrollment traffic to share a single global rate limit bucket - Use tower_governor with SmartIpKeyExtractor for proper per-IP rate limiting that respects X-Forwarded-For headers (critical behind HAProxy) - Add three configurable rate limit tiers via config.toml: * Enrollment: 5 req/min per IP, burst 3 (strict) * Auth: 20 req/min per IP, burst 10 (moderate) * API: 120 req/min per IP, burst 30 (normal) - Remove enrollment_rate_limits from AppState and cleanup task - Remove manual rate limit code from enrollment.rs (headers param, IP extraction) - Add into_make_service_with_connect_info for ConnectInfo fallback - Add RateLimitConfig to AppConfig with sensible defaults Fixes: #1
This commit is contained in:
@ -33,6 +33,8 @@ ulid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
dashmap = { version = "6" }
|
||||
tower_governor = { workspace = true }
|
||||
governor = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
|
||||
rand = { workspace = true }
|
||||
|
||||
106
crates/pm-web/src/main.rs
Executable file → Normal file
106
crates/pm-web/src/main.rs
Executable file → Normal file
@ -15,12 +15,11 @@ use pm_core::{
|
||||
use routes::sso::{OidcCache, SsoSession};
|
||||
use routes::ws::WsTicket;
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_governor::{
|
||||
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
|
||||
};
|
||||
use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
trace::TraceLayer,
|
||||
@ -41,8 +40,6 @@ pub struct AppState {
|
||||
pub oidc_cache: Arc<Mutex<OidcCache>>,
|
||||
/// Internal certificate authority for mTLS client cert issuance.
|
||||
pub ca: Arc<pm_ca::CertAuthority>,
|
||||
/// IP-based rate limits for enrollment requests.
|
||||
pub enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>>,
|
||||
/// Short-lived cache for approved enrollment PKI bundles.
|
||||
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
|
||||
}
|
||||
@ -104,7 +101,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
|
||||
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
|
||||
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
|
||||
let enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>> = Arc::new(DashMap::new());
|
||||
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
|
||||
|
||||
// Background task: purge expired WS tickets every 30 seconds.
|
||||
@ -144,19 +140,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
});
|
||||
}
|
||||
|
||||
// Background task: purge expired enrollment rate limits every 5 minutes.
|
||||
{
|
||||
let limits = enrollment_rate_limits.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = Instant::now();
|
||||
limits.retain(|_, v| now.duration_since(*v) < Duration::from_secs(3600));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Background task: purge approved enrollment PKI bundles every 10 minutes.
|
||||
{
|
||||
let approved = approved_enrollments.clone();
|
||||
@ -177,7 +160,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
ws_tickets,
|
||||
sso_sessions,
|
||||
ca: Arc::new(ca),
|
||||
enrollment_rate_limits,
|
||||
approved_enrollments,
|
||||
oidc_cache,
|
||||
};
|
||||
@ -205,7 +187,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
tracing::info!(%addr, "Listening (HTTPS)");
|
||||
axum_server::bind_rustls(addr, tls_config)
|
||||
.serve(app.into_make_service())
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await?;
|
||||
} else {
|
||||
tracing::warn!(
|
||||
@ -216,7 +198,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
tracing::info!(%addr, "Listening (HTTP — no TLS)");
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -226,8 +212,59 @@ async fn main() -> anyhow::Result<()> {
|
||||
pub fn build_router(state: AppState) -> Router {
|
||||
let static_dir = state.config.server.static_dir.clone();
|
||||
let auth_config = state.auth_config.clone();
|
||||
let rl = &state.config.rate_limit;
|
||||
|
||||
// All protected API routes — require valid JWT
|
||||
// Enrollment rate limiting: strict (5 req/min per IP, burst 3)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 12_000ms = ~5/min sustained
|
||||
let enrollment_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(12_000)
|
||||
.burst_size(rl.enrollment_burst)
|
||||
.finish()
|
||||
.expect("Invalid enrollment governor config"),
|
||||
);
|
||||
|
||||
// Auth rate limiting: moderate (20 req/min per IP, burst 10)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 3_000ms = ~20/min sustained
|
||||
let auth_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(3_000)
|
||||
.burst_size(rl.auth_burst)
|
||||
.finish()
|
||||
.expect("Invalid auth governor config"),
|
||||
);
|
||||
|
||||
// API rate limiting: normal (120 req/min per IP, burst 30)
|
||||
// Uses SmartIpKeyExtractor to respect X-Forwarded-For behind reverse proxy.
|
||||
// governor quota: 1 request per 500ms = ~120/min sustained
|
||||
let api_governor = Arc::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.per_millisecond(500)
|
||||
.burst_size(rl.api_burst)
|
||||
.finish()
|
||||
.expect("Invalid API governor config"),
|
||||
);
|
||||
|
||||
// Enrollment routes with strict per-IP rate limiting
|
||||
let enrollment_router =
|
||||
routes::enrollment::router().layer(GovernorLayer::new(enrollment_governor));
|
||||
|
||||
// Public auth routes with moderate per-IP rate limiting
|
||||
let auth_public_router =
|
||||
routes::auth::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
|
||||
|
||||
// SSO routes with moderate per-IP rate limiting
|
||||
let sso_public_router =
|
||||
routes::sso::public_router().layer(GovernorLayer::new(Arc::clone(&auth_governor)));
|
||||
let sso_azure_router =
|
||||
routes::sso::azure_compat_router().layer(GovernorLayer::new(auth_governor));
|
||||
|
||||
// All protected API routes — require valid JWT, with normal per-IP rate limiting
|
||||
let protected_api = Router::new()
|
||||
// Auth: MFA setup/verify
|
||||
// Auth: MFA setup/verify/disable (nested under /auth so paths are /api/v1/auth/mfa/*)
|
||||
@ -267,7 +304,8 @@ pub fn build_router(state: AppState) -> Router {
|
||||
.nest("/settings", routes::settings::router())
|
||||
// Admin enrollment routes (JWT protected, Admin role enforced)
|
||||
.nest("/admin", routes::enrollment::admin_router())
|
||||
// Apply auth middleware to all the above
|
||||
// Apply rate limiting then auth middleware
|
||||
.layer(GovernorLayer::new(api_governor))
|
||||
.route_layer(middleware::from_fn(move |req, next| {
|
||||
let auth_config = auth_config.clone();
|
||||
require_auth(auth_config, req, next)
|
||||
@ -275,15 +313,15 @@ pub fn build_router(state: AppState) -> Router {
|
||||
|
||||
Router::new()
|
||||
.route("/status/health", get(health_handler))
|
||||
// Public auth routes (no JWT needed)
|
||||
.nest("/api/v1/auth", routes::auth::public_router())
|
||||
// Public auth routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth", auth_public_router)
|
||||
// Public enrollment endpoints (rate-limited, no JWT)
|
||||
.nest("/api/v1", routes::enrollment::router())
|
||||
// Public SSO routes (no JWT needed)
|
||||
.nest("/api/v1/auth/sso", routes::sso::public_router())
|
||||
// Public Azure SSO routes (no JWT needed)
|
||||
.nest("/api/v1/auth/azure", routes::sso::azure_compat_router())
|
||||
// Protected API routes (JWT required)
|
||||
.nest("/api/v1", enrollment_router)
|
||||
// Public SSO routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth/sso", sso_public_router)
|
||||
// Public Azure SSO routes (rate-limited, no JWT)
|
||||
.nest("/api/v1/auth/azure", sso_azure_router)
|
||||
// Protected API routes (JWT required, rate-limited)
|
||||
.nest("/api/v1", protected_api)
|
||||
// WebSocket browser endpoint — ticket-authenticated, outside JWT middleware
|
||||
.merge(routes::ws::ws_router())
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
@ -16,8 +16,6 @@ use pm_core::{
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::Serialize;
|
||||
use std::net::IpAddr;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct HostConflict {
|
||||
@ -34,43 +32,12 @@ pub fn router() -> Router<AppState> {
|
||||
|
||||
/// POST /api/v1/enroll
|
||||
/// Initiates host self-enrollment.
|
||||
/// Rate limiting is handled by tower-governor middleware (per-IP, configurable).
|
||||
async fn enroll_host(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<CreateEnrollmentRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
|
||||
// 1. IP-based Rate Limiting
|
||||
// Prefer real IP from headers if behind proxy (e.g., X-Forwarded-For)
|
||||
let ip = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(',').next())
|
||||
.and_then(|h| h.trim().parse::<IpAddr>().ok())
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
"No X-Forwarded-For header found for enrollment request from public endpoint"
|
||||
);
|
||||
// Default to a placeholder IP since we can't extract the socket addr without the ConnectInfo layer
|
||||
"0.0.0.0".parse().unwrap()
|
||||
});
|
||||
|
||||
{
|
||||
let mut rate_limits = state
|
||||
.enrollment_rate_limits
|
||||
.entry(ip)
|
||||
.or_insert(Instant::now() - std::time::Duration::from_secs(3600));
|
||||
let last_request = rate_limits.value();
|
||||
if last_request.elapsed().as_secs() < 60 {
|
||||
// 1 request per minute per IP
|
||||
return Err((
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(serde_json::json!({ "error": "Rate limit exceeded. Try again in a minute." })),
|
||||
));
|
||||
}
|
||||
*rate_limits = Instant::now();
|
||||
}
|
||||
|
||||
// 2. Generate secure random polling token
|
||||
// Generate secure random polling token
|
||||
let polling_token: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(64)
|
||||
|
||||
Reference in New Issue
Block a user