Private
Public Access
1
0

Apply cargo fmt formatting to fix CI/CD fmt job

This commit is contained in:
2026-04-12 14:13:36 +00:00
parent fa6cf0dba7
commit 17254e5217
21 changed files with 563 additions and 421 deletions

View File

@ -9,8 +9,8 @@
pub mod mtls;
pub mod whitelist;
pub use mtls::{MtlsConfig, MtlsMiddleware, MtlsError, ClientCertInfo};
pub use whitelist::{WhitelistManager, WhitelistMiddleware, WhitelistEntry, WhitelistConfig};
pub use mtls::{ClientCertInfo, MtlsConfig, MtlsError, MtlsMiddleware};
pub use whitelist::{WhitelistConfig, WhitelistEntry, WhitelistManager, WhitelistMiddleware};
/// Combined authentication result
#[derive(Debug, Clone)]
@ -44,7 +44,7 @@ mod tests {
cert_info: None,
client_ip: Some("192.168.1.100".parse().unwrap()),
};
assert!(result.is_authenticated());
assert!(result.mtls_valid);
assert!(result.ip_allowed);
@ -58,7 +58,7 @@ mod tests {
cert_info: None,
client_ip: Some("192.168.1.100".parse().unwrap()),
};
assert!(!result.is_authenticated());
}
@ -70,7 +70,7 @@ mod tests {
cert_info: None,
client_ip: Some("192.168.1.100".parse().unwrap()),
};
assert!(!result.is_authenticated());
}
}

View File

