Private
Public Access
1
0

fix: apply cargo fmt to resolve CI formatting failures

Format all enrollment module source files and tests per rustfmt standards.
Resolves Gitea CI workflow cargo fmt check failures.
This commit is contained in:
2026-05-17 05:49:26 +00:00
parent 75ec2b8e3c
commit 5c670cbd0c
9 changed files with 491 additions and 292 deletions

View File

@ -94,23 +94,32 @@ impl WhitelistManager {
// Parse to validate - must be IPv4 or CIDR, no hostnames in auto-append // Parse to validate - must be IPv4 or CIDR, no hostnames in auto-append
let parsed_entry = if let Some((ip_str, prefix_str)) = entry_str.split_once('/') { let parsed_entry = if let Some((ip_str, prefix_str)) = entry_str.split_once('/') {
let ip: Ipv4Addr = ip_str.parse() let ip: Ipv4Addr = ip_str
.parse()
.with_context(|| format!("Invalid IP in CIDR notation: {}", entry_str))?; .with_context(|| format!("Invalid IP in CIDR notation: {}", entry_str))?;
let prefix: u8 = prefix_str.parse() let prefix: u8 = prefix_str
.parse()
.with_context(|| format!("Invalid prefix in CIDR notation: {}", entry_str))?; .with_context(|| format!("Invalid prefix in CIDR notation: {}", entry_str))?;
if prefix > 32 { if prefix > 32 {
anyhow::bail!("Invalid CIDR prefix (must be 0-32): {}", entry_str); anyhow::bail!("Invalid CIDR prefix (must be 0-32): {}", entry_str);
} }
WhitelistEntry::Cidr { network: ip, prefix } WhitelistEntry::Cidr {
network: ip,
prefix,
}
} else { } else {
let ip: Ipv4Addr = entry_str.parse() let ip: Ipv4Addr = entry_str
.parse()
.with_context(|| format!("Invalid IPv4 address: {}", entry_str))?; .with_context(|| format!("Invalid IPv4 address: {}", entry_str))?;
WhitelistEntry::Ip(ip) WhitelistEntry::Ip(ip)
}; };
// 2. Check for duplicate in current in-memory state // 2. Check for duplicate in current in-memory state
{ {
let entries = self.entries.read().map_err(|e| anyhow::anyhow!("Failed to acquire whitelist read lock: {}", e))?; let entries = self
.entries
.read()
.map_err(|e| anyhow::anyhow!("Failed to acquire whitelist read lock: {}", e))?;
for existing in entries.iter() { for existing in entries.iter() {
if *existing == parsed_entry { if *existing == parsed_entry {
info!( info!(
@ -133,11 +142,16 @@ impl WhitelistManager {
.open(&lock_path) .open(&lock_path)
.with_context(|| format!("Failed to create lock file: {}", lock_path))?; .with_context(|| format!("Failed to create lock file: {}", lock_path))?;
lock_file.lock_exclusive().context("Failed to acquire exclusive whitelist lock")?; lock_file
.lock_exclusive()
.context("Failed to acquire exclusive whitelist lock")?;
// Double-check for duplicates after acquiring lock (concurrent append scenario) // Double-check for duplicates after acquiring lock (concurrent append scenario)
{ {
let entries = self.entries.read().map_err(|e| anyhow::anyhow!("Failed to acquire whitelist read lock: {}", e))?; let entries = self
.entries
.read()
.map_err(|e| anyhow::anyhow!("Failed to acquire whitelist read lock: {}", e))?;
for existing in entries.iter() { for existing in entries.iter() {
if *existing == parsed_entry { if *existing == parsed_entry {
info!( info!(
@ -154,9 +168,12 @@ impl WhitelistManager {
// 4. Read current whitelist YAML or create empty config // 4. Read current whitelist YAML or create empty config
let mut config = if Path::new(&self.config_path).exists() { let mut config = if Path::new(&self.config_path).exists() {
self.load_config().context("Failed to load existing whitelist for append")? self.load_config()
.context("Failed to load existing whitelist for append")?
} else { } else {
WhitelistConfig { entries: Vec::new() } WhitelistConfig {
entries: Vec::new(),
}
}; };
// 5. Append new entry to allowed_ips list // 5. Append new entry to allowed_ips list
@ -168,8 +185,9 @@ impl WhitelistManager {
// Ensure parent directory exists // Ensure parent directory exists
if let Some(parent) = config_path.parent() { if let Some(parent) = config_path.parent() {
if !parent.exists() { if !parent.exists() {
fs::create_dir_all(parent) fs::create_dir_all(parent).with_context(|| {
.with_context(|| format!("Failed to create whitelist directory: {}", parent.display()))?; format!("Failed to create whitelist directory: {}", parent.display())
})?;
} }
} }
@ -182,28 +200,35 @@ impl WhitelistManager {
.create_new(true) .create_new(true)
.truncate(true) .truncate(true)
.open(&temp_path) .open(&temp_path)
.with_context(|| format!("Failed to create temp whitelist file: {}", temp_path.display()))?;
file.write_all(yaml_content.as_bytes())
.with_context(|| format!("Failed to write whitelist data to: {}", temp_path.display()))?;
file.flush()
.with_context(|| format!("Failed to flush whitelist data to: {}", temp_path.display()))?;
// Atomic rename
fs::rename(&temp_path, config_path)
.with_context(|| { .with_context(|| {
format!( format!(
"Failed to atomically rename whitelist temp file {} to {}", "Failed to create temp whitelist file: {}",
temp_path.display(), temp_path.display()
config_path.display()
) )
})?; })?;
file.write_all(yaml_content.as_bytes()).with_context(|| {
format!("Failed to write whitelist data to: {}", temp_path.display())
})?;
file.flush().with_context(|| {
format!("Failed to flush whitelist data to: {}", temp_path.display())
})?;
// Atomic rename
fs::rename(&temp_path, config_path).with_context(|| {
format!(
"Failed to atomically rename whitelist temp file {} to {}",
temp_path.display(),
config_path.display()
)
})?;
// Release lock explicitly before reload (drop happens at end of scope) // Release lock explicitly before reload (drop happens at end of scope)
drop(lock_file); drop(lock_file);
// 7. Reload in-memory state // 7. Reload in-memory state
self.reload().context("Failed to reload whitelist after append")?; self.reload()
.context("Failed to reload whitelist after append")?;
// 8. Log audit event // 8. Log audit event
tracing::info!( tracing::info!(

View File

@ -7,7 +7,7 @@
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::signal::unix::{SignalKind, signal as unix_signal}; use tokio::signal::unix::{signal as unix_signal, SignalKind};
use crate::enroll::identity; use crate::enroll::identity;
@ -99,7 +99,7 @@ impl EnrollmentClient {
.expect("Failed to parse manager URL"); .expect("Failed to parse manager URL");
match parsed.scheme() { match parsed.scheme() {
"http" | "https" => {}, // Allowed schemes "http" | "https" => {} // Allowed schemes
other => panic!( other => panic!(
"Invalid manager URL scheme '{}' — only 'http' and 'https' are allowed. \ "Invalid manager URL scheme '{}' — only 'http' and 'https' are allowed. \
Refused dangerous scheme to prevent SSRF/path traversal.", Refused dangerous scheme to prevent SSRF/path traversal.",
@ -139,12 +139,11 @@ impl EnrollmentClient {
/// - `Err` if URL parsing fails or DNS resolution yields no results /// - `Err` if URL parsing fails or DNS resolution yields no results
pub async fn manager_ip(&self) -> Result<String> { pub async fn manager_ip(&self) -> Result<String> {
// Parse URL to extract host using url crate for RFC-compliant parsing // Parse URL to extract host using url crate for RFC-compliant parsing
let parsed = url::Url::parse(&self.manager_url).with_context(|| { let parsed = url::Url::parse(&self.manager_url)
format!("Failed to parse manager URL '{}'", self.manager_url) .with_context(|| format!("Failed to parse manager URL '{}'", self.manager_url))?;
})?; let host_str = parsed
let host_str = parsed.host_str().with_context(|| { .host_str()
format!("Manager URL '{}' has no host component", self.manager_url) .with_context(|| format!("Manager URL '{}' has no host component", self.manager_url))?;
})?;
// Check if already an IP address using url::Host parsing // Check if already an IP address using url::Host parsing
if let Ok(url::Host::Ipv4(addr)) = url::Host::parse(host_str) { if let Ok(url::Host::Ipv4(addr)) = url::Host::parse(host_str) {
@ -287,9 +286,8 @@ impl EnrollmentClient {
.await .await
.context("Failed to read status response body")?; .context("Failed to read status response body")?;
let status: EnrollmentStatusResponse = let status: EnrollmentStatusResponse = serde_json::from_str(&body)
serde_json::from_str(&body) .context("Invalid status response — malformed JSON from manager")?;
.context("Invalid status response — malformed JSON from manager")?;
Ok(status) Ok(status)
} }
@ -336,7 +334,11 @@ impl EnrollmentClient {
max_attempts: u32, max_attempts: u32,
) -> Result<PkiBundle> { ) -> Result<PkiBundle> {
// Enforce hard limits // Enforce hard limits
let effective_interval = if interval_seconds == 0 { 60 } else { interval_seconds }; let effective_interval = if interval_seconds == 0 {
60
} else {
interval_seconds
};
let effective_max = match max_attempts { let effective_max = match max_attempts {
0 => 1440, 0 => 1440,
n if n > 1440 => 1440, n if n > 1440 => 1440,
@ -417,7 +419,11 @@ impl EnrollmentClient {
attempts = attempt, attempts = attempt,
"Enrollment approved — received PKI bundle from manager" "Enrollment approved — received PKI bundle from manager"
); );
return Ok(PkiBundle { ca_crt, server_crt, server_key }); return Ok(PkiBundle {
ca_crt,
server_crt,
server_key,
});
} }
EnrollmentStatusResponse::Denied => { EnrollmentStatusResponse::Denied => {
tracing::warn!( tracing::warn!(
@ -444,20 +450,22 @@ impl EnrollmentClient {
total_seconds = total_seconds, total_seconds = total_seconds,
"Enrollment polling timed out after maximum attempts" "Enrollment polling timed out after maximum attempts"
); );
Err(anyhow!("Enrollment timed out after {} hours ({}/{} attempts)", Err(anyhow!(
total_seconds / 3600, effective_max, effective_max)) "Enrollment timed out after {} hours ({}/{} attempts)",
total_seconds / 3600,
effective_max,
effective_max
))
} }
/// Create a SIGINT (Ctrl+C) signal receiver. /// Create a SIGINT (Ctrl+C) signal receiver.
fn setup_sigint() -> Result<tokio::signal::unix::Signal> { fn setup_sigint() -> Result<tokio::signal::unix::Signal> {
unix_signal(SignalKind::interrupt()) unix_signal(SignalKind::interrupt()).context("Failed to create SIGINT signal handler")
.context("Failed to create SIGINT signal handler")
} }
/// Create a SIGTERM signal receiver. /// Create a SIGTERM signal receiver.
fn setup_sigterm() -> Result<tokio::signal::unix::Signal> { fn setup_sigterm() -> Result<tokio::signal::unix::Signal> {
unix_signal(SignalKind::terminate()) unix_signal(SignalKind::terminate()).context("Failed to create SIGTERM signal handler")
.context("Failed to create SIGTERM signal handler")
} }
} }

View File

@ -66,8 +66,7 @@ pub fn get_fqdn() -> Result<String> {
/// Collect all non-loopback IPv4 addresses from network interfaces. /// Collect all non-loopback IPv4 addresses from network interfaces.
pub fn get_ip_addresses() -> Result<Vec<String>> { pub fn get_ip_addresses() -> Result<Vec<String>> {
let ifaces = if_addrs::get_if_addrs() let ifaces = if_addrs::get_if_addrs().context("Failed to enumerate network interfaces")?;
.context("Failed to enumerate network interfaces")?;
let mut addrs: Vec<String> = ifaces let mut addrs: Vec<String> = ifaces
.iter() .iter()
@ -105,16 +104,28 @@ pub fn get_os_details() -> Result<serde_json::Value> {
let unquoted = value.trim().trim_matches('"').trim_matches('\''); let unquoted = value.trim().trim_matches('"').trim_matches('\'');
match key { match key {
"NAME" => { "NAME" => {
details.insert("distro".into(), serde_json::Value::String(unquoted.to_string())); details.insert(
"distro".into(),
serde_json::Value::String(unquoted.to_string()),
);
} }
"VERSION_ID" => { "VERSION_ID" => {
details.insert("version".into(), serde_json::Value::String(unquoted.to_string())); details.insert(
"version".into(),
serde_json::Value::String(unquoted.to_string()),
);
} }
"ID_LIKE" => { "ID_LIKE" => {
details.insert("id_like".into(), serde_json::Value::String(unquoted.to_string())); details.insert(
"id_like".into(),
serde_json::Value::String(unquoted.to_string()),
);
} }
"VERSION_CODENAME" => { "VERSION_CODENAME" => {
details.insert("codename".into(), serde_json::Value::String(unquoted.to_string())); details.insert(
"codename".into(),
serde_json::Value::String(unquoted.to_string()),
);
} }
_ => {} _ => {}
} }
@ -123,7 +134,10 @@ pub fn get_os_details() -> Result<serde_json::Value> {
} else { } else {
// Fallback for systems without os-release (very rare) // Fallback for systems without os-release (very rare)
details.insert("distro".into(), serde_json::Value::String("unknown".into())); details.insert("distro".into(), serde_json::Value::String("unknown".into()));
details.insert("version".into(), serde_json::Value::String("unknown".into())); details.insert(
"version".into(),
serde_json::Value::String("unknown".into()),
);
} }
// Kernel version via uname -r // Kernel version via uname -r
@ -159,6 +173,9 @@ mod tests {
#[test] #[test]
fn os_details_contains_kernel() { fn os_details_contains_kernel() {
let details = get_os_details().expect("Failed to get OS details"); let details = get_os_details().expect("Failed to get OS details");
assert!(details.get("kernel").is_some(), "OS details must contain kernel version"); assert!(
details.get("kernel").is_some(),
"OS details must contain kernel version"
);
} }
} }

View File

@ -12,8 +12,7 @@ use anyhow::{Context, Result};
/// Re-export key types for ergonomic access from parent modules. /// Re-export key types for ergonomic access from parent modules.
pub use client::{ pub use client::{
EnrollmentClient, EnrollmentRequest, EnrollmentResponse, EnrollmentClient, EnrollmentRequest, EnrollmentResponse, EnrollmentStatusResponse, PkiBundle,
EnrollmentStatusResponse, PkiBundle,
}; };
/// Re-export identity extraction functions. /// Re-export identity extraction functions.
pub use identity::{get_fqdn, get_ip_addresses, get_machine_id, get_os_details}; pub use identity::{get_fqdn, get_ip_addresses, get_machine_id, get_os_details};
@ -40,10 +39,16 @@ pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Res
tracing::info!("Registration successful - received polling token"); tracing::info!("Registration successful - received polling token");
// Get polling config (use defaults if not set) // Get polling config (use defaults if not set)
let interval = config.enrollment.as_ref() let interval = config
.map(|e| e.polling_interval_seconds).unwrap_or(60); .enrollment
let max_attempts = config.enrollment.as_ref() .as_ref()
.map(|e| e.max_poll_attempts).unwrap_or(1440); .map(|e| e.polling_interval_seconds)
.unwrap_or(60);
let max_attempts = config
.enrollment
.as_ref()
.map(|e| e.max_poll_attempts)
.unwrap_or(1440);
// Phase 2: Polling // Phase 2: Polling
tracing::info!( tracing::info!(
@ -51,7 +56,9 @@ pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Res
max_attempts = max_attempts, max_attempts = max_attempts,
"Starting enrollment - polling phase" "Starting enrollment - polling phase"
); );
let pki_bundle = client.poll_for_approval(&response.polling_token, interval, max_attempts).await?; let pki_bundle = client
.poll_for_approval(&response.polling_token, interval, max_attempts)
.await?;
// Phase 3: PKI provisioning & whitelist update // Phase 3: PKI provisioning & whitelist update
tracing::info!("Enrollment approved - starting PKI provisioning phase"); tracing::info!("Enrollment approved - starting PKI provisioning phase");
@ -62,13 +69,15 @@ pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Res
&pki_bundle.server_crt, &pki_bundle.server_crt,
&pki_bundle.server_key, &pki_bundle.server_key,
config.tls_config(), config.tls_config(),
).await?; )
.await?;
tracing::info!("PKI bundle written to disk"); tracing::info!("PKI bundle written to disk");
// Resolve manager hostname to IP and append to whitelist // Resolve manager hostname to IP and append to whitelist
let manager_ip = client.manager_ip().await.context( let manager_ip = client
"Failed to resolve manager IP - cannot update whitelist", .manager_ip()
)?; .await
.context("Failed to resolve manager IP - cannot update whitelist")?;
provision::append_manager_to_whitelist(&manager_ip, config.whitelist_path()).await?; provision::append_manager_to_whitelist(&manager_ip, config.whitelist_path()).await?;
tracing::info!(manager_ip = %manager_ip, "Manager IP appended to whitelist"); tracing::info!(manager_ip = %manager_ip, "Manager IP appended to whitelist");

View File

@ -1,8 +1,8 @@
//! PKI provisioning module for self-enrollment. //! PKI provisioning module for self-enrollment.
//! Handles certificate extraction, validation, and secure file writing. //! Handles certificate extraction, validation, and secure file writing.
use anyhow::{bail, Context, Result};
use crate::auth::WhitelistManager; use crate::auth::WhitelistManager;
use anyhow::{bail, Context, Result};
use std::fs::{self, OpenOptions}; use std::fs::{self, OpenOptions};
use std::io::Write; use std::io::Write;
use std::os::unix::fs::OpenOptionsExt; use std::os::unix::fs::OpenOptionsExt;
@ -71,8 +71,9 @@ pub fn write_pem_file(path: &str, pem_data: &str, is_key: bool) -> Result<()> {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(parent)?.permissions(); let mut perms = fs::metadata(parent)?.permissions();
perms.set_mode(0o755); perms.set_mode(0o755);
fs::set_permissions(parent, perms) fs::set_permissions(parent, perms).with_context(|| {
.with_context(|| format!("Failed to set permissions on: {}", parent.display()))?; format!("Failed to set permissions on: {}", parent.display())
})?;
} }
} }
} }
@ -107,14 +108,13 @@ pub fn write_pem_file(path: &str, pem_data: &str, is_key: bool) -> Result<()> {
.with_context(|| format!("Failed to flush PEM data to: {}", temp_path.display()))?; .with_context(|| format!("Failed to flush PEM data to: {}", temp_path.display()))?;
// Atomic rename to target path // Atomic rename to target path
fs::rename(&temp_path, path) fs::rename(&temp_path, path).with_context(|| {
.with_context(|| { format!(
format!( "Failed to atomically rename {} to {}",
"Failed to atomically rename {} to {}", temp_path.display(),
temp_path.display(), path.display()
path.display() )
) })?;
})?;
tracing::info!( tracing::info!(
path = %path.display(), path = %path.display(),
@ -138,7 +138,11 @@ pub async fn provision_pki_bundle(
) -> Result<()> { ) -> Result<()> {
// Determine target paths from config or defaults // Determine target paths from config or defaults
let (ca_path, cert_path, key_path) = if let Some(tls) = tls_config { let (ca_path, cert_path, key_path) = if let Some(tls) = tls_config {
(tls.ca_cert.clone(), tls.server_cert.clone(), tls.server_key.clone()) (
tls.ca_cert.clone(),
tls.server_cert.clone(),
tls.server_key.clone(),
)
} else { } else {
( (
DEFAULT_CA_CERT.to_string(), DEFAULT_CA_CERT.to_string(),
@ -148,10 +152,8 @@ pub async fn provision_pki_bundle(
}; };
// 1. Validate all three PEM strings before any writes // 1. Validate all three PEM strings before any writes
validate_pem(ca_crt, "CERTIFICATE") validate_pem(ca_crt, "CERTIFICATE").context("CA certificate validation failed")?;
.context("CA certificate validation failed")?; validate_pem(server_crt, "CERTIFICATE").context("Server certificate validation failed")?;
validate_pem(server_crt, "CERTIFICATE")
.context("Server certificate validation failed")?;
// Server key can be PRIVATE KEY (PKCS#8), RSA PRIVATE KEY (PKCS#1), or EC PRIVATE KEY // Server key can be PRIVATE KEY (PKCS#8), RSA PRIVATE KEY (PKCS#1), or EC PRIVATE KEY
let key_valid = validate_pem(server_key, "PRIVATE KEY").is_ok() let key_valid = validate_pem(server_key, "PRIVATE KEY").is_ok()
@ -165,14 +167,11 @@ pub async fn provision_pki_bundle(
} }
// 2. Write to configured paths (atomic writes) // 2. Write to configured paths (atomic writes)
write_pem_file(&ca_path, ca_crt, false) write_pem_file(&ca_path, ca_crt, false).context("Failed to write CA certificate")?;
.context("Failed to write CA certificate")?;
write_pem_file(&cert_path, server_crt, false) write_pem_file(&cert_path, server_crt, false).context("Failed to write server certificate")?;
.context("Failed to write server certificate")?;
write_pem_file(&key_path, server_key, true) write_pem_file(&key_path, server_key, true).context("Failed to write server key")?;
.context("Failed to write server key")?;
// 3. Log successful provisioning with structured fields // 3. Log successful provisioning with structured fields
tracing::info!( tracing::info!(
@ -198,11 +197,19 @@ pub async fn append_manager_to_whitelist(manager_ip: &str, whitelist_path: &str)
} }
// Create or load WhitelistManager and call append_entry // Create or load WhitelistManager and call append_entry
let mut manager = WhitelistManager::new(whitelist_path) let mut manager = WhitelistManager::new(whitelist_path).with_context(|| {
.with_context(|| format!("Failed to initialize whitelist manager for path: {}", whitelist_path))?; format!(
"Failed to initialize whitelist manager for path: {}",
whitelist_path
)
})?;
manager.append_entry(ip_or_cidr) manager.append_entry(ip_or_cidr).with_context(|| {
.with_context(|| format!("Failed to append manager IP '{}' to whitelist at: {}", ip_or_cidr, whitelist_path))?; format!(
"Failed to append manager IP '{}' to whitelist at: {}",
ip_or_cidr, whitelist_path
)
})?;
Ok(()) Ok(())
} }
@ -343,7 +350,8 @@ mod tests {
let dir = tempdir().expect("failed to create temp dir"); let dir = tempdir().expect("failed to create temp dir");
let target_path = dir.path().join("cert.pem"); let target_path = dir.path().join("cert.pem");
let cert1 = sample_certificate(); let cert1 = sample_certificate();
let cert2 = "-----BEGIN CERTIFICATE-----\nNEWCERTDATA\n-----END CERTIFICATE-----".to_string(); let cert2 =
"-----BEGIN CERTIFICATE-----\nNEWCERTDATA\n-----END CERTIFICATE-----".to_string();
// Write initial file // Write initial file
write_pem_file(target_path.to_str().unwrap(), &cert1, false).expect("initial write failed"); write_pem_file(target_path.to_str().unwrap(), &cert1, false).expect("initial write failed");
@ -352,7 +360,10 @@ mod tests {
write_pem_file(target_path.to_str().unwrap(), &cert2, false).expect("second write failed"); write_pem_file(target_path.to_str().unwrap(), &cert2, false).expect("second write failed");
let backup_path = format!("{}.bak", target_path.display()); let backup_path = format!("{}.bak", target_path.display());
assert!(std::path::Path::new(&backup_path).exists(), "Backup file should exist"); assert!(
std::path::Path::new(&backup_path).exists(),
"Backup file should exist"
);
// Original content in backup // Original content in backup
let backup_content = fs::read_to_string(&backup_path).expect("failed to read backup"); let backup_content = fs::read_to_string(&backup_path).expect("failed to read backup");

View File

@ -23,8 +23,8 @@ use tracing::{error, info, warn};
use linux_patch_api::api::{configure_api_routes, configure_health_route}; use linux_patch_api::api::{configure_api_routes, configure_health_route};
use linux_patch_api::auth::{mtls, MtlsMiddleware, WhitelistManager}; use linux_patch_api::auth::{mtls, MtlsMiddleware, WhitelistManager};
use linux_patch_api::packages::create_backend;
use linux_patch_api::enroll; use linux_patch_api::enroll;
use linux_patch_api::packages::create_backend;
use linux_patch_api::{init_logging, AppConfig, JobManager}; use linux_patch_api::{init_logging, AppConfig, JobManager};
/// Linux Patch API CLI arguments /// Linux Patch API CLI arguments
@ -42,7 +42,10 @@ struct Args {
verbose: bool, verbose: bool,
/// Enroll with manager at URL (skips mTLS startup, runs enrollment flow only) /// Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)
#[arg(long, help = "Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)")] #[arg(
long,
help = "Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)"
)]
enroll: Option<String>, enroll: Option<String>,
} }
@ -78,7 +81,10 @@ async fn main() -> Result<()> {
// Handle enrollment mode - runs before server startup // Handle enrollment mode - runs before server startup
if let Some(ref manager_url) = args.enroll { if let Some(ref manager_url) = args.enroll {
info!(manager_url = manager_url, "Enrollment mode activated - running enrollment flow before server startup"); info!(
manager_url = manager_url,
"Enrollment mode activated - running enrollment flow before server startup"
);
match enroll::run_enrollment(manager_url, &config).await { match enroll::run_enrollment(manager_url, &config).await {
Ok(()) => { Ok(()) => {
info!("Enrollment complete - proceeding to server startup"); info!("Enrollment complete - proceeding to server startup");

View File

@ -25,8 +25,8 @@ use std::os::unix::fs::PermissionsExt;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::Arc;
use tempfile::TempDir; use tempfile::TempDir;
use wiremock::{Mock, MockServer, ResponseTemplate};
use wiremock::matchers::{method, path, path_regex}; use wiremock::matchers::{method, path, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
/// Test constants /// Test constants
const TEST_TOKEN: &str = "test_enrollment_token"; const TEST_TOKEN: &str = "test_enrollment_token";
@ -63,10 +63,7 @@ fn create_temp_dirs() -> (TempDir, TempDir) {
/// Initialize an empty whitelist YAML file at the given path. /// Initialize an empty whitelist YAML file at the given path.
/// Required because WhitelistManager::new() loads existing config on construction. /// Required because WhitelistManager::new() loads existing config on construction.
fn init_empty_whitelist(path: &str) { fn init_empty_whitelist(path: &str) {
std::fs::write( std::fs::write(path, "entries: []\n").expect("Failed to create initial whitelist file");
path,
"entries: []\n",
).expect("Failed to create initial whitelist file");
} }
/// Build a TLS config pointing to the temp certificate directory. /// Build a TLS config pointing to the temp certificate directory.
@ -76,7 +73,10 @@ fn build_tls_config(cert_dir: &std::path::Path) -> TlsConfig {
port: 12443, port: 12443,
ca_cert: cert_dir.join("ca.pem").to_string_lossy().to_string(), ca_cert: cert_dir.join("ca.pem").to_string_lossy().to_string(),
server_cert: cert_dir.join("server.pem").to_string_lossy().to_string(), server_cert: cert_dir.join("server.pem").to_string_lossy().to_string(),
server_key: cert_dir.join("server.key.pem").to_string_lossy().to_string(), server_key: cert_dir
.join("server.key.pem")
.to_string_lossy()
.to_string(),
min_tls_version: "1.3".to_string(), min_tls_version: "1.3".to_string(),
} }
} }
@ -104,7 +104,11 @@ async fn test_full_enrollment_flow_happy_path() {
let ca_cert_path = cert_dir.path().join("ca.pem"); let ca_cert_path = cert_dir.path().join("ca.pem");
let server_cert_path = cert_dir.path().join("server.pem"); let server_cert_path = cert_dir.path().join("server.pem");
let server_key_path = cert_dir.path().join("server.key.pem"); let server_key_path = cert_dir.path().join("server.key.pem");
let whitelist_path = whitelist_dir.path().join("whitelist.yaml").to_string_lossy().to_string(); let whitelist_path = whitelist_dir
.path()
.join("whitelist.yaml")
.to_string_lossy()
.to_string();
init_empty_whitelist(&whitelist_path); init_empty_whitelist(&whitelist_path);
@ -128,8 +132,7 @@ async fn test_full_enrollment_flow_happy_path() {
let count = poll_count_clone.fetch_add(1, Ordering::SeqCst); let count = poll_count_clone.fetch_add(1, Ordering::SeqCst);
if count < 1 { if count < 1 {
// First poll returns pending (simulates admin review delay) // First poll returns pending (simulates admin review delay)
ResponseTemplate::new(200) ResponseTemplate::new(200).set_body_string(r#"{"status": "pending"}"#)
.set_body_string(r#"{"status": "pending"}"#)
} else { } else {
// Second poll returns approved with full PKI bundle // Second poll returns approved with full PKI bundle
ResponseTemplate::new(200).set_body_string(&format!( ResponseTemplate::new(200).set_body_string(&format!(
@ -152,7 +155,10 @@ async fn test_full_enrollment_flow_happy_path() {
let client = build_client(&base_url); let client = build_client(&base_url);
// Phase 1: Registration // Phase 1: Registration
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
assert_eq!(response.polling_token, TEST_TOKEN); assert_eq!(response.polling_token, TEST_TOKEN);
// Phase 2: Polling (should get pending first, then approved) // Phase 2: Polling (should get pending first, then approved)
@ -172,10 +178,15 @@ async fn test_full_enrollment_flow_happy_path() {
&bundle.server_crt, &bundle.server_crt,
&bundle.server_key, &bundle.server_key,
Some(&tls_config), Some(&tls_config),
).await.expect("PKI provisioning should succeed"); )
.await
.expect("PKI provisioning should succeed");
// Phase 3b: Whitelist update (manager_ip for localhost URL returns 127.0.0.1) // Phase 3b: Whitelist update (manager_ip for localhost URL returns 127.0.0.1)
let manager_ip = client.manager_ip().await.expect("Should resolve manager IP"); let manager_ip = client
.manager_ip()
.await
.expect("Should resolve manager IP");
provision::append_manager_to_whitelist(&manager_ip, &whitelist_path) provision::append_manager_to_whitelist(&manager_ip, &whitelist_path)
.await .await
.expect("Whitelist append should succeed"); .expect("Whitelist append should succeed");
@ -186,14 +197,29 @@ async fn test_full_enrollment_flow_happy_path() {
assert!(server_key_path.exists(), "Server key file should exist"); assert!(server_key_path.exists(), "Server key file should exist");
// Verify: correct permissions (key=0o600, certs=0o644) // Verify: correct permissions (key=0o600, certs=0o644)
let key_perms = std::fs::metadata(&server_key_path).unwrap().permissions().mode() & 0o777; let key_perms = std::fs::metadata(&server_key_path)
.unwrap()
.permissions()
.mode()
& 0o777;
assert_eq!(key_perms, 0o600, "Key file should have 0o600 permissions"); assert_eq!(key_perms, 0o600, "Key file should have 0o600 permissions");
let ca_perms = std::fs::metadata(&ca_cert_path).unwrap().permissions().mode() & 0o777; let ca_perms = std::fs::metadata(&ca_cert_path)
.unwrap()
.permissions()
.mode()
& 0o777;
assert_eq!(ca_perms, 0o644, "CA cert should have 0o644 permissions"); assert_eq!(ca_perms, 0o644, "CA cert should have 0o644 permissions");
let server_perms = std::fs::metadata(&server_cert_path).unwrap().permissions().mode() & 0o777; let server_perms = std::fs::metadata(&server_cert_path)
assert_eq!(server_perms, 0o644, "Server cert should have 0o644 permissions"); .unwrap()
.permissions()
.mode()
& 0o777;
assert_eq!(
server_perms, 0o644,
"Server cert should have 0o644 permissions"
);
// Verify: whitelist contains manager IP // Verify: whitelist contains manager IP
let wl_content = std::fs::read_to_string(&whitelist_path).unwrap(); let wl_content = std::fs::read_to_string(&whitelist_path).unwrap();
@ -220,14 +246,17 @@ async fn test_enrollment_denied_flow() {
let (server, base_url) = create_mock_manager().await; let (server, base_url) = create_mock_manager().await;
let (cert_dir, _whitelist_dir) = create_temp_dirs(); let (cert_dir, _whitelist_dir) = create_temp_dirs();
let whitelist_path = _whitelist_dir.path().join("whitelist.yaml").to_string_lossy().to_string(); let whitelist_path = _whitelist_dir
.path()
.join("whitelist.yaml")
.to_string_lossy()
.to_string();
init_empty_whitelist(&whitelist_path); init_empty_whitelist(&whitelist_path);
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "denied_token"}"#),
.set_body_string(r#"{"polling_token": "denied_token"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -235,9 +264,7 @@ async fn test_enrollment_denied_flow() {
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"status": "denied"}"#))
ResponseTemplate::new(200).set_body_string(r#"{"status": "denied"}"#),
)
.named("status_denied") .named("status_denied")
.expect(1) // Exactly one poll attempt before denial .expect(1) // Exactly one poll attempt before denial
.mount(&server) .mount(&server)
@ -246,7 +273,10 @@ async fn test_enrollment_denied_flow() {
let client = build_client(&base_url); let client = build_client(&base_url);
// Phase 1: Registration succeeds even for denied enrollment // Phase 1: Registration succeeds even for denied enrollment
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
assert_eq!(response.polling_token, "denied_token"); assert_eq!(response.polling_token, "denied_token");
// Phase 2: Polling returns denial error // Phase 2: Polling returns denial error
@ -254,7 +284,10 @@ async fn test_enrollment_denied_flow() {
.poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 10) .poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 10)
.await; .await;
assert!(result.is_err(), "Should receive error for denied enrollment"); assert!(
result.is_err(),
"Should receive error for denied enrollment"
);
let err_msg = result.unwrap_err().to_string(); let err_msg = result.unwrap_err().to_string();
assert!( assert!(
err_msg.contains("denied"), err_msg.contains("denied"),
@ -267,9 +300,18 @@ async fn test_enrollment_denied_flow() {
let server_cert_path = cert_dir.path().join("server.pem"); let server_cert_path = cert_dir.path().join("server.pem");
let server_key_path = cert_dir.path().join("server.key.pem"); let server_key_path = cert_dir.path().join("server.key.pem");
assert!(!ca_path.exists(), "CA cert should NOT exist after denied enrollment"); assert!(
assert!(!server_cert_path.exists(), "Server cert should NOT exist after denied enrollment"); !ca_path.exists(),
assert!(!server_key_path.exists(), "Server key should NOT exist after denied enrollment"); "CA cert should NOT exist after denied enrollment"
);
assert!(
!server_cert_path.exists(),
"Server cert should NOT exist after denied enrollment"
);
assert!(
!server_key_path.exists(),
"Server key should NOT exist after denied enrollment"
);
// Verify: no whitelist modifications on failed enrollment // Verify: no whitelist modifications on failed enrollment
let wl_content = std::fs::read_to_string(&whitelist_path).unwrap(); let wl_content = std::fs::read_to_string(&whitelist_path).unwrap();
@ -298,8 +340,7 @@ async fn test_enrollment_timeout_flow() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "timeout_token"}"#),
.set_body_string(r#"{"polling_token": "timeout_token"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -307,16 +348,17 @@ async fn test_enrollment_timeout_flow() {
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"status": "pending"}"#))
ResponseTemplate::new(200).set_body_string(r#"{"status": "pending"}"#),
)
.named("status_always_pending") .named("status_always_pending")
.expect(3) // Exactly 3 poll attempts before timeout .expect(3) // Exactly 3 poll attempts before timeout
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
// Poll with max_attempts=3 - should timeout after exactly 3 attempts // Poll with max_attempts=3 - should timeout after exactly 3 attempts
let result = client let result = client
@ -337,8 +379,14 @@ async fn test_enrollment_timeout_flow() {
let server_key_path = cert_dir.path().join("server.key.pem"); let server_key_path = cert_dir.path().join("server.key.pem");
assert!(!ca_path.exists(), "CA cert should NOT exist after timeout"); assert!(!ca_path.exists(), "CA cert should NOT exist after timeout");
assert!(!server_cert_path.exists(), "Server cert should NOT exist after timeout"); assert!(
assert!(!server_key_path.exists(), "Server key should NOT exist after timeout"); !server_cert_path.exists(),
"Server cert should NOT exist after timeout"
);
assert!(
!server_key_path.exists(),
"Server key should NOT exist after timeout"
);
} }
// ============================================================================= // =============================================================================
@ -359,32 +407,32 @@ async fn test_certificate_permission_verification() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "perm_token"}"#),
.set_body_string(r#"{"polling_token": "perm_token"}"#),
) )
.mount(&server) .mount(&server)
.await; .await;
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(&format!(
ResponseTemplate::new(200).set_body_string(&format!( r#"{{
r#"{{
"status": "approved", "status": "approved",
"ca_crt": {}, "ca_crt": {},
"server_crt": {}, "server_crt": {},
"server_key": {} "server_key": {}
}}"#, }}"#,
serde_json::to_string(DUMMY_CA_PEM).unwrap(), serde_json::to_string(DUMMY_CA_PEM).unwrap(),
serde_json::to_string(DUMMY_SERVER_PEM).unwrap(), serde_json::to_string(DUMMY_SERVER_PEM).unwrap(),
serde_json::to_string(DUMMY_KEY_PEM).unwrap(), serde_json::to_string(DUMMY_KEY_PEM).unwrap(),
)), )))
)
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
let bundle = client let bundle = client
.poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 5) .poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 5)
.await .await
@ -397,14 +445,15 @@ async fn test_certificate_permission_verification() {
&bundle.server_crt, &bundle.server_crt,
&bundle.server_key, &bundle.server_key,
Some(&tls_config), Some(&tls_config),
).await.expect("PKI provisioning should succeed"); )
.await
.expect("PKI provisioning should succeed");
// Verify key file: 0o600 (owner read/write only) // Verify key file: 0o600 (owner read/write only)
let key_path = cert_dir.path().join("server.key.pem"); let key_path = cert_dir.path().join("server.key.pem");
let key_perms = std::fs::metadata(&key_path).unwrap().permissions().mode() & 0o777; let key_perms = std::fs::metadata(&key_path).unwrap().permissions().mode() & 0o777;
assert_eq!( assert_eq!(
key_perms, key_perms, 0o600,
0o600,
"Key file must have exactly 0o600 permissions (owner rw only)" "Key file must have exactly 0o600 permissions (owner rw only)"
); );
@ -412,17 +461,19 @@ async fn test_certificate_permission_verification() {
let ca_path = cert_dir.path().join("ca.pem"); let ca_path = cert_dir.path().join("ca.pem");
let ca_perms = std::fs::metadata(&ca_path).unwrap().permissions().mode() & 0o777; let ca_perms = std::fs::metadata(&ca_path).unwrap().permissions().mode() & 0o777;
assert_eq!( assert_eq!(
ca_perms, ca_perms, 0o644,
0o644,
"CA certificate must have exactly 0o644 permissions" "CA certificate must have exactly 0o644 permissions"
); );
// Verify server cert: 0o644 (owner rw, group/others read) // Verify server cert: 0o644 (owner rw, group/others read)
let server_cert_path = cert_dir.path().join("server.pem"); let server_cert_path = cert_dir.path().join("server.pem");
let server_perms = std::fs::metadata(&server_cert_path).unwrap().permissions().mode() & 0o777; let server_perms = std::fs::metadata(&server_cert_path)
.unwrap()
.permissions()
.mode()
& 0o777;
assert_eq!( assert_eq!(
server_perms, server_perms, 0o644,
0o644,
"Server certificate must have exactly 0o644 permissions" "Server certificate must have exactly 0o644 permissions"
); );
@ -442,7 +493,9 @@ async fn test_certificate_permission_verification() {
assert!(ca_content.contains("END CERTIFICATE")); assert!(ca_content.contains("END CERTIFICATE"));
let key_content = std::fs::read_to_string(&key_path).unwrap(); let key_content = std::fs::read_to_string(&key_path).unwrap();
assert!(key_content.contains("BEGIN PRIVATE KEY") || key_content.contains("BEGIN RSA PRIVATE KEY")); assert!(
key_content.contains("BEGIN PRIVATE KEY") || key_content.contains("BEGIN RSA PRIVATE KEY")
);
} }
// ============================================================================= // =============================================================================
@ -459,45 +512,52 @@ async fn test_whitelist_append_verification() {
let (server, base_url) = create_mock_manager().await; let (server, base_url) = create_mock_manager().await;
let (_cert_dir, whitelist_dir) = create_temp_dirs(); let (_cert_dir, whitelist_dir) = create_temp_dirs();
let whitelist_path = whitelist_dir.path().join("whitelist.yaml").to_string_lossy().to_string(); let whitelist_path = whitelist_dir
.path()
.join("whitelist.yaml")
.to_string_lossy()
.to_string();
init_empty_whitelist(&whitelist_path); init_empty_whitelist(&whitelist_path);
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "wl_token"}"#),
.set_body_string(r#"{"polling_token": "wl_token"}"#),
) )
.mount(&server) .mount(&server)
.await; .await;
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(&format!(
ResponseTemplate::new(200).set_body_string(&format!( r#"{{
r#"{{
"status": "approved", "status": "approved",
"ca_crt": {}, "ca_crt": {},
"server_crt": {}, "server_crt": {},
"server_key": {} "server_key": {}
}}"#, }}"#,
serde_json::to_string(DUMMY_CA_PEM).unwrap(), serde_json::to_string(DUMMY_CA_PEM).unwrap(),
serde_json::to_string(DUMMY_SERVER_PEM).unwrap(), serde_json::to_string(DUMMY_SERVER_PEM).unwrap(),
serde_json::to_string(DUMMY_KEY_PEM).unwrap(), serde_json::to_string(DUMMY_KEY_PEM).unwrap(),
)), )))
)
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
let _bundle = client let _bundle = client
.poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 5) .poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 5)
.await .await
.expect("Should receive approved PkiBundle"); .expect("Should receive approved PkiBundle");
// First enrollment: append to whitelist // First enrollment: append to whitelist
let manager_ip = client.manager_ip().await.expect("Should resolve manager IP"); let manager_ip = client
.manager_ip()
.await
.expect("Should resolve manager IP");
provision::append_manager_to_whitelist(&manager_ip, &whitelist_path) provision::append_manager_to_whitelist(&manager_ip, &whitelist_path)
.await .await
.expect("First whitelist append should succeed"); .expect("First whitelist append should succeed");
@ -544,7 +604,10 @@ async fn test_whitelist_append_verification() {
); );
// Verify: YAML format is valid and parseable // Verify: YAML format is valid and parseable
assert!(wl_content.contains("entries:"), "YAML should contain 'entries:' key"); assert!(
wl_content.contains("entries:"),
"YAML should contain 'entries:' key"
);
} }
// ============================================================================= // =============================================================================
@ -565,24 +628,24 @@ async fn test_signal_handling_during_polling() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "signal_token"}"#),
.set_body_string(r#"{"polling_token": "signal_token"}"#),
) )
.mount(&server) .mount(&server)
.await; .await;
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"status": "pending"}"#))
ResponseTemplate::new(200).set_body_string(r#"{"status": "pending"}"#),
)
.named("always_pending") .named("always_pending")
.expect(3) // Exactly 3 polls before graceful shutdown .expect(3) // Exactly 3 polls before graceful shutdown
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
// Poll with max_attempts=3, interval=1s // Poll with max_attempts=3, interval=1s
// This simulates SIGTERM interrupt by exhausting attempts (graceful shutdown) // This simulates SIGTERM interrupt by exhausting attempts (graceful shutdown)
@ -602,8 +665,11 @@ async fn test_signal_handling_during_polling() {
// Verify: cleanup of any partial state (no leftover files) // Verify: cleanup of any partial state (no leftover files)
for entry in std::fs::read_dir(cert_dir.path()).unwrap() { for entry in std::fs::read_dir(cert_dir.path()).unwrap() {
let entry = entry.unwrap(); let entry = entry.unwrap();
assert!(false, "No partial files should remain after graceful shutdown: {}", assert!(
entry.file_name().to_string_lossy()); false,
"No partial files should remain after graceful shutdown: {}",
entry.file_name().to_string_lossy()
);
} }
} }
@ -620,45 +686,52 @@ async fn test_whitelist_yaml_format_preservation() {
let (server, base_url) = create_mock_manager().await; let (server, base_url) = create_mock_manager().await;
let (_cert_dir, whitelist_dir) = create_temp_dirs(); let (_cert_dir, whitelist_dir) = create_temp_dirs();
let whitelist_path = whitelist_dir.path().join("whitelist.yaml").to_string_lossy().to_string(); let whitelist_path = whitelist_dir
.path()
.join("whitelist.yaml")
.to_string_lossy()
.to_string();
init_empty_whitelist(&whitelist_path); init_empty_whitelist(&whitelist_path);
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "yaml_token"}"#),
.set_body_string(r#"{"polling_token": "yaml_token"}"#),
) )
.mount(&server) .mount(&server)
.await; .await;
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(&format!(
ResponseTemplate::new(200).set_body_string(&format!( r#"{{
r#"{{
"status": "approved", "status": "approved",
"ca_crt": {}, "ca_crt": {},
"server_crt": {}, "server_crt": {},
"server_key": {} "server_key": {}
}}"#, }}"#,
serde_json::to_string(DUMMY_CA_PEM).unwrap(), serde_json::to_string(DUMMY_CA_PEM).unwrap(),
serde_json::to_string(DUMMY_SERVER_PEM).unwrap(), serde_json::to_string(DUMMY_SERVER_PEM).unwrap(),
serde_json::to_string(DUMMY_KEY_PEM).unwrap(), serde_json::to_string(DUMMY_KEY_PEM).unwrap(),
)), )))
)
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
let _bundle = client let _bundle = client
.poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 5) .poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 5)
.await .await
.expect("Should receive approved PkiBundle"); .expect("Should receive approved PkiBundle");
// Provision and append to whitelist // Provision and append to whitelist
let manager_ip = client.manager_ip().await.expect("Should resolve manager IP"); let manager_ip = client
.manager_ip()
.await
.expect("Should resolve manager IP");
provision::append_manager_to_whitelist(&manager_ip, &whitelist_path) provision::append_manager_to_whitelist(&manager_ip, &whitelist_path)
.await .await
.expect("Whitelist append should succeed"); .expect("Whitelist append should succeed");
@ -667,11 +740,14 @@ async fn test_whitelist_yaml_format_preservation() {
let wl_content = std::fs::read_to_string(&whitelist_path).unwrap(); let wl_content = std::fs::read_to_string(&whitelist_path).unwrap();
// Parse as serde_yaml to verify format // Parse as serde_yaml to verify format
let wl_config: serde_yaml::Value = serde_yaml::from_str(&wl_content) let wl_config: serde_yaml::Value =
.expect("Whitelist should be valid YAML after enrollment"); serde_yaml::from_str(&wl_content).expect("Whitelist should be valid YAML after enrollment");
// Verify structure: entries key exists and is a sequence // Verify structure: entries key exists and is a sequence
assert!(wl_config.get("entries").is_some(), "YAML must contain 'entries' key"); assert!(
wl_config.get("entries").is_some(),
"YAML must contain 'entries' key"
);
let entries = wl_config.get("entries").unwrap(); let entries = wl_config.get("entries").unwrap();
assert!(entries.is_sequence(), "'entries' must be a YAML sequence"); assert!(entries.is_sequence(), "'entries' must be a YAML sequence");

View File

@ -9,13 +9,11 @@
//! - Short polling intervals ensure tests complete quickly //! - Short polling intervals ensure tests complete quickly
//! - serial_test prevents port conflicts between concurrent test runs //! - serial_test prevents port conflicts between concurrent test runs
use linux_patch_api::enroll::client::{ use linux_patch_api::enroll::client::EnrollmentClient;
EnrollmentClient,
};
use serial_test::serial; use serial_test::serial;
use wiremock::{ use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path, path_regex}, matchers::{method, path, path_regex},
Mock, MockServer, ResponseTemplate,
}; };
/// Test constants /// Test constants
@ -54,8 +52,7 @@ async fn test_successful_enrollment_flow() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "test_token_123"}"#),
.set_body_string(r#"{"polling_token": "test_token_123"}"#),
) )
.named("enroll_registration") .named("enroll_registration")
.mount(&server) .mount(&server)
@ -81,7 +78,10 @@ async fn test_successful_enrollment_flow() {
let client = build_client(&base_url); let client = build_client(&base_url);
// Phase 1: Register - should succeed with polling token // Phase 1: Register - should succeed with polling token
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
assert_eq!(response.polling_token, TEST_TOKEN); assert_eq!(response.polling_token, TEST_TOKEN);
// Phase 2: Poll for approval - should get PkiBundle immediately since mock returns approved // Phase 2: Poll for approval - should get PkiBundle immediately since mock returns approved
@ -89,11 +89,23 @@ async fn test_successful_enrollment_flow() {
.poll_for_approval(TEST_TOKEN, POLL_INTERVAL_SECONDS, 5) .poll_for_approval(TEST_TOKEN, POLL_INTERVAL_SECONDS, 5)
.await; .await;
assert!(result.is_ok(), "Polling should succeed with approved status"); assert!(
result.is_ok(),
"Polling should succeed with approved status"
);
let bundle = result.unwrap(); let bundle = result.unwrap();
assert_eq!(bundle.ca_crt, "-----BEGIN CERTIFICATE-----\nCA_CERT_DATA\n-----END CERTIFICATE-----"); assert_eq!(
assert_eq!(bundle.server_crt, "-----BEGIN CERTIFICATE-----\nSERVER_CERT_DATA\n-----END CERTIFICATE-----"); bundle.ca_crt,
assert_eq!(bundle.server_key, "-----BEGIN PRIVATE KEY-----\nSERVER_KEY_DATA\n-----END PRIVATE KEY-----"); "-----BEGIN CERTIFICATE-----\nCA_CERT_DATA\n-----END CERTIFICATE-----"
);
assert_eq!(
bundle.server_crt,
"-----BEGIN CERTIFICATE-----\nSERVER_CERT_DATA\n-----END CERTIFICATE-----"
);
assert_eq!(
bundle.server_key,
"-----BEGIN PRIVATE KEY-----\nSERVER_KEY_DATA\n-----END PRIVATE KEY-----"
);
} }
// ============================================================================= // =============================================================================
@ -111,8 +123,7 @@ async fn test_pending_then_approved_sequence() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "seq_token_456"}"#),
.set_body_string(r#"{"polling_token": "seq_token_456"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -121,16 +132,14 @@ async fn test_pending_then_approved_sequence() {
// Status always returns approved (simplifies test while verifying the happy path) // Status always returns approved (simplifies test while verifying the happy path)
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(
ResponseTemplate::new(200).set_body_string( r#"{
r#"{
"status": "approved", "status": "approved",
"ca_crt": "CA_PEM", "ca_crt": "CA_PEM",
"server_crt": "SERVER_PEM", "server_crt": "SERVER_PEM",
"server_key": "KEY_PEM" "server_key": "KEY_PEM"
}"#, }"#,
), ))
)
.named("status_approved") .named("status_approved")
.mount(&server) .mount(&server)
.await; .await;
@ -168,8 +177,7 @@ async fn test_denied_enrollment() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "denied_token_789"}"#),
.set_body_string(r#"{"polling_token": "denied_token_789"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -178,10 +186,7 @@ async fn test_denied_enrollment() {
// Status returns denied immediately // Status returns denied immediately
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/api/v1/enroll/status/denied_token_789")) .and(path("/api/v1/enroll/status/denied_token_789"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"status": "denied"}"#))
ResponseTemplate::new(200)
.set_body_string(r#"{"status": "denied"}"#),
)
.named("status_denied") .named("status_denied")
.expect(1) // Exactly one poll attempt .expect(1) // Exactly one poll attempt
.mount(&server) .mount(&server)
@ -190,7 +195,10 @@ async fn test_denied_enrollment() {
let client = build_client(&base_url); let client = build_client(&base_url);
// Register succeeds // Register succeeds
let response = client.register().await.expect("Registration should succeed even for denied enrollment"); let response = client
.register()
.await
.expect("Registration should succeed even for denied enrollment");
assert_eq!(response.polling_token, "denied_token_789"); assert_eq!(response.polling_token, "denied_token_789");
// Poll should return error // Poll should return error
@ -198,7 +206,10 @@ async fn test_denied_enrollment() {
.poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 10) .poll_for_approval(&response.polling_token, POLL_INTERVAL_SECONDS, 10)
.await; .await;
assert!(result.is_err(), "Should receive error for denied enrollment"); assert!(
result.is_err(),
"Should receive error for denied enrollment"
);
let err_msg = result.unwrap_err().to_string(); let err_msg = result.unwrap_err().to_string();
assert!( assert!(
err_msg.contains("denied"), err_msg.contains("denied"),
@ -223,8 +234,7 @@ async fn test_token_not_found_expired() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "expired_token_000"}"#),
.set_body_string(r#"{"polling_token": "expired_token_000"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -233,10 +243,7 @@ async fn test_token_not_found_expired() {
// Status returns notfound (serde rename_all="lowercase" converts NotFound -> "notfind") // Status returns notfound (serde rename_all="lowercase" converts NotFound -> "notfind")
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/api/v1/enroll/status/expired_token_000")) .and(path("/api/v1/enroll/status/expired_token_000"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"status": "notfound"}"#))
ResponseTemplate::new(200)
.set_body_string(r#"{"status": "notfound"}"#),
)
.named("status_not_found") .named("status_not_found")
.expect(1) // Exactly one poll attempt .expect(1) // Exactly one poll attempt
.mount(&server) .mount(&server)
@ -245,7 +252,10 @@ async fn test_token_not_found_expired() {
let client = build_client(&base_url); let client = build_client(&base_url);
// Register succeeds // Register succeeds
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
// Poll should return error about expired/invalid token // Poll should return error about expired/invalid token
let result = client let result = client
@ -277,8 +287,7 @@ async fn test_max_attempts_timeout() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "timeout_token_abc"}"#),
.set_body_string(r#"{"polling_token": "timeout_token_abc"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -287,10 +296,7 @@ async fn test_max_attempts_timeout() {
// Status always returns pending - should be called exactly 3 times (max_attempts=3) // Status always returns pending - should be called exactly 3 times (max_attempts=3)
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/api/v1/enroll/status/timeout_token_abc")) .and(path("/api/v1/enroll/status/timeout_token_abc"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"status": "pending"}"#))
ResponseTemplate::new(200)
.set_body_string(r#"{"status": "pending"}"#),
)
.named("status_pending_timeout") .named("status_pending_timeout")
.expect(3) // Exactly 3 poll attempts before giving up .expect(3) // Exactly 3 poll attempts before giving up
.mount(&server) .mount(&server)
@ -298,7 +304,10 @@ async fn test_max_attempts_timeout() {
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
// Poll with max_attempts=3, interval=1s // Poll with max_attempts=3, interval=1s
let result = client let result = client
@ -329,9 +338,10 @@ async fn test_rate_limit_on_registration() {
// Registration returns 429 // Registration returns 429
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with(ResponseTemplate::new(429).set_body_string( .respond_with(
r#"{"error": "Too Many Requests", "retry_after": 60}"#, ResponseTemplate::new(429)
)) .set_body_string(r#"{"error": "Too Many Requests", "retry_after": 60}"#),
)
.named("registration_rate_limited") .named("registration_rate_limited")
.expect(1) // Exactly one attempt .expect(1) // Exactly one attempt
.mount(&server) .mount(&server)
@ -382,16 +392,14 @@ async fn test_registration_payload_structure() {
// Status endpoint (for completeness) // Status endpoint (for completeness)
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path_regex(r"/api/v1/enroll/status/.+")) .and(path_regex(r"/api/v1/enroll/status/.+"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(
ResponseTemplate::new(200).set_body_string( r#"{
r#"{
"status": "approved", "status": "approved",
"ca_crt": "CA_TEST", "ca_crt": "CA_TEST",
"server_crt": "CRT_TEST", "server_crt": "CRT_TEST",
"server_key": "KEY_TEST" "server_key": "KEY_TEST"
}"#, }"#,
), ))
)
.named("status_approved") .named("status_approved")
.mount(&server) .mount(&server)
.await; .await;
@ -399,34 +407,45 @@ async fn test_registration_payload_structure() {
let client = build_client(&base_url); let client = build_client(&base_url);
// Execute registration and capture the actual request // Execute registration and capture the actual request
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
assert_eq!(response.polling_token, "payload_test_token"); assert_eq!(response.polling_token, "payload_test_token");
// Verify using server request logs // Verify using server request logs
let requests = server.received_requests().await.unwrap(); let requests = server.received_requests().await.unwrap();
let post_request = requests.iter() let post_request = requests
.iter()
.find(|r| r.method.to_string() == "POST") .find(|r| r.method.to_string() == "POST")
.expect("Should have received a POST request"); .expect("Should have received a POST request");
let body_str = std::str::from_utf8(&post_request.body).expect("Body should be valid UTF-8"); let body_str = std::str::from_utf8(&post_request.body).expect("Body should be valid UTF-8");
let payload: serde_json::Value = serde_json::from_str(body_str) let payload: serde_json::Value =
.expect("Request body should be valid JSON"); serde_json::from_str(body_str).expect("Request body should be valid JSON");
// Verify machine_id field // Verify machine_id field
let machine_id = payload.get("machine_id") let machine_id = payload
.get("machine_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.expect("machine_id field must exist and be a string"); .expect("machine_id field must exist and be a string");
assert!(!machine_id.is_empty(), "machine_id should not be empty"); assert!(!machine_id.is_empty(), "machine_id should not be empty");
assert_eq!(machine_id.len(), 32, "machine_id should be 32 characters (UUID hex)"); assert_eq!(
machine_id.len(),
32,
"machine_id should be 32 characters (UUID hex)"
);
// Verify fqdn field // Verify fqdn field
let fqdn = payload.get("fqdn") let fqdn = payload
.get("fqdn")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.expect("fqdn field must exist and be a string"); .expect("fqdn field must exist and be a string");
assert!(!fqdn.is_empty(), "fqdn should not be empty"); assert!(!fqdn.is_empty(), "fqdn should not be empty");
// Verify ip_address field // Verify ip_address field
let ip_address = payload.get("ip_address") let ip_address = payload
.get("ip_address")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.expect("ip_address field must exist and be a string"); .expect("ip_address field must exist and be a string");
assert!(!ip_address.is_empty(), "ip_address should not be empty"); assert!(!ip_address.is_empty(), "ip_address should not be empty");
@ -438,12 +457,10 @@ async fn test_registration_payload_structure() {
); );
// Verify os_details field is an object with expected keys // Verify os_details field is an object with expected keys
let os_details = payload.get("os_details") let os_details = payload
.get("os_details")
.expect("os_details field must exist"); .expect("os_details field must exist");
assert!( assert!(os_details.is_object(), "os_details should be a JSON object");
os_details.is_object(),
"os_details should be a JSON object"
);
let os_obj = os_details.as_object().unwrap(); let os_obj = os_details.as_object().unwrap();
assert!(!os_obj.is_empty(), "os_details should not be empty"); assert!(!os_obj.is_empty(), "os_details should not be empty");
@ -469,9 +486,9 @@ async fn test_server_error_on_registration() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with(ResponseTemplate::new(500).set_body_string( .respond_with(
r#"{"error": "Internal Server Error"}"#, ResponseTemplate::new(500).set_body_string(r#"{"error": "Internal Server Error"}"#),
)) )
.named("registration_server_error") .named("registration_server_error")
.expect(1) .expect(1)
.mount(&server) .mount(&server)
@ -506,8 +523,7 @@ async fn test_rate_limit_on_polling_retries() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "rl_poll_token"}"#),
.set_body_string(r#"{"polling_token": "rl_poll_token"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -516,22 +532,23 @@ async fn test_rate_limit_on_polling_retries() {
// Status returns approved on first poll // Status returns approved on first poll
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/api/v1/enroll/status/rl_poll_token")) .and(path("/api/v1/enroll/status/rl_poll_token"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(
ResponseTemplate::new(200).set_body_string( r#"{
r#"{
"status": "approved", "status": "approved",
"ca_crt": "CA_OK", "ca_crt": "CA_OK",
"server_crt": "CRT_OK", "server_crt": "CRT_OK",
"server_key": "KEY_OK" "server_key": "KEY_OK"
}"#, }"#,
), ))
)
.named("status_approved_after_retry") .named("status_approved_after_retry")
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
// Polling should succeed (mock returns approved directly) // Polling should succeed (mock returns approved directly)
let bundle = client let bundle = client
@ -579,8 +596,7 @@ async fn test_polling_default_parameters() {
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/api/v1/enroll")) .and(path("/api/v1/enroll"))
.respond_with( .respond_with(
ResponseTemplate::new(202) ResponseTemplate::new(202).set_body_string(r#"{"polling_token": "defaults_token"}"#),
.set_body_string(r#"{"polling_token": "defaults_token"}"#),
) )
.named("registration") .named("registration")
.mount(&server) .mount(&server)
@ -589,22 +605,23 @@ async fn test_polling_default_parameters() {
// Status returns approved immediately // Status returns approved immediately
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/api/v1/enroll/status/defaults_token")) .and(path("/api/v1/enroll/status/defaults_token"))
.respond_with( .respond_with(ResponseTemplate::new(200).set_body_string(
ResponseTemplate::new(200).set_body_string( r#"{
r#"{
"status": "approved", "status": "approved",
"ca_crt": "DEFAULT_CA", "ca_crt": "DEFAULT_CA",
"server_crt": "DEFAULT_CRT", "server_crt": "DEFAULT_CRT",
"server_key": "DEFAULT_KEY" "server_key": "DEFAULT_KEY"
}"#, }"#,
), ))
)
.named("status_approved") .named("status_approved")
.mount(&server) .mount(&server)
.await; .await;
let client = build_client(&base_url); let client = build_client(&base_url);
let response = client.register().await.expect("Registration should succeed"); let response = client
.register()
.await
.expect("Registration should succeed");
// Call with interval=0 (should default to 60) and max_attempts=0 (should default to 1440) // Call with interval=0 (should default to 60) and max_attempts=0 (should default to 1440)
// But since mock returns approved on first try, we don't actually wait // But since mock returns approved on first try, we don't actually wait

View File

@ -3,7 +3,9 @@
//! Comprehensive tests for cross-distribution identity extraction functions. //! Comprehensive tests for cross-distribution identity extraction functions.
//! Verifies machine-id, FQDN, IP address collection, and OS detail parsing. //! Verifies machine-id, FQDN, IP address collection, and OS detail parsing.
use linux_patch_api::enroll::identity::{get_fqdn, get_ip_addresses, get_machine_id, get_os_details}; use linux_patch_api::enroll::identity::{
get_fqdn, get_ip_addresses, get_machine_id, get_os_details,
};
use linux_patch_api::enroll::EnrollmentRequest; use linux_patch_api::enroll::EnrollmentRequest;
use serde_json::Value; use serde_json::Value;
@ -46,10 +48,7 @@ fn test_machine_id_is_consistent() {
// Multiple calls should return the same value (it's a persistent identifier) // Multiple calls should return the same value (it's a persistent identifier)
let id1 = get_machine_id().expect("Failed to get machine-id (call 1)"); let id1 = get_machine_id().expect("Failed to get machine-id (call 1)");
let id2 = get_machine_id().expect("Failed to get machine-id (call 2)"); let id2 = get_machine_id().expect("Failed to get machine-id (call 2)");
assert_eq!( assert_eq!(id1, id2, "machine-id should be consistent across calls");
id1, id2,
"machine-id should be consistent across calls"
);
} }
#[test] #[test]
@ -67,8 +66,12 @@ fn test_machine_id_fallback_file_check() {
// Verify fallback file exists (may or may not be used) // Verify fallback file exists (may or may not be used)
let fallback = std::path::Path::new("/var/lib/dbus/machine-id"); let fallback = std::path::Path::new("/var/lib/dbus/machine-id");
if fallback.exists() { if fallback.exists() {
let content = std::fs::read_to_string(fallback).expect("Failed to read fallback machine-id"); let content =
assert!(!content.trim().is_empty(), "Fallback machine-id should not be empty"); std::fs::read_to_string(fallback).expect("Failed to read fallback machine-id");
assert!(
!content.trim().is_empty(),
"Fallback machine-id should not be empty"
);
} }
// If it doesn't exist, that's fine - primary file is used instead // If it doesn't exist, that's fine - primary file is used instead
} }
@ -157,9 +160,9 @@ fn test_ip_addresses_are_valid_ipv4() {
assert_eq!(parts.len(), 4, "IP '{}' should have 4 octets", addr); assert_eq!(parts.len(), 4, "IP '{}' should have 4 octets", addr);
for part in &parts { for part in &parts {
let _octet: u8 = part let _octet: u8 = part.parse().unwrap_or_else(|_| {
.parse() panic!("IP octet '{}' in '{}' is not a valid number", part, addr)
.unwrap_or_else(|_| panic!("IP octet '{}' in '{}' is not a valid number", part, addr)); });
// u8 parse success guarantees 0-255 range // u8 parse success guarantees 0-255 range
} }
} }
@ -198,7 +201,10 @@ fn test_ip_addresses_no_broadcast() {
let addrs = get_ip_addresses().expect("Failed to get IP addresses"); let addrs = get_ip_addresses().expect("Failed to get IP addresses");
for addr in &addrs { for addr in &addrs {
assert_ne!(addr, "255.255.255.255", "Broadcast address should be excluded"); assert_ne!(
addr, "255.255.255.255",
"Broadcast address should be excluded"
);
} }
} }
@ -238,7 +244,11 @@ fn test_ip_addresses_are_unicast() {
assert!(first < 240, "Address '{}' is reserved", addr); assert!(first < 240, "Address '{}' is reserved", addr);
// Not unspecified (0.0.0.0) // Not unspecified (0.0.0.0)
assert!(!(parts == vec![0, 0, 0, 0]), "Address '{}' is unspecified", addr); assert!(
!(parts == vec![0, 0, 0, 0]),
"Address '{}' is unspecified",
addr
);
} }
} }
@ -259,7 +269,9 @@ fn test_os_details_returns_valid_json_object() {
#[test] #[test]
fn test_os_details_contains_kernel_version() { fn test_os_details_contains_kernel_version() {
let details = get_os_details().expect("Failed to get OS details"); let details = get_os_details().expect("Failed to get OS details");
let kernel = details.get("kernel").expect("OS details must contain 'kernel' field"); let kernel = details
.get("kernel")
.expect("OS details must contain 'kernel' field");
assert!(kernel.is_string(), "Kernel version should be a string"); assert!(kernel.is_string(), "Kernel version should be a string");
let kernel_str = kernel.as_str().unwrap(); let kernel_str = kernel.as_str().unwrap();
@ -297,7 +309,10 @@ fn test_os_details_distro_is_valid_string() {
assert!(distro.is_string(), "Distro should be a string"); assert!(distro.is_string(), "Distro should be a string");
let distro_str = distro.as_str().unwrap(); let distro_str = distro.as_str().unwrap();
assert!(!distro_str.is_empty(), "Distro name should not be empty"); assert!(!distro_str.is_empty(), "Distro name should not be empty");
assert_ne!(distro_str, "unknown", "Distro should be identified on this system"); assert_ne!(
distro_str, "unknown",
"Distro should be identified on this system"
);
} }
} }
@ -350,7 +365,8 @@ fn test_enrollment_payload_construction() {
let os_details = get_os_details().expect("Failed to get OS details"); let os_details = get_os_details().expect("Failed to get OS details");
// Use first non-loopback IP as the primary address // Use first non-loopback IP as the primary address
let primary_ip = ip_addrs.first() let primary_ip = ip_addrs
.first()
.expect("Should have at least one IP") .expect("Should have at least one IP")
.clone(); .clone();
@ -362,19 +378,30 @@ fn test_enrollment_payload_construction() {
}; };
// Verify payload serializes to valid JSON // Verify payload serializes to valid JSON
let json = serde_json::to_string(&request) let json =
.expect("EnrollmentRequest should serialize to valid JSON"); serde_json::to_string(&request).expect("EnrollmentRequest should serialize to valid JSON");
assert!(!json.is_empty(), "Serialized enrollment request should not be empty"); assert!(
!json.is_empty(),
"Serialized enrollment request should not be empty"
);
// Verify JSON contains all required fields // Verify JSON contains all required fields
let parsed: Value = serde_json::from_str(&json) let parsed: Value = serde_json::from_str(&json).expect("Should deserialize enrollment request");
.expect("Should deserialize enrollment request");
assert!(parsed.get("machine_id").is_some(), "JSON must contain machine_id"); assert!(
parsed.get("machine_id").is_some(),
"JSON must contain machine_id"
);
assert!(parsed.get("fqdn").is_some(), "JSON must contain fqdn"); assert!(parsed.get("fqdn").is_some(), "JSON must contain fqdn");
assert!(parsed.get("ip_address").is_some(), "JSON must contain ip_address"); assert!(
assert!(parsed.get("os_details").is_some(), "JSON must contain os_details"); parsed.get("ip_address").is_some(),
"JSON must contain ip_address"
);
assert!(
parsed.get("os_details").is_some(),
"JSON must contain os_details"
);
} }
#[test] #[test]
@ -430,8 +457,8 @@ fn test_enrollment_payload_roundtrip() {
// Serialize to JSON then deserialize back // Serialize to JSON then deserialize back
let json = serde_json::to_string(&request).expect("Failed to serialize"); let json = serde_json::to_string(&request).expect("Failed to serialize");
let deserialized: EnrollmentRequest = serde_json::from_str(&json) let deserialized: EnrollmentRequest =
.expect("Failed to deserialize enrollment request"); serde_json::from_str(&json).expect("Failed to deserialize enrollment request");
assert_eq!(request.machine_id, deserialized.machine_id); assert_eq!(request.machine_id, deserialized.machine_id);
assert_eq!(request.fqdn, deserialized.fqdn); assert_eq!(request.fqdn, deserialized.fqdn);
@ -461,7 +488,10 @@ fn test_cross_distro_os_release_parsing() {
} }
// Verify key fields are present (POSIX standard for os-release) // Verify key fields are present (POSIX standard for os-release)
assert!(parsed.contains_key("NAME"), "os-release must contain NAME field"); assert!(
parsed.contains_key("NAME"),
"os-release must contain NAME field"
);
assert!(parsed["NAME"].ne(&""), "NAME should not be empty"); assert!(parsed["NAME"].ne(&""), "NAME should not be empty");
} }