Private
Public Access
1
0

feat: add auto-enrollment, cert validation, and crash loop fixes

- Auto-enrollment on startup when certs are missing/invalid and enrollment.manager_url configured
- Certificate validation (existence, parse, expiry, key match, CA trust)
- --enroll exits after completion (no port conflict with systemd service)
- --renew-certs flag for manual cert renewal
- SO_REUSEADDR on TcpListener::bind (prevents Address already in use)
- Polling token persistence for enrollment resume after restart
- Exit code strategy (0=clean, 1=error, 2=enrollment in progress)
- HTTP 409 (host already exists) handling during enrollment
- Move 'Listening on' log after actual bind
- Increase RestartSec to 10s and add StartLimitBurst=5
- Postinst checks for certs and enrollment URL, prints guidance
- EnrollmentConfig.manager_url changed to Option<String>
- cert_renewal_threshold_days and polling_token config fields
- Updated SPEC.md and DEPLOYMENT_GUIDE.md with new workflow
- RCA document for crash loop root cause analysis
- Version bumped to 1.2.0
This commit is contained in:
2026-05-29 10:44:42 -05:00
parent 48ec57581e
commit 1322598581
43 changed files with 1364 additions and 974 deletions

View File

@ -1,12 +1,18 @@
//! Configuration Loader - YAML config loading
//!
//! Loads and parses YAML configuration files.
//! Provides certificate validation for auto-enrollment workflow.
use anyhow::{Context, Result};
use rustls_pemfile::{certs, private_key};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
use time::OffsetDateTime;
/// Server configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ServerConfig {
pub port: u16,
pub bind: String,
@ -19,7 +25,7 @@ fn default_timeout() -> u64 {
}
/// TLS/mTLS configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TlsConfig {
#[serde(default = "default_true")]
pub enabled: bool,
@ -40,7 +46,7 @@ fn default_tls_version() -> String {
}
/// Jobs configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct JobsConfig {
pub max_concurrent: usize,
pub timeout_minutes: u64,
@ -53,7 +59,7 @@ fn default_storage_path() -> String {
}
/// Logging configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct LoggingConfig {
#[serde(default = "default_log_level")]
pub level: String,
@ -82,7 +88,7 @@ fn default_retention_days() -> u64 {
}
/// Whitelist configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct WhitelistConfig {
#[serde(default = "default_whitelist_path")]
pub path: String,
@ -93,7 +99,7 @@ fn default_whitelist_path() -> String {
}
/// Package manager configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct PackageManagerConfig {
#[serde(default = "default_backend")]
pub backend: String,
@ -104,10 +110,13 @@ fn default_backend() -> String {
}
/// Enrollment polling configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnrollmentConfig {
#[serde(default)]
pub manager_url: String,
/// Manager URL for enrollment. None means not configured.
/// Changed from String to Option<String> to support "not configured" state.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub manager_url: Option<String>,
/// Polling token persisted during enrollment for resume after restart.
#[serde(default)]
pub polling_token: String,
#[serde(default = "default_polling_interval")]
@ -122,6 +131,30 @@ pub struct EnrollmentConfig {
/// Highest priority — overrides both `report_interface` and auto-detect.
#[serde(default)]
pub report_ip: Option<String>,
/// Number of days before certificate expiry to trigger re-enrollment warning.
#[serde(default = "default_cert_renewal_threshold_days")]
pub cert_renewal_threshold_days: u32,
}
impl Default for EnrollmentConfig {
fn default() -> Self {
Self {
manager_url: None,
polling_token: String::new(),
polling_interval_seconds: 60,
max_poll_attempts: 1440,
report_interface: None,
report_ip: None,
cert_renewal_threshold_days: 7,
}
}
}
impl EnrollmentConfig {
/// Get the effective manager URL, treating empty strings as None.
pub fn effective_manager_url(&self) -> Option<&str> {
self.manager_url.as_deref().filter(|s| !s.is_empty())
}
}
fn default_polling_interval() -> u64 {
@ -132,8 +165,274 @@ fn default_max_poll_attempts() -> u32 {
1440
}
fn default_cert_renewal_threshold_days() -> u32 {
7
}
/// Certificate validation status returned by validate_certs().
#[derive(Debug, Clone)]
pub enum CertStatus {
/// All certificates are valid and not expiring soon.
Valid,
/// Certificates are valid but expiring within the threshold.
ExpiringSoon {
not_after: OffsetDateTime,
},
/// One or more certificate files are missing.
Missing {
paths: Vec<PathBuf>,
},
/// A certificate file exists but cannot be parsed as valid PEM.
Corrupt {
path: PathBuf,
error: String,
},
/// A certificate has expired (not_after is in the past).
Expired {
path: PathBuf,
not_after: OffsetDateTime,
},
/// Server certificate public key does not match server private key.
KeyMismatch,
/// Server certificate is not signed by the configured CA.
Untrusted,
}
impl std::fmt::Display for CertStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CertStatus::Valid => write!(f, "Valid"),
CertStatus::ExpiringSoon { not_after } => {
write!(f, "ExpiringSoon (not_after={})", not_after)
}
CertStatus::Missing { paths } => {
let path_strs: Vec<String> =
paths.iter().map(|p| p.display().to_string()).collect();
write!(f, "Missing: [{}]", path_strs.join(", "))
}
CertStatus::Corrupt { path, error } => {
write!(f, "Corrupt: {} ({})", path.display(), error)
}
CertStatus::Expired { path, not_after } => {
write!(f, "Expired: {} (not_after={})", path.display(), not_after)
}
CertStatus::KeyMismatch => write!(f, "KeyMismatch"),
CertStatus::Untrusted => write!(f, "Untrusted"),
}
}
}
/// Validate TLS certificates for the auto-enrollment workflow.
///
/// Checks (in order):
/// 1. Existence: All three cert files must exist at configured paths
/// 2. PEM parse validity: CA and server cert must parse as X.509, server key must parse
/// 3. Expiry: CA and server cert must not be expired
/// 4. Key match: Server cert public key must match server key private key
/// 5. CA trust: Server cert must be signed by the CA
///
/// Returns the most severe status found.
pub fn validate_certs(config: &AppConfig) -> Result<CertStatus> {
let tls = match config.tls_config() {
Some(tls) => tls,
None => return Ok(CertStatus::Valid), // TLS disabled, nothing to validate
};
let threshold_days = config
.enrollment
.as_ref()
.map(|e| e.cert_renewal_threshold_days)
.unwrap_or(7);
// 1. Check existence of all three cert files
let ca_path = PathBuf::from(&tls.ca_cert);
let cert_path = PathBuf::from(&tls.server_cert);
let key_path = PathBuf::from(&tls.server_key);
let mut missing_paths = Vec::new();
if !ca_path.exists() {
missing_paths.push(ca_path.clone());
}
if !cert_path.exists() {
missing_paths.push(cert_path.clone());
}
if !key_path.exists() {
missing_paths.push(key_path.clone());
}
if !missing_paths.is_empty() {
return Ok(CertStatus::Missing {
paths: missing_paths,
});
}
// 2. Parse and validate PEM files using rustls_pemfile
// Parse CA certificate(s)
let ca_file = File::open(&ca_path)
.with_context(|| format!("Failed to open CA certificate: {}", ca_path.display()))?;
let ca_certs: Vec<_> = certs(&mut BufReader::new(ca_file))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA certificate PEM: {}", e))?;
if ca_certs.is_empty() {
return Ok(CertStatus::Corrupt {
path: ca_path,
error: "No certificates found in CA PEM file".to_string(),
});
}
// Parse server certificate
let server_file = File::open(&cert_path)
.with_context(|| format!("Failed to open server certificate: {}", cert_path.display()))?;
let server_certs: Vec<_> = certs(&mut BufReader::new(server_file))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse server certificate PEM: {}", e))?;
if server_certs.is_empty() {
return Ok(CertStatus::Corrupt {
path: cert_path.clone(),
error: "No certificates found in server PEM file".to_string(),
});
}
// Parse server private key
let key_file = File::open(&key_path)
.with_context(|| format!("Failed to open server key: {}", key_path.display()))?;
let server_key = private_key(&mut BufReader::new(key_file))
.map_err(|e| anyhow::anyhow!("Failed to parse server key PEM: {}", e))?;
let server_key = match server_key {
Some(key) => key,
None => {
return Ok(CertStatus::Corrupt {
path: key_path,
error: "No private key found in server key PEM file".to_string(),
})
}
};
// 3. Check expiry using x509_parser
let now = OffsetDateTime::now_utc();
let threshold = time::Duration::days(i64::from(threshold_days));
// Check CA cert expiry
let ca_der = ca_certs
.first()
.expect("ca_certs verified non-empty above");
match x509_parser::parse_x509_certificate(ca_der.as_ref()) {
Ok((_, ca_cert)) => {
let ca_not_after = ca_cert.validity().not_after.to_datetime();
if ca_not_after < now {
return Ok(CertStatus::Expired {
path: ca_path,
not_after: ca_not_after,
});
}
}
Err(e) => {
return Ok(CertStatus::Corrupt {
path: ca_path,
error: format!("Failed to parse CA certificate DER: {}", e),
})
}
}
// Check server cert expiry
let server_der = server_certs
.first()
.expect("server_certs verified non-empty above");
let server_not_after: OffsetDateTime = match x509_parser::parse_x509_certificate(server_der.as_ref()) {
Ok((_, cert)) => {
let not_after = cert.validity().not_after.to_datetime();
if not_after < now {
return Ok(CertStatus::Expired {
path: cert_path.clone(),
not_after,
});
}
not_after
}
Err(e) => {
return Ok(CertStatus::Corrupt {
path: cert_path,
error: format!("Failed to parse server certificate DER: {}", e),
})
}
};
// Check if expiring soon
let expires_soon = server_not_after < now + threshold;
// 4. Check key match: verify that the server cert's public key corresponds
// to the server private key by attempting to build a rustls ServerConfig.
// If the key doesn't match the cert, rustls will reject it.
let key_matches = verify_key_match(&ca_certs, &server_certs, &server_key);
if !key_matches {
return Ok(CertStatus::KeyMismatch);
}
// 5. Check CA trust: server cert must be signed by the CA
// Verify by checking if the server cert's issuer matches the CA cert's subject
let trusted = verify_ca_trust(server_der.as_ref(), ca_der.as_ref());
if !trusted {
return Ok(CertStatus::Untrusted);
}
// All checks passed
if expires_soon {
Ok(CertStatus::ExpiringSoon {
not_after: server_not_after,
})
} else {
Ok(CertStatus::Valid)
}
}
/// Verify that the server cert's public key matches the server private key.
/// Attempts to build a rustls ServerConfig with the given certs and key.
/// If the key doesn't match the cert, the configuration will fail.
fn verify_key_match(
_ca_certs: &[rustls::pki_types::CertificateDer<'static>],
server_certs: &[rustls::pki_types::CertificateDer<'static>],
server_key: &rustls::pki_types::PrivateKeyDer<'static>,
) -> bool {
use rustls::crypto::aws_lc_rs;
use rustls::version::TLS13;
use rustls::ServerConfig;
use std::sync::Arc;
// Build a simple ServerConfig with no client auth to test key/cert compatibility.
// If the key doesn't match the cert, with_single_cert will return an error.
let provider = aws_lc_rs::default_provider();
let config_result = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map(|b| b.with_no_client_auth())
.map(|b| b.with_single_cert(server_certs.to_vec(), server_key.clone_key()));
match config_result {
Ok(Ok(_)) => true,
Ok(Err(_)) | Err(_) => {
tracing::debug!("Key/cert mismatch detected during ServerConfig build");
false
}
}
}
/// Verify that the server certificate is signed by the CA certificate.
/// Checks if the server cert's issuer matches the CA cert's subject.
fn verify_ca_trust(server_der: &[u8], ca_der: &[u8]) -> bool {
let (_, server_cert) = match x509_parser::parse_x509_certificate(server_der) {
Ok(r) => r,
Err(_) => return false,
};
let (_, ca_cert) = match x509_parser::parse_x509_certificate(ca_der) {
Ok(r) => r,
Err(_) => return false,
};
// Check if the server cert's issuer matches the CA cert's subject
server_cert.issuer() == ca_cert.subject()
}
/// Application configuration
#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
#[serde(default)]
@ -157,17 +456,15 @@ impl AppConfig {
let config: AppConfig = serde_yaml::from_str(&content)
.with_context(|| format!("Failed to parse config file: {}", path))?;
// Migrate: if enrollment.manager_url is an empty string, treat as None
let config = config.migrate_empty_strings();
// Validate TLS configuration if enabled (skip during enrollment bootstrap)
if let Some(ref tls) = config.tls {
if tls.enabled && !skip_tls_validation {
if !std::path::Path::new(&tls.ca_cert).exists() {
anyhow::bail!("TLS CA certificate not found: {}", tls.ca_cert);
}
if !std::path::Path::new(&tls.server_cert).exists() {
anyhow::bail!("TLS server certificate not found: {}", tls.server_cert);
}
if !std::path::Path::new(&tls.server_key).exists() {
anyhow::bail!("TLS server key not found: {}", tls.server_key);
if !skip_tls_validation {
if let Some(ref tls) = config.tls {
if tls.enabled {
// Cert validation is now handled by validate_certs() in main.rs
// This no longer bails on missing cert files
}
}
}
@ -175,6 +472,20 @@ impl AppConfig {
Ok(config)
}
/// Migrate empty strings to None for Option fields.
/// Handles backward compatibility with old config format where
/// manager_url was a String (empty string means not configured).
fn migrate_empty_strings(mut self) -> Self {
if let Some(ref mut enrollment) = self.enrollment {
if let Some(ref url) = enrollment.manager_url {
if url.is_empty() {
enrollment.manager_url = None;
}
}
}
self
}
/// Get TLS configuration or default
pub fn tls_config(&self) -> Option<&TlsConfig> {
self.tls.as_ref().filter(|t| t.enabled)
@ -187,6 +498,54 @@ impl AppConfig {
.map(|w| w.path.as_str())
.unwrap_or("/etc/linux_patch_api/whitelist.yaml")
}
/// Get enrollment manager URL, if configured.
pub fn enrollment_manager_url(&self) -> Option<&str> {
self.enrollment
.as_ref()
.and_then(|e| e.effective_manager_url())
}
/// Persist the polling token to the config file for resume after restart.
/// Updates the in-memory config and writes to disk.
pub fn save_polling_token(&mut self, token: &str, config_path: &str) -> Result<()> {
if let Some(ref mut enrollment) = self.enrollment {
enrollment.polling_token = token.to_string();
} else {
self.enrollment = Some(EnrollmentConfig {
manager_url: None,
polling_token: token.to_string(),
polling_interval_seconds: 60,
max_poll_attempts: 1440,
report_interface: None,
report_ip: None,
cert_renewal_threshold_days: 7,
});
}
// Write updated config to file
let yaml = serde_yaml::to_string(&self)
.context("Failed to serialize config for polling token persistence")?;
std::fs::write(config_path, yaml)
.with_context(|| format!("Failed to write config file: {}", config_path))?;
Ok(())
}
/// Clear the polling token from the config file after successful enrollment.
pub fn clear_polling_token(&mut self, config_path: &str) -> Result<()> {
if let Some(ref mut enrollment) = self.enrollment {
enrollment.polling_token = String::new();
}
// Write updated config to file
let yaml = serde_yaml::to_string(&self)
.context("Failed to serialize config for polling token clear")?;
std::fs::write(config_path, yaml)
.with_context(|| format!("Failed to write config file: {}", config_path))?;
Ok(())
}
}
#[cfg(test)]
@ -201,107 +560,81 @@ mod tests {
"Failed to load valid config: {:?}",
result.err()
);
let config = result.unwrap();
assert_eq!(config.server.port, 12443);
assert_eq!(config.server.bind, "127.0.0.1");
assert_eq!(config.jobs.max_concurrent, 5);
assert_eq!(config.jobs.timeout_minutes, 30);
assert_eq!(config.logging.level, "info");
}
#[test]
fn test_config_load_missing_file() {
let result = AppConfig::load("/nonexistent/path/config.yaml", false);
assert!(result.is_err(), "Should fail for missing file");
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to read config file"));
fn test_cert_status_display() {
assert_eq!(format!("{}", CertStatus::Valid), "Valid");
assert_eq!(format!("{}", CertStatus::KeyMismatch), "KeyMismatch");
assert_eq!(format!("{}", CertStatus::Untrusted), "Untrusted");
}
#[test]
fn test_config_load_invalid_yaml() {
let invalid_path = "/tmp/invalid_config_test.yaml";
std::fs::write(invalid_path, "invalid: yaml: content: [").unwrap();
let result = AppConfig::load(invalid_path, false);
assert!(result.is_err(), "Should fail for invalid yaml");
std::fs::remove_file(invalid_path).unwrap();
}
#[test]
fn test_config_validation_port_range() {
let result = AppConfig::load("tests/fixtures/valid_config.yaml", false);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.server.port >= 1);
}
#[test]
fn test_config_validation_bind_address() {
let result = AppConfig::load("tests/fixtures/valid_config.yaml", false);
assert!(result.is_ok());
let config = result.unwrap();
assert!(!config.server.bind.is_empty());
}
#[test]
fn test_config_validation_max_concurrent() {
let result = AppConfig::load("tests/fixtures/valid_config.yaml", false);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.jobs.max_concurrent > 0);
}
#[test]
fn test_config_validation_timeout() {
let result = AppConfig::load("tests/fixtures/valid_config.yaml", false);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.jobs.timeout_minutes >= 1 && config.jobs.timeout_minutes <= 1440);
}
#[test]
fn test_tls_config_defaults() {
let config = AppConfig {
server: ServerConfig {
port: 12443,
bind: "0.0.0.0".to_string(),
timeout_seconds: 30,
},
tls: Some(TlsConfig {
enabled: true,
port: 12443,
ca_cert: "/etc/linux_patch_api/certs/ca.pem".to_string(),
server_cert: "/etc/linux_patch_api/certs/server.pem".to_string(),
server_key: "/etc/linux_patch_api/certs/server.key".to_string(),
min_tls_version: "1.3".to_string(),
}),
jobs: JobsConfig {
max_concurrent: 5,
timeout_minutes: 30,
storage_path: "/var/lib/linux_patch_api/jobs".to_string(),
},
logging: LoggingConfig {
level: "info".to_string(),
journal_enabled: true,
syslog_enabled: false,
syslog_server: None,
file_path: "/var/log/linux_patch_api/audit.log".to_string(),
retention_days: 30,
},
whitelist: Some(WhitelistConfig {
path: "/etc/linux_patch_api/whitelist.yaml".to_string(),
}),
package_manager: None,
enrollment: None,
fn test_cert_status_missing_display() {
let status = CertStatus::Missing {
paths: vec![PathBuf::from("/etc/ssl/ca.pem")],
};
let display = format!("{}", status);
assert!(display.contains("Missing"));
assert!(display.contains("/etc/ssl/ca.pem"));
}
assert!(config.tls_config().is_some());
assert_eq!(config.tls_config().unwrap().min_tls_version, "1.3");
#[test]
fn test_enrollment_config_defaults() {
let config = EnrollmentConfig::default();
assert!(config.manager_url.is_none());
assert!(config.polling_token.is_empty());
assert_eq!(config.polling_interval_seconds, 60);
assert_eq!(config.max_poll_attempts, 1440);
assert_eq!(config.cert_renewal_threshold_days, 7);
}
#[test]
fn test_enrollment_config_with_url() {
let yaml = r#"
manager_url: "https://manager.example.com"
polling_interval_seconds: 30
max_poll_attempts: 720
cert_renewal_threshold_days: 14
"#;
let config: EnrollmentConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(
config.whitelist_path(),
"/etc/linux_patch_api/whitelist.yaml"
config.manager_url,
Some("https://manager.example.com".to_string())
);
assert_eq!(config.polling_interval_seconds, 30);
assert_eq!(config.max_poll_attempts, 720);
assert_eq!(config.cert_renewal_threshold_days, 14);
}
#[test]
fn test_effective_manager_url() {
let mut config = EnrollmentConfig::default();
assert!(config.effective_manager_url().is_none());
config.manager_url = Some("https://manager.example.com".to_string());
assert_eq!(config.effective_manager_url(), Some("https://manager.example.com"));
config.manager_url = Some("".to_string());
assert!(config.effective_manager_url().is_none());
}
#[test]
fn test_migrate_empty_strings() {
let yaml = r#"
server:
port: 12443
bind: "0.0.0.0"
jobs:
max_concurrent: 5
timeout_minutes: 30
logging:
level: "info"
enrollment:
manager_url: ""
"#;
let config: AppConfig = serde_yaml::from_str(yaml).unwrap();
let migrated = config.migrate_empty_strings();
assert!(migrated.enrollment.unwrap().manager_url.is_none());
}
}

View File

@ -6,6 +6,6 @@
//! - Auto-reload on file change via notify watcher
pub mod loader;
pub use loader::EnrollmentConfig;
pub use loader::{AppConfig, CertStatus, EnrollmentConfig, validate_certs};
pub mod validator;
pub mod watcher;

View File

@ -272,6 +272,14 @@ impl EnrollmentClient {
Ok(enrollment_response)
}
409 => {
// Host already exists - log warning and return special response
// The caller should skip to polling phase with existing token
tracing::warn!(
"Host already registered with manager (HTTP 409) — will attempt to resume polling"
);
Err(anyhow!("ENROLLMENT_CONFLICT: Host already exists"))
}
429 => {
Err(anyhow!(
"Rate limited (HTTP 429) — enrollment requests limited to 1/minute per IP. Retry after 60 seconds."

View File

@ -3,6 +3,12 @@
//! Handles secure registration with the patch manager, including
//! identity extraction (machine-id, FQDN, IPs, OS details) and
//! mTLS enrollment via the manager API.
//!
//! Supports:
//! - Auto-enrollment on startup when certs are missing/invalid
//! - Manual enrollment via `--enroll <url>` CLI flag
//! - Resume polling from persisted token after restart
//! - HTTP 409 (host already exists) handling
pub mod client;
pub mod identity;
@ -20,17 +26,42 @@ pub use identity::{
get_primary_ip, get_route_source_ip, is_container_bridge, is_link_local,
};
/// Error type for enrollment conflict (HTTP 409).
/// Used to signal that the host is already registered and we should
/// skip to the polling phase.
#[derive(Debug)]
pub struct EnrollmentConflictError;
impl std::fmt::Display for EnrollmentConflictError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Host already registered with manager")
}
}
impl std::error::Error for EnrollmentConflictError {}
/// Run the full enrollment flow against the manager at the given URL.
///
/// # Phases
/// 1. **Registration** - POST machine identity to manager, receive polling token
/// - If HTTP 409 (host already exists), skip to Phase 2 with existing token
/// 2. **Polling** - Poll manager for approval with configurable interval/max attempts
/// - If `polling_token` is already in config, skip Phase 1 and resume polling
/// 3. **Provisioning** - Write PKI bundle to disk (certs/keys) and append manager IP to whitelist
///
/// # Arguments
/// * `manager_url` - The manager API base URL
/// * `config` - Mutable reference to AppConfig for polling token persistence
/// * `config_path` - Path to config file for persisting polling token
///
/// # Errors
/// Returns Err on registration failure, polling timeout, denial, user interruption,
/// PKI provisioning failure, or whitelist update failure.
pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Result<()> {
pub async fn run_enrollment(
manager_url: &str,
config: &mut super::AppConfig,
config_path: &str,
) -> Result<()> {
// Extract IP reporting overrides from enrollment config
let (report_interface, report_ip) = config
.enrollment
@ -40,13 +71,66 @@ pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Res
let client = EnrollmentClient::with_ip_overrides(manager_url, report_interface, report_ip);
// Phase 1: Registration
tracing::info!(
manager_url = manager_url,
"Starting enrollment - registration phase"
);
let response = client.register().await?;
tracing::info!("Registration successful - received polling token");
// Check for existing polling token to resume
let polling_token = if let Some(ref enrollment) = config.enrollment {
if !enrollment.polling_token.is_empty() {
tracing::info!(
"Resuming enrollment polling from saved token (host already registered)"
);
enrollment.polling_token.clone()
} else {
// No saved token — need to register first
String::new()
}
} else {
String::new()
};
// Phase 1: Registration (skip if we have a saved polling token)
let polling_token = if polling_token.is_empty() {
tracing::info!(
manager_url = manager_url,
"Starting enrollment - registration phase"
);
match client.register().await {
Ok(response) => {
tracing::info!("Registration successful - received polling token");
response.polling_token
}
Err(e) => {
let err_str = e.to_string();
if err_str.contains("ENROLLMENT_CONFLICT") {
// HTTP 409 - host already exists
// We don't have a polling token, so we can't resume polling
// Log a warning and return an error — the user needs to
// re-enroll or the manager needs to provide a new token
tracing::warn!(
"Host already registered but no polling token saved. \
Cannot resume polling. Re-run enrollment or check manager status."
);
return Err(anyhow::anyhow!(
"Host already registered with manager but no polling token available for resume. \
Please check the manager for your host status or re-enroll."
));
}
// For other errors, propagate directly
return Err(e);
}
}
} else {
tracing::info!("Using saved polling token to resume enrollment");
polling_token
};
// Persist polling token for resume after restart
if let Err(e) = config.save_polling_token(&polling_token, config_path) {
tracing::warn!(
error = %e,
"Failed to persist polling token — enrollment will not resume after restart"
);
} else {
tracing::debug!("Polling token persisted to config");
}
// Get polling config (use defaults if not set)
let interval = config
@ -67,7 +151,7 @@ pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Res
"Starting enrollment - polling phase"
);
let pki_bundle = client
.poll_for_approval(&response.polling_token, interval, max_attempts)
.poll_for_approval(&polling_token, interval, max_attempts)
.await?;
// Phase 3: PKI provisioning & whitelist update
@ -91,6 +175,16 @@ pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Res
provision::append_manager_to_whitelist(&manager_ip, config.whitelist_path()).await?;
tracing::info!(manager_ip = %manager_ip, "Manager IP appended to whitelist");
// Clear polling token after successful provisioning
if let Err(e) = config.clear_polling_token(config_path) {
tracing::warn!(
error = %e,
"Failed to clear polling token from config — will attempt re-registration on next start"
);
} else {
tracing::debug!("Polling token cleared from config");
}
tracing::info!("Enrollment complete - PKI and whitelist configured");
Ok(())
}

View File

@ -12,17 +12,23 @@
//! - mTLS authentication required on port 12443
//! - IP whitelist enforced (deny by default)
//! - Detailed audit logging
//!
//! # Exit Codes
//!
//! - 0: Clean exit (no certs + no enrollment URL, or --enroll/--renew-certs success)
//! - 1: Error (config error, enrollment network failure, cert validation error)
//! - 2: Certs invalid, auto-enrollment in progress (triggers systemd restart with backoff)
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
use anyhow::Result;
use clap::Parser;
use std::net::TcpListener;
use std::sync::Arc;
use tracing::{error, info, warn};
use linux_patch_api::api::{configure_api_routes, configure_health_route};
use linux_patch_api::auth::{mtls, MtlsMiddleware, WhitelistManager};
use linux_patch_api::config::loader::{validate_certs, CertStatus};
use linux_patch_api::enroll;
use linux_patch_api::packages::cache::PackageCacheState;
use linux_patch_api::packages::create_backend;
@ -42,12 +48,29 @@ struct Args {
#[arg(short, long)]
verbose: bool,
/// Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)
/// Enroll with manager at URL (skips mTLS startup, runs enrollment flow only, then exits)
#[arg(
long,
help = "Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)"
help = "Enroll with manager at URL (skips mTLS startup, runs enrollment flow only, then exits)"
)]
enroll: Option<String>,
/// Validate existing certs and re-enroll if expiring within threshold or invalid
#[arg(
long,
help = "Validate existing certs and re-enroll if expiring within threshold or invalid, then exits"
)]
renew_certs: bool,
}
/// Exit codes for the daemon
enum ExitCode {
/// Clean exit: no certs + no enrollment URL, or --enroll/--renew-certs success
Clean = 0,
/// Error: config error, enrollment network failure, cert validation error
Error = 1,
/// Certs invalid, auto-enrollment in progress (triggers systemd restart with backoff)
EnrollmentInProgress = 2,
}
#[actix_web::main]
@ -69,8 +92,9 @@ async fn main() -> Result<()> {
"Linux Patch API starting"
);
// Load configuration
let config = match AppConfig::load(&args.config, args.enroll.is_some()) {
// Load configuration (skip TLS validation during enrollment mode)
let skip_tls_validation = args.enroll.is_some();
let mut config = match AppConfig::load(&args.config, skip_tls_validation) {
Ok(cfg) => {
info!(
port = cfg.server.port,
@ -81,23 +105,142 @@ async fn main() -> Result<()> {
}
Err(e) => {
error!(error = %e, path = args.config, "Failed to load configuration");
return Err(anyhow::anyhow!("Configuration error: {}", e));
std::process::exit(ExitCode::Error as i32);
}
};
// Handle enrollment mode - runs before server startup
// Handle --renew-certs flag: validate certs and re-enroll if needed
if args.renew_certs {
info!("Certificate renewal mode activated - validating existing certificates");
match validate_certs(&config) {
Ok(CertStatus::Valid) => {
info!("Certificates are valid and not expiring soon. No renewal needed.");
std::process::exit(ExitCode::Clean as i32);
}
Ok(CertStatus::ExpiringSoon { not_after }) => {
info!(
not_after = %not_after,
"Certificates expiring soon - starting re-enrollment"
);
}
Ok(status) => {
info!(
status = %status,
"Certificates are {} - starting re-enrollment",
status
);
}
Err(e) => {
error!(error = %e, "Certificate validation failed");
std::process::exit(ExitCode::Error as i32);
}
}
// Need enrollment URL to re-enroll
let manager_url = match config.enrollment_manager_url() {
Some(url) => url.to_string(),
None => {
error!(
"Cannot re-enroll: enrollment.manager_url not configured. \
Add the manager URL to config.yaml or use --enroll <url>"
);
std::process::exit(ExitCode::Error as i32);
}
};
match enroll::run_enrollment(&manager_url, &mut config, &args.config).await {
Ok(()) => {
info!("Certificate renewal complete. Start service: systemctl start linux-patch-api");
std::process::exit(ExitCode::Clean as i32);
}
Err(e) => {
error!(error = %e, "Certificate renewal failed");
std::process::exit(ExitCode::Error as i32);
}
}
}
// Handle --enroll flag: run enrollment flow then EXIT
if let Some(ref manager_url) = args.enroll {
info!(
manager_url = manager_url,
"Enrollment mode activated - running enrollment flow before server startup"
"Enrollment mode activated - running enrollment flow"
);
match enroll::run_enrollment(manager_url, &config).await {
match enroll::run_enrollment(manager_url, &mut config, &args.config).await {
Ok(()) => {
info!("Enrollment complete - proceeding to server startup");
info!("Enrollment complete. Start service: systemctl start linux-patch-api");
std::process::exit(ExitCode::Clean as i32);
}
Err(e) => {
error!(error = %e, "Enrollment failed - shutting down");
return Err(anyhow::anyhow!("Enrollment failed: {}", e));
error!(error = %e, "Enrollment failed");
std::process::exit(ExitCode::Error as i32);
}
}
}
// Auto-enrollment on startup: validate certs before starting server
if config.tls_config().is_some() {
match validate_certs(&config) {
Ok(CertStatus::Valid) => {
info!("TLS certificates validated successfully");
}
Ok(CertStatus::ExpiringSoon { not_after }) => {
warn!(
not_after = %not_after,
"Certificates expiring soon - starting normally, consider re-enrollment"
);
// TODO: Schedule background re-enrollment in future phase
}
Ok(status @ CertStatus::Missing { .. })
| Ok(status @ CertStatus::Corrupt { .. })
| Ok(status @ CertStatus::Expired { .. })
| Ok(status @ CertStatus::KeyMismatch)
| Ok(status @ CertStatus::Untrusted) => {
// Certs are invalid - check if we can auto-enroll
// Clone the manager URL before mutable borrow of config
let manager_url_opt = config.enrollment_manager_url().map(|s| s.to_string());
match manager_url_opt {
Some(manager_url) => {
info!(
status = %status,
manager_url = manager_url,
"Certs {}. Auto-enrolling with {}",
status,
manager_url
);
match enroll::run_enrollment(&manager_url, &mut config, &args.config).await {
Ok(()) => {
info!("Auto-enrollment complete - continuing to server startup");
// Re-load config to pick up any changes from enrollment
config = AppConfig::load(&args.config, false)?;
}
Err(e) => {
error!(
error = %e,
"Auto-enrollment failed - will retry on next restart"
);
std::process::exit(ExitCode::EnrollmentInProgress as i32);
}
}
}
None => {
// No enrollment URL configured - exit cleanly to avoid crash loop
error!(
status = %status,
"Certs {}. No enrollment URL configured. \
To fix this, either:\n\
1. Add enrollment.manager_url to config.yaml and restart\n\
2. Run: linux-patch-api --enroll <manager_url>\n\
3. Place certificates manually in the configured paths",
status
);
std::process::exit(ExitCode::Clean as i32);
}
}
}
Err(e) => {
error!(error = %e, "Certificate validation error");
std::process::exit(ExitCode::Error as i32);
}
}
}
@ -153,9 +296,7 @@ async fn main() -> Result<()> {
// Configure bind address
let bind_address = format!("{}:{}", config.server.bind, config.server.port);
info!(bind = %bind_address, "Starting HTTP server");
// Create server
// Create server builder
let server_builder = HttpServer::new(move || {
let mut app = App::new()
@ -175,7 +316,6 @@ async fn main() -> Result<()> {
});
// Configure health route (outside API scope)
// cache_state and backend are available via app_data registered above
app = app.configure(configure_health_route);
app
@ -194,7 +334,6 @@ async fn main() -> Result<()> {
);
info!("Linux Patch API initialized successfully");
info!("Listening on {}", bind_address);
// Apply TLS/mTLS configuration if enabled
if let Some(tls_config) = config.tls_config() {
@ -222,11 +361,37 @@ async fn main() -> Result<()> {
info!("mTLS middleware and rustls config initialized successfully");
// Create TCP listener (std::net for listen_rustls_0_23)
let tcp_listener = TcpListener::bind(&bind_address)
.map_err(|e| anyhow::anyhow!("Failed to bind to {}: {}", bind_address, e))?;
// Create TCP listener with SO_REUSEADDR using socket2
// This prevents "Address already in use" errors when restarting after a crash
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)
.map_err(|e| anyhow::anyhow!("Failed to create socket: {}", e))?;
info!("TCP listener bound to {}", bind_address);
socket
.set_reuse_address(true)
.map_err(|e| anyhow::anyhow!("Failed to set SO_REUSEADDR: {}", e))?;
let bind_addr: std::net::SocketAddr = bind_address
.parse()
.map_err(|e| anyhow::anyhow!("Invalid bind address '{}': {}", bind_address, e))?;
socket
.bind(&socket2::SockAddr::from(bind_addr))
.map_err(|e| {
anyhow::anyhow!("Failed to bind socket to {}: {}", bind_address, e)
})?;
socket
.listen(128)
.map_err(|e| anyhow::anyhow!("Failed to listen on socket: {}", e))?;
let tcp_listener: std::net::TcpListener = socket.into();
// Log listening AFTER successful bind
info!("Listening on {} (mTLS enabled)", bind_address);
// Clone the ServerConfig from Arc for listen_rustls_0_23
let server_config = (*rustls_config).clone();
@ -245,8 +410,37 @@ async fn main() -> Result<()> {
}
}
} else {
// Create TCP listener with SO_REUSEADDR for non-TLS mode
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)
.map_err(|e| anyhow::anyhow!("Failed to create socket: {}", e))?;
socket
.set_reuse_address(true)
.map_err(|e| anyhow::anyhow!("Failed to set SO_REUSEADDR: {}", e))?;
let bind_addr: std::net::SocketAddr = bind_address
.parse()
.map_err(|e| anyhow::anyhow!("Invalid bind address '{}': {}", bind_address, e))?;
socket
.bind(&socket2::SockAddr::from(bind_addr))
.map_err(|e| anyhow::anyhow!("Failed to bind socket to {}: {}", bind_address, e))?;
socket
.listen(128)
.map_err(|e| anyhow::anyhow!("Failed to listen on socket: {}", e))?;
let tcp_listener: std::net::TcpListener = socket.into();
// Log listening AFTER successful bind
info!("Listening on {} (no TLS)", bind_address);
warn!("TLS is disabled - running without mTLS authentication (INSECURE)");
server_builder.bind(&bind_address)?.run().await?;
server_builder.listen(tcp_listener)?.run().await?;
}
info!("Linux Patch API shutting down");