@ -3,13 +3,15 @@
//! Provides mutual TLS authentication middleware for Actix-web.
//! Non-mTLS connections are silently dropped (no response).
use actix_web::http::header;
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage,
};
use chrono::{DateTime, Duration, Utc};
use futures_util::future::LocalBoxFuture;
use rustls::{
server::{WebPkiClientVerifier, ServerConfig},
server::{ServerConfig, WebPkiClientVerifier},
RootCertStore,
};
use rustls_pemfile::{certs, private_key};
@ -20,14 +22,12 @@ use std::{
task::{Context, Poll},
};
use tracing::{debug, info, warn};
use chrono::{DateTime, Utc, Duration};
use actix_web::http::header;
/// Check for duplicate critical headers (VULN-006)
/// Returns true if duplicate headers are detected
fn has_duplicate_critical_headers(req: &ServiceRequest) -> bool {
let critical_headers = ["content-type", "authorization", "host"];
for header_name in critical_headers.iter() {
// Count occurrences of this header
let mut count = 0;
@ -67,7 +67,7 @@ impl MtlsMiddleware {
/// Create a new mTLS middleware
pub fn new(config: MtlsConfig) -> Result<Self, MtlsError> {
let cert_store = load_ca_certs(&config.ca_cert_path)?;
Ok(Self {
config: Arc::new(config),
cert_store: Arc::new(cert_store),
@ -95,21 +95,21 @@ impl MtlsMiddleware {
/// Load CA certificates from PEM file
fn load_ca_certs(path: &str) -> Result<RootCertStore, MtlsError> {
let mut cert_store = RootCertStore::empty();
let cert_file = File::open(path)
.map_err(|e| MtlsError::IoError(format!("Failed to open CA cert {}: {}", path, e)))?;
let mut reader = BufReader::new(cert_file);
let certs = certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MtlsError::ParseError(format!("Failed to parse CA certs: {}", e)))?;
for cert in certs {
cert_store.add(cert).map_err(|e| {
MtlsError::StoreError(format!("Failed to add CA cert to store: {}", e))
})?;
cert_store
.add(cert)
.map_err(|e| MtlsError::StoreError(format!("Failed to add CA cert to store: {}", e)))?;
}
info!("Loaded CA certificates from {}", path);
Ok(cert_store)
}
@ -119,11 +119,11 @@ fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'stati
let cert_file = File::open(path)
.map_err(|e| MtlsError::IoError(format!("Failed to open cert {}: {}", path, e)))?;
let mut reader = BufReader::new(cert_file);
let certs = certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MtlsError::ParseError(format!("Failed to parse server certs: {}", e)))?;
Ok(certs)
}
@ -132,11 +132,11 @@ fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'stat
let key_file = File::open(path)
.map_err(|e| MtlsError::IoError(format!("Failed to open key {}: {}", path, e)))?;
let mut reader = BufReader::new(key_file);
let key = private_key(&mut reader)
.map_err(|e| MtlsError::ParseError(format!("Failed to parse private key: {}", e)))?
.ok_or_else(|| MtlsError::ParseError("No private key found in file".to_string()))?;
Ok(key)
}
@ -199,7 +199,7 @@ where
fn call(&self, req: ServiceRequest) -> Self::Future {
let cert_store = self.cert_store.clone();
let peer_addr = req.peer_addr();
// VULN-006: Check for duplicate critical headers before processing
if has_duplicate_critical_headers(&req) {
warn!(
@ -207,15 +207,17 @@ where
"Duplicate critical headers detected - rejecting request (VULN-006)"
);
return Box::pin(async move {
Err(actix_web::error::ErrorBadRequest("Duplicate critical headers not allowed"))
Err(actix_web::error::ErrorBadRequest(
"Duplicate critical headers not allowed",
))
});
}
// Check for client certificate in request extensions
// In a proper mTLS setup with Actix-web + rustls, the certificate
// would be extracted from the TLS connection before reaching this middleware
let has_client_cert = req.extensions().get::<ClientCertInfo>().is_some();
if !has_client_cert {
// No client certificate provided - silent drop
warn!(
@ -224,13 +226,15 @@ where
);
// Return error immediately without calling service
return Box::pin(async move {
Err(actix_web::error::ErrorBadRequest("Client certificate required"))
Err(actix_web::error::ErrorBadRequest(
"Client certificate required",
))
});
}
// Certificate present - validate it
let cert_info = req.extensions().get::<ClientCertInfo>().cloned();
if let Some(info) = cert_info {
// Validate certificate against CA store
match validate_client_certificate(&info, &cert_store) {
@ -249,7 +253,9 @@ where
"mTLS client certificate validation failed - dropping connection"
);
return Box::pin(async move {
Err(actix_web::error::ErrorBadRequest("Certificate validation failed"))
Err(actix_web::error::ErrorBadRequest(
"Certificate validation failed",
))
});
}
}
@ -259,17 +265,17 @@ where
"No client certificate provided - dropping connection (mTLS required)"
);
return Box::pin(async move {
Err(actix_web::error::ErrorBadRequest("Client certificate required"))
Err(actix_web::error::ErrorBadRequest(
"Client certificate required",
))
});
}
debug!("mTLS authentication passed for request");
// All checks passed - call the service
let fut = self.service.call(req);
Box::pin(async move {
fut.await
})
Box::pin(async move { fut.await })
}
}
@ -290,22 +296,22 @@ fn validate_client_certificate(
) -> Result<(), MtlsError> {
// Check certificate validity period
let now = Utc::now();
if now < cert_info.not_before {
return Err(MtlsError::ValidationError(
"Certificate is not yet valid".to_string()
"Certificate is not yet valid".to_string(),
));
}
if now > cert_info.not_after {
return Err(MtlsError::ValidationError(
"Certificate has expired".to_string()
"Certificate has expired".to_string(),
));
}
// In production, would verify certificate chain against CA store
// For now, we trust certificates that were extracted from the TLS connection
Ok(())
}
@ -321,7 +327,7 @@ mod tests {
server_key_path: "/etc/linux_patch_api/certs/server.key".to_string(),
min_tls_version: "1.3".to_string(),
};
assert_eq!(config.ca_cert_path, "/etc/linux_patch_api/certs/ca.pem");
assert_eq!(config.min_tls_version, "1.3");
}
@ -335,15 +341,15 @@ mod tests {
not_before: Utc::now() - Duration::days(1),
not_after: Utc::now() + Duration::days(365),
};
assert!(info.subject.contains("CN="));
assert!(info.issuer.contains("CN="));
// Test validation with valid cert
let cert_store = RootCertStore::empty();
assert!(validate_client_certificate(&info, &cert_store).is_ok());
}
#[test]
fn test_client_cert_expired() {
let info = ClientCertInfo {
@ -353,7 +359,7 @@ mod tests {
not_before: Utc::now() - Duration::days(365),
not_after: Utc::now() - Duration::days(1),
};
let cert_store = RootCertStore::empty();
let result = validate_client_certificate(&info, &cert_store);
assert!(result.is_err());

View File

@ -42,19 +42,19 @@ impl WhitelistManager {
/// Create a new whitelist manager
pub fn new(config_path: &str) -> Result<Self> {
let entries = Arc::new(RwLock::new(HashSet::new()));
let mut manager = Self {
entries: entries.clone(),
config_path: config_path.to_string(),
watcher: None,
};
// Load initial whitelist
manager.reload()?;
// Set up file watcher for auto-reload
manager.setup_watcher()?;
Ok(manager)
}
@ -62,26 +62,27 @@ impl WhitelistManager {
pub fn reload(&self) -> Result<()> {
let config = self.load_config()?;
let entries = self.parse_entries(&config.entries)?;
let mut current_entries = self.entries.write().map_err(|e| {
anyhow::anyhow!("Failed to acquire whitelist lock: {}", e)
})?;
let mut current_entries = self
.entries
.write()
.map_err(|e| anyhow::anyhow!("Failed to acquire whitelist lock: {}", e))?;
*current_entries = entries;
info!(
path = %self.config_path,
count = current_entries.len(),
"Whitelist reloaded successfully"
);
Ok(())
}
/// Check if an IP address is allowed
pub fn is_allowed(&self, ip: &Ipv4Addr) -> bool {
let entries = self.entries.read().unwrap();
for entry in entries.iter() {
match entry {
WhitelistEntry::Ip(allowed_ip) => {
@ -101,7 +102,7 @@ impl WhitelistManager {
}
}
}
false
}
@ -126,38 +127,38 @@ impl WhitelistManager {
fn load_config(&self) -> Result<WhitelistConfig> {
let content = std::fs::read_to_string(&self.config_path)
.with_context(|| format!("Failed to read whitelist config: {}", self.config_path))?;
let config: WhitelistConfig = serde_yaml::from_str(&content)
.with_context(|| format!("Failed to parse whitelist config: {}", self.config_path))?;
Ok(config)
}
/// Parse whitelist entries from strings
fn parse_entries(&self, entries: &[String]) -> Result<HashSet<WhitelistEntry>> {
let mut parsed = HashSet::new();
for entry_str in entries {
let entry_str = entry_str.trim();
// Skip comments and empty lines
if entry_str.is_empty() || entry_str.starts_with('#') {
continue;
}
// Check for CIDR notation
if let Some((ip_str, prefix_str)) = entry_str.split_once('/') {
let ip: Ipv4Addr = ip_str.parse().with_context(|| {
format!("Invalid IP in CIDR notation: {}", entry_str)
})?;
let prefix: u8 = prefix_str.parse().with_context(|| {
format!("Invalid prefix in CIDR notation: {}", entry_str)
})?;
let ip: Ipv4Addr = ip_str
.parse()
.with_context(|| format!("Invalid IP in CIDR notation: {}", entry_str))?;
let prefix: u8 = prefix_str
.parse()
.with_context(|| format!("Invalid prefix in CIDR notation: {}", entry_str))?;
if prefix > 32 {
anyhow::bail!("Invalid CIDR prefix (must be 0-32): {}", entry_str);
}
parsed.insert(WhitelistEntry::Cidr {
network: ip,
prefix,
@ -185,7 +186,7 @@ impl WhitelistManager {
}
}
}
Ok(parsed)
}
@ -193,7 +194,7 @@ impl WhitelistManager {
fn setup_watcher(&mut self) -> Result<()> {
let config_path = self.config_path.clone();
let entries = self.entries.clone();
let watcher = RecommendedWatcher::new(
move |res: Result<Event, notify::Error>| {
if let Ok(event) = res {
@ -208,19 +209,19 @@ impl WhitelistManager {
},
Config::default().with_poll_interval(Duration::from_secs(5)),
)?;
let mut watcher = watcher;
let path = Path::new(&config_path);
if path.exists() {
watcher.watch(path, RecursiveMode::NonRecursive)?;
info!("Watching whitelist file for changes: {}", config_path);
} else {
warn!("Whitelist file does not exist yet: {}", config_path);
}
self.watcher = Some(watcher);
Ok(())
}
}
@ -234,24 +235,24 @@ fn ip_in_subnet(ip: &Ipv4Addr, network: Ipv4Addr, prefix: u8) -> bool {
} else {
!0u32 << (32 - prefix)
};
(ip_bits & mask) == (network_bits & mask)
}
/// Resolve a hostname to an IPv4 address
fn resolve_hostname(hostname: &str) -> Result<Ipv4Addr> {
use std::net::ToSocketAddrs;
let addrs = (hostname, 0)
.to_socket_addrs()
.with_context(|| format!("Failed to resolve hostname: {}", hostname))?;
for addr in addrs {
if let IpAddr::V4(ip) = addr.ip() {
return Ok(ip);
}
}
anyhow::bail!("No IPv4 address found for hostname: {}", hostname)
}
@ -337,11 +338,11 @@ mod tests {
std::fs::write(temp_path, "entries:\n - \"192.168.1.0/24\"\n").unwrap();
WhitelistManager::new(temp_path).unwrap()
});
// Test IP entry
let ip: Ipv4Addr = "192.168.1.100".parse().unwrap();
assert!(manager.is_allowed(&ip));
// Test IP outside subnet
let ip_outside: Ipv4Addr = "192.168.2.100".parse().unwrap();
assert!(!manager.is_allowed(&ip_outside));