feat: add self-enrollment workflow for automated PKI provisioning
- Phase 1: CLI args (--enroll flag), enroll module skeleton, config support - Phase 2: Registration request, polling loop (24h timeout), main.rs integration - Phase 3: PKI extraction, atomic cert writing, whitelist auto-append, mTLS transition - Phase 4: E2E test suite, README/DEPLOYMENT docs, CI pipeline - Phase 5: SPEC.md, API_DOCUMENTATION.md, CHANGELOG.md, ROADMAP.md sync Security review: APPROVED (0 critical, 0 high findings) Cross-distro compatible: Debian/Ubuntu, RHEL/CentOS/Fedora, Alpine, Arch Linux
This commit is contained in:
@ -4,10 +4,13 @@
|
||||
//! Loads configuration from YAML file with auto-reload support.
|
||||
//! All connections not in whitelist are silently dropped.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use fs2::FileExt;
|
||||
use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::fs::{self, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -26,7 +29,7 @@ pub enum WhitelistEntry {
|
||||
}
|
||||
|
||||
/// Whitelist configuration loaded from YAML
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct WhitelistConfig {
|
||||
pub entries: Vec<String>,
|
||||
}
|
||||
@ -79,6 +82,141 @@ impl WhitelistManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append an IP address or CIDR entry to the whitelist file.
|
||||
/// Creates the file if it doesn't exist. Uses file locking for concurrent safety.
|
||||
/// Logs the change to audit log.
|
||||
pub fn append_entry(&mut self, ip_or_cidr: &str) -> Result<()> {
|
||||
// 1. Validate IP/CIDR format
|
||||
let entry_str = ip_or_cidr.trim();
|
||||
if entry_str.is_empty() {
|
||||
bail!("Cannot append empty whitelist entry");
|
||||
}
|
||||
|
||||
// Parse to validate - must be IPv4 or CIDR, no hostnames in auto-append
|
||||
let parsed_entry = if let Some((ip_str, prefix_str)) = entry_str.split_once('/') {
|
||||
let ip: Ipv4Addr = ip_str.parse()
|
||||
.with_context(|| format!("Invalid IP in CIDR notation: {}", entry_str))?;
|
||||
let prefix: u8 = prefix_str.parse()
|
||||
.with_context(|| format!("Invalid prefix in CIDR notation: {}", entry_str))?;
|
||||
if prefix > 32 {
|
||||
anyhow::bail!("Invalid CIDR prefix (must be 0-32): {}", entry_str);
|
||||
}
|
||||
WhitelistEntry::Cidr { network: ip, prefix }
|
||||
} else {
|
||||
let ip: Ipv4Addr = entry_str.parse()
|
||||
.with_context(|| format!("Invalid IPv4 address: {}", entry_str))?;
|
||||
WhitelistEntry::Ip(ip)
|
||||
};
|
||||
|
||||
// 2. Check for duplicate in current in-memory state
|
||||
{
|
||||
let entries = self.entries.read().map_err(|e| anyhow::anyhow!("Failed to acquire whitelist read lock: {}", e))?;
|
||||
for existing in entries.iter() {
|
||||
if *existing == parsed_entry {
|
||||
info!(
|
||||
action = "whitelist_append",
|
||||
ip = entry_str,
|
||||
source = "enrollment",
|
||||
already_exists = true,
|
||||
"Whitelist entry already exists, skipping duplicate"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Acquire exclusive file lock using fs2
|
||||
let lock_path = format!("{}.lock", self.config_path);
|
||||
let lock_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(&lock_path)
|
||||
.with_context(|| format!("Failed to create lock file: {}", lock_path))?;
|
||||
|
||||
lock_file.lock_exclusive().context("Failed to acquire exclusive whitelist lock")?;
|
||||
|
||||
// Double-check for duplicates after acquiring lock (concurrent append scenario)
|
||||
{
|
||||
let entries = self.entries.read().map_err(|e| anyhow::anyhow!("Failed to acquire whitelist read lock: {}", e))?;
|
||||
for existing in entries.iter() {
|
||||
if *existing == parsed_entry {
|
||||
info!(
|
||||
action = "whitelist_append",
|
||||
ip = entry_str,
|
||||
source = "enrollment",
|
||||
already_exists = true,
|
||||
"Whitelist entry already exists (post-lock check), skipping duplicate"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Read current whitelist YAML or create empty config
|
||||
let mut config = if Path::new(&self.config_path).exists() {
|
||||
self.load_config().context("Failed to load existing whitelist for append")?
|
||||
} else {
|
||||
WhitelistConfig { entries: Vec::new() }
|
||||
};
|
||||
|
||||
// 5. Append new entry to allowed_ips list
|
||||
config.entries.push(entry_str.to_string());
|
||||
|
||||
// 6. Write back atomically (temp file + rename)
|
||||
let config_path = Path::new(&self.config_path);
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = config_path.parent() {
|
||||
if !parent.exists() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create whitelist directory: {}", parent.display()))?;
|
||||
}
|
||||
}
|
||||
|
||||
let yaml_content = serde_yaml::to_string(&config)
|
||||
.with_context(|| "Failed to serialize whitelist config")?;
|
||||
|
||||
let temp_path = config_path.with_extension("tmp");
|
||||
let mut file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.truncate(true)
|
||||
.open(&temp_path)
|
||||
.with_context(|| format!("Failed to create temp whitelist file: {}", temp_path.display()))?;
|
||||
|
||||
file.write_all(yaml_content.as_bytes())
|
||||
.with_context(|| format!("Failed to write whitelist data to: {}", temp_path.display()))?;
|
||||
file.flush()
|
||||
.with_context(|| format!("Failed to flush whitelist data to: {}", temp_path.display()))?;
|
||||
|
||||
// Atomic rename
|
||||
fs::rename(&temp_path, config_path)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to atomically rename whitelist temp file {} to {}",
|
||||
temp_path.display(),
|
||||
config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
// Release lock explicitly before reload (drop happens at end of scope)
|
||||
drop(lock_file);
|
||||
|
||||
// 7. Reload in-memory state
|
||||
self.reload().context("Failed to reload whitelist after append")?;
|
||||
|
||||
// 8. Log audit event
|
||||
tracing::info!(
|
||||
action = "whitelist_append",
|
||||
ip = entry_str,
|
||||
source = "enrollment",
|
||||
total_entries = self.entry_count(),
|
||||
"Whitelist entry added during enrollment"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if an IP address is allowed
|
||||
pub fn is_allowed(&self, ip: &Ipv4Addr) -> bool {
|
||||
let entries = self.entries.read().unwrap();
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
//! Loads and parses YAML configuration files.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Server configuration
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
@ -103,6 +103,27 @@ fn default_backend() -> String {
|
||||
"auto".to_string()
|
||||
}
|
||||
|
||||
/// Enrollment polling configuration
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct EnrollmentConfig {
|
||||
#[serde(default)]
|
||||
pub manager_url: String,
|
||||
#[serde(default)]
|
||||
pub polling_token: String,
|
||||
#[serde(default = "default_polling_interval")]
|
||||
pub polling_interval_seconds: u64,
|
||||
#[serde(default = "default_max_poll_attempts")]
|
||||
pub max_poll_attempts: u32,
|
||||
}
|
||||
|
||||
fn default_polling_interval() -> u64 {
|
||||
60
|
||||
}
|
||||
|
||||
fn default_max_poll_attempts() -> u32 {
|
||||
1440
|
||||
}
|
||||
|
||||
/// Application configuration
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct AppConfig {
|
||||
@ -115,6 +136,8 @@ pub struct AppConfig {
|
||||
pub whitelist: Option<WhitelistConfig>,
|
||||
#[serde(default)]
|
||||
pub package_manager: Option<PackageManagerConfig>,
|
||||
#[serde(default)]
|
||||
pub enrollment: Option<EnrollmentConfig>,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
@ -263,6 +286,7 @@ mod tests {
|
||||
path: "/etc/linux_patch_api/whitelist.yaml".to_string(),
|
||||
}),
|
||||
package_manager: None,
|
||||
enrollment: None,
|
||||
};
|
||||
|
||||
assert!(config.tls_config().is_some());
|
||||
|
||||
@ -6,5 +6,6 @@
|
||||
//! - Auto-reload on file change via notify watcher
|
||||
|
||||
pub mod loader;
|
||||
pub use loader::EnrollmentConfig;
|
||||
pub mod validator;
|
||||
pub mod watcher;
|
||||
|
||||
542
src/enroll/client.rs
Normal file
542
src/enroll/client.rs
Normal file
@ -0,0 +1,542 @@
|
||||
//! HTTP client wrapper for manager enrollment API communication.
|
||||
//!
|
||||
//! Provides typed request/response structures matching the manager's
|
||||
//! `/api/v1/enroll` endpoints and a reqwest-based `EnrollmentClient` with
|
||||
//! insecure TLS mode (manager approval process provides security).
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::signal::unix::{SignalKind, signal as unix_signal};
|
||||
|
||||
use crate::enroll::identity;
|
||||
|
||||
/// Payload sent to `POST /api/v1/enroll`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EnrollmentRequest {
|
||||
pub machine_id: String,
|
||||
pub fqdn: String,
|
||||
pub ip_address: String,
|
||||
pub os_details: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Response from `POST /api/v1/enroll` (HTTP 202).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EnrollmentResponse {
|
||||
pub polling_token: String,
|
||||
}
|
||||
|
||||
/// Tagged response from `GET /api/v1/enroll/status/{token}`.
|
||||
/// The manager uses a JSON-tagged enum with the `status` key.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "status", rename_all = "lowercase")]
|
||||
pub enum EnrollmentStatusResponse {
|
||||
Pending,
|
||||
Approved {
|
||||
ca_crt: String,
|
||||
server_crt: String,
|
||||
server_key: String,
|
||||
},
|
||||
Denied,
|
||||
NotFound,
|
||||
}
|
||||
|
||||
/// PEM-encoded PKI bundle extracted from an `Approved` status response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PkiBundle {
|
||||
pub ca_crt: String,
|
||||
pub server_crt: String,
|
||||
pub server_key: String,
|
||||
}
|
||||
|
||||
impl From<EnrollmentStatusResponse> for Option<PkiBundle> {
|
||||
fn from(response: EnrollmentStatusResponse) -> Self {
|
||||
match response {
|
||||
EnrollmentStatusResponse::Approved {
|
||||
ca_crt,
|
||||
server_crt,
|
||||
server_key,
|
||||
} => Some(PkiBundle {
|
||||
ca_crt,
|
||||
server_crt,
|
||||
server_key,
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP client for enrollment communication with the manager.
|
||||
///
|
||||
/// Configured with disabled TLS verification (`danger_accept_invalid_certs`)
|
||||
/// per project security model: manager approval workflow provides authorization,
|
||||
/// not initial transport encryption.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EnrollmentClient {
|
||||
/// Base URL of the manager API (e.g. `https://manager.example.com/api/v1`)
|
||||
pub manager_url: String,
|
||||
/// Pre-configured reqwest client with insecure TLS and timeout.
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl EnrollmentClient {
|
||||
/// Create a new enrollment client targeting the given manager base URL.
|
||||
///
|
||||
/// The HTTP client is configured with:
|
||||
/// - `danger_accept_invalid_certs(true)` — TLS verification disabled
|
||||
/// - 30-second timeout for request/response cycle
|
||||
///
|
||||
/// # Security
|
||||
/// Validates that `manager_url` uses an allowed scheme (`http` or `https`) and
|
||||
/// contains a valid host component. Rejects dangerous schemes like `file://`,
|
||||
/// `gopher://`, or URLs without a host.
|
||||
pub fn new(manager_url: &str) -> Self {
|
||||
// SECURITY: Validate URL scheme before building HTTP client.
|
||||
// Only http and https are permitted to prevent path traversal, SSRF,
|
||||
// or local file access via dangerous schemes (file://, gopher://, etc.).
|
||||
let parsed = url::Url::parse(manager_url)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid manager URL: {} — must be a valid URL", e))
|
||||
.expect("Failed to parse manager URL");
|
||||
|
||||
match parsed.scheme() {
|
||||
"http" | "https" => {}, // Allowed schemes
|
||||
other => panic!(
|
||||
"Invalid manager URL scheme '{}' — only 'http' and 'https' are allowed. \
|
||||
Refused dangerous scheme to prevent SSRF/path traversal.",
|
||||
other
|
||||
),
|
||||
}
|
||||
|
||||
// Ensure the URL has a host component (e.g., reject `http://` with no host)
|
||||
if parsed.host().is_none() {
|
||||
panic!(
|
||||
"Invalid manager URL — missing host component. \
|
||||
Manager URL must include a hostname or IP address (e.g., https://manager.example.com/api/v1)"
|
||||
);
|
||||
}
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to build reqwest client — static config should always succeed");
|
||||
|
||||
Self {
|
||||
manager_url: manager_url.to_string(),
|
||||
http_client,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the manager URL to an IP address.
|
||||
///
|
||||
/// Parses the `manager_url` to extract the host portion. If the host is
|
||||
/// already an IPv4/IPv6 address, returns it directly. Otherwise performs
|
||||
/// async DNS resolution via `tokio::net::lookup_host` and returns the first
|
||||
/// resolved IP.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Ok(String)` with the manager IP address (v4 or v6)
|
||||
/// - `Err` if URL parsing fails or DNS resolution yields no results
|
||||
pub async fn manager_ip(&self) -> Result<String> {
|
||||
// Parse URL to extract host using url crate for RFC-compliant parsing
|
||||
let parsed = url::Url::parse(&self.manager_url).with_context(|| {
|
||||
format!("Failed to parse manager URL '{}'", self.manager_url)
|
||||
})?;
|
||||
let host_str = parsed.host_str().with_context(|| {
|
||||
format!("Manager URL '{}' has no host component", self.manager_url)
|
||||
})?;
|
||||
|
||||
// Check if already an IP address using url::Host parsing
|
||||
if let Ok(url::Host::Ipv4(addr)) = url::Host::parse(host_str) {
|
||||
return Ok(addr.to_string());
|
||||
}
|
||||
if let Ok(url::Host::Ipv6(addr)) = url::Host::parse(host_str) {
|
||||
return Ok(addr.to_string());
|
||||
}
|
||||
|
||||
// It's a hostname — resolve via async DNS lookup
|
||||
tracing::info!(host = host_str, "Resolving manager hostname to IP address");
|
||||
let addrs: Vec<_> = tokio::net::lookup_host(format!("{}:1", host_str))
|
||||
.await
|
||||
.map(|iter| iter.collect())
|
||||
.with_context(|| format!("Failed to resolve manager hostname '{}'", host_str))?;
|
||||
|
||||
if addrs.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"DNS resolution returned no addresses for '{}'",
|
||||
host_str
|
||||
));
|
||||
}
|
||||
|
||||
// Return the first resolved IP (IPv4 typically preferred by resolver)
|
||||
let ip = addrs[0].ip();
|
||||
tracing::info!(resolved_ip = %ip, "Manager hostname resolved successfully");
|
||||
Ok(ip.to_string())
|
||||
}
|
||||
|
||||
/// Register this machine with the manager.
|
||||
///
|
||||
/// Collects host identity data (machine-id, FQDN, IP, OS details) and
|
||||
/// sends a `POST /api/v1/enroll` request to the manager.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Ok(EnrollmentResponse)` with the polling token on HTTP 202
|
||||
/// - Error on 429 (rate limited), 5xx (server error), or network failure
|
||||
pub async fn register(&self) -> Result<EnrollmentResponse> {
|
||||
// 1. Collect identity data
|
||||
let machine_id = identity::get_machine_id()
|
||||
.context("Failed to read machine-id — host cannot enroll without identity")?;
|
||||
let fqdn = identity::get_fqdn()
|
||||
.context("Failed to determine FQDN — check hostname configuration")?;
|
||||
let ip_addresses = identity::get_ip_addresses()
|
||||
.context("Failed to enumerate network interfaces — check network configuration")?;
|
||||
let os_details = identity::get_os_details()
|
||||
.context("Failed to collect OS details — /etc/os-release may be missing")?;
|
||||
|
||||
// Use first non-loopback IP (manager expects single string)
|
||||
let ip_address = ip_addresses
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "127.0.0.1".to_string());
|
||||
|
||||
// 2. Build EnrollmentRequest struct
|
||||
let request = EnrollmentRequest {
|
||||
machine_id,
|
||||
fqdn,
|
||||
ip_address,
|
||||
os_details,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
manager_url = %self.manager_url,
|
||||
"Sending enrollment registration request"
|
||||
);
|
||||
|
||||
// 3. POST to {manager_url}/api/v1/enroll
|
||||
let enroll_url = format!("{}/api/v1/enroll", self.manager_url);
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&enroll_url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Network error — failed to reach enrollment endpoint")?;
|
||||
|
||||
// 4. Handle response status codes
|
||||
match response.status().as_u16() {
|
||||
202 => {
|
||||
// Success — parse EnrollmentResponse with polling_token
|
||||
let body = response
|
||||
.text()
|
||||
.await
|
||||
.context("Failed to read enrollment response body")?;
|
||||
|
||||
let enrollment_response: EnrollmentResponse =
|
||||
serde_json::from_str(&body)
|
||||
.context("Invalid enrollment response — missing or malformed polling_token")?;
|
||||
|
||||
// SECURITY: Do not log polling_token - it is a bearer credential.
|
||||
// Log only that registration succeeded, never the token value itself.
|
||||
tracing::info!("Enrollment registration successful");
|
||||
|
||||
Ok(enrollment_response)
|
||||
}
|
||||
429 => {
|
||||
Err(anyhow!(
|
||||
"Rate limited (HTTP 429) — enrollment requests limited to 1/minute per IP. Retry after 60 seconds."
|
||||
))
|
||||
}
|
||||
status if status >= 500 => {
|
||||
let body = response.text().await.ok();
|
||||
Err(anyhow!(
|
||||
"Server error (HTTP {}) — {}. {}",
|
||||
status,
|
||||
body.as_deref().unwrap_or("no details"),
|
||||
"The manager may be experiencing issues"
|
||||
))
|
||||
}
|
||||
other => {
|
||||
let body = response.text().await.ok();
|
||||
Err(anyhow!(
|
||||
"Unexpected HTTP {} — {}",
|
||||
other,
|
||||
body.as_deref().unwrap_or("no details")
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll the enrollment status for a given token (single request).
|
||||
///
|
||||
/// Sends `GET /api/v1/enroll/status/{token}` to the manager and returns
|
||||
/// the deserialized status response.
|
||||
pub async fn poll_status(&self, token: &str) -> Result<EnrollmentStatusResponse> {
|
||||
let status_url = format!("{}/api/v1/enroll/status/{}", self.manager_url, token);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&status_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Network error — failed to reach enrollment status endpoint")?;
|
||||
|
||||
match response.status().as_u16() {
|
||||
200 => {
|
||||
let body = response
|
||||
.text()
|
||||
.await
|
||||
.context("Failed to read status response body")?;
|
||||
|
||||
let status: EnrollmentStatusResponse =
|
||||
serde_json::from_str(&body)
|
||||
.context("Invalid status response — malformed JSON from manager")?;
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
404 => Err(anyhow!("Enrollment token expired or invalid (HTTP 404)")),
|
||||
429 => Err(anyhow!(
|
||||
"Rate limited (HTTP 429) — polling too frequently. Back off and retry."
|
||||
)),
|
||||
status if status >= 500 => {
|
||||
let body = response.text().await.ok();
|
||||
Err(anyhow!(
|
||||
"Server error (HTTP {}) — {}. The manager may be experiencing issues.",
|
||||
status,
|
||||
body.as_deref().unwrap_or("no details")
|
||||
))
|
||||
}
|
||||
other => {
|
||||
let body = response.text().await.ok();
|
||||
Err(anyhow!(
|
||||
"Unexpected HTTP {} — {}",
|
||||
other,
|
||||
body.as_deref().unwrap_or("no details")
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll the manager for enrollment approval status.
|
||||
///
|
||||
/// Repeatedly calls `poll_status` until the request is approved, denied,
|
||||
/// token becomes invalid, or max attempts are exhausted.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `polling_token` - Opaque token returned by `register()`
|
||||
/// * `interval_seconds` - Sleep duration between polls (0 = use 60s default)
|
||||
/// * `max_attempts` - Maximum poll attempts (0 or >1440 clamped to 1440 for 24h cap)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(PkiBundle)` when approved — contains CA cert, server cert, and server key PEMs
|
||||
/// * `Err` on denial, token expiry, timeout, or user interruption
|
||||
pub async fn poll_for_approval(
|
||||
&self,
|
||||
polling_token: &str,
|
||||
interval_seconds: u64,
|
||||
max_attempts: u32,
|
||||
) -> Result<PkiBundle> {
|
||||
// Enforce hard limits
|
||||
let effective_interval = if interval_seconds == 0 { 60 } else { interval_seconds };
|
||||
let effective_max = match max_attempts {
|
||||
0 => 1440,
|
||||
n if n > 1440 => 1440,
|
||||
n => n,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
attempts_limit = effective_max,
|
||||
interval_seconds = effective_interval,
|
||||
"Starting enrollment approval polling loop"
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let sleep_duration = Duration::from_secs(effective_interval);
|
||||
|
||||
// Set up shutdown signal listeners (all target distros are Linux/Unix)
|
||||
let mut sigint_stream = Self::setup_sigint()?;
|
||||
let mut sigterm_stream = Self::setup_sigterm()?;
|
||||
|
||||
for attempt in 1..=effective_max {
|
||||
// Elapsed tracking for log throttling
|
||||
let elapsed = start.elapsed();
|
||||
let should_log = (attempt % 10 == 0) || elapsed.as_secs() >= 300;
|
||||
|
||||
if should_log && attempt > 1 {
|
||||
tracing::info!(
|
||||
attempt = attempt,
|
||||
max_attempts = effective_max,
|
||||
elapsed_seconds = elapsed.as_secs(),
|
||||
"Enrollment approval still pending — continuing to poll"
|
||||
);
|
||||
}
|
||||
|
||||
// Race: poll request vs shutdown signal
|
||||
let status = tokio::select! {
|
||||
result = self.poll_status(polling_token) => {
|
||||
match result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
attempt = attempt,
|
||||
"Transient poll error — will retry"
|
||||
);
|
||||
// Retry on transient errors (network, 5xx)
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SIGINT handler (Ctrl+C)
|
||||
_ = sigint_stream.recv() => {
|
||||
tracing::info!("Enrollment interrupted by user (SIGINT)");
|
||||
return Err(anyhow!("Enrollment interrupted by user"));
|
||||
}
|
||||
|
||||
// SIGTERM handler
|
||||
_ = sigterm_stream.recv() => {
|
||||
tracing::info!("Enrollment interrupted by system (SIGTERM)");
|
||||
return Err(anyhow!("Enrollment interrupted by system signal"));
|
||||
}
|
||||
};
|
||||
|
||||
// Process status response
|
||||
match status {
|
||||
EnrollmentStatusResponse::Pending => {
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
continue;
|
||||
}
|
||||
EnrollmentStatusResponse::Approved {
|
||||
ca_crt,
|
||||
server_crt,
|
||||
server_key,
|
||||
} => {
|
||||
tracing::info!(
|
||||
elapsed_seconds = start.elapsed().as_secs(),
|
||||
attempts = attempt,
|
||||
"Enrollment approved — received PKI bundle from manager"
|
||||
);
|
||||
return Ok(PkiBundle { ca_crt, server_crt, server_key });
|
||||
}
|
||||
EnrollmentStatusResponse::Denied => {
|
||||
tracing::warn!(
|
||||
elapsed_seconds = start.elapsed().as_secs(),
|
||||
"Enrollment request denied by administrator"
|
||||
);
|
||||
return Err(anyhow!("Enrollment request denied by administrator"));
|
||||
}
|
||||
EnrollmentStatusResponse::NotFound => {
|
||||
tracing::warn!(
|
||||
elapsed_seconds = start.elapsed().as_secs(),
|
||||
"Enrollment token expired or invalid (not found on manager)"
|
||||
);
|
||||
return Err(anyhow!("Enrollment token expired or invalid"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exhausted all attempts
|
||||
let total_seconds = effective_max as u64 * effective_interval;
|
||||
tracing::error!(
|
||||
max_attempts = effective_max,
|
||||
interval_seconds = effective_interval,
|
||||
total_seconds = total_seconds,
|
||||
"Enrollment polling timed out after maximum attempts"
|
||||
);
|
||||
Err(anyhow!("Enrollment timed out after {} hours ({}/{} attempts)",
|
||||
total_seconds / 3600, effective_max, effective_max))
|
||||
}
|
||||
|
||||
/// Create a SIGINT (Ctrl+C) signal receiver.
|
||||
fn setup_sigint() -> Result<tokio::signal::unix::Signal> {
|
||||
unix_signal(SignalKind::interrupt())
|
||||
.context("Failed to create SIGINT signal handler")
|
||||
}
|
||||
|
||||
/// Create a SIGTERM signal receiver.
|
||||
fn setup_sigterm() -> Result<tokio::signal::unix::Signal> {
|
||||
unix_signal(SignalKind::terminate())
|
||||
.context("Failed to create SIGTERM signal handler")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn enrollment_request_serializes() {
|
||||
let request = EnrollmentRequest {
|
||||
machine_id: "test1234".into(),
|
||||
fqdn: "node.example.com".into(),
|
||||
ip_address: "192.168.1.10".into(),
|
||||
os_details: serde_json::json!({"distro": "Debian", "version": "12"}),
|
||||
};
|
||||
let json = serde_json::to_string(&request).expect("Failed to serialize EnrollmentRequest");
|
||||
assert!(json.contains("machine_id"));
|
||||
assert!(json.contains("fqdn"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enrollment_response_deserializes() {
|
||||
let json = r#"{"polling_token": "abc123def456"}"#;
|
||||
let response: EnrollmentResponse =
|
||||
serde_json::from_str(json).expect("Failed to deserialize EnrollmentResponse");
|
||||
assert_eq!(response.polling_token, "abc123def456");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_pending_deserializes() {
|
||||
let json = r#"{"status": "pending"}"#;
|
||||
let status: EnrollmentStatusResponse =
|
||||
serde_json::from_str(json).expect("Failed to deserialize Pending");
|
||||
match status {
|
||||
EnrollmentStatusResponse::Pending => {}
|
||||
_ => panic!("Expected Pending variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_approved_deserializes() {
|
||||
let json = r#"{
|
||||
"status": "approved",
|
||||
"ca_crt": "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
"server_crt": "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----",
|
||||
"server_key": "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----"
|
||||
}"#;
|
||||
let status: EnrollmentStatusResponse =
|
||||
serde_json::from_str(json).expect("Failed to deserialize Approved");
|
||||
match status {
|
||||
EnrollmentStatusResponse::Approved { .. } => {}
|
||||
_ => panic!("Expected Approved variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn approved_to_pki_bundle() {
|
||||
let status = EnrollmentStatusResponse::Approved {
|
||||
ca_crt: "ca".into(),
|
||||
server_crt: "crt".into(),
|
||||
server_key: "key".into(),
|
||||
};
|
||||
let bundle: Option<PkiBundle> = status.into();
|
||||
assert!(bundle.is_some());
|
||||
let bundle = bundle.unwrap();
|
||||
assert_eq!(bundle.ca_crt, "ca");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pending_to_pki_bundle_is_none() {
|
||||
let status = EnrollmentStatusResponse::Pending;
|
||||
let bundle: Option<PkiBundle> = status.into();
|
||||
assert!(bundle.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enrollment_client_has_insecure_tls() {
|
||||
let client = EnrollmentClient::new("https://manager.example.com/api/v1");
|
||||
// Client builds without panic — danger_accept_invalid_certs is set
|
||||
assert_eq!(client.manager_url, "https://manager.example.com/api/v1");
|
||||
}
|
||||
}
|
||||
164
src/enroll/identity.rs
Normal file
164
src/enroll/identity.rs
Normal file
@ -0,0 +1,164 @@
|
||||
//! Cross-distribution identity extraction for Linux systems.
|
||||
//!
|
||||
//! Provides machine-id, FQDN, IP address, and OS-detail collection
|
||||
//! compatible with Debian/Ubuntu, RHEL/CentOS/Fedora, Alpine, and Arch Linux.
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use std::fs;
|
||||
use std::net::IpAddr;
|
||||
use std::process::Command;
|
||||
|
||||
/// Read the D-Bus machine identifier from `/etc/machine-id`.
|
||||
/// Falls back to `/var/lib/dbus/machine-id` on older systems.
|
||||
pub fn get_machine_id() -> Result<String> {
|
||||
let primary = "/etc/machine-id";
|
||||
let fallback = "/var/lib/dbus/machine-id";
|
||||
|
||||
if let Ok(id) = fs::read_to_string(primary) {
|
||||
let trimmed = id.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
return Ok(trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
let id = fs::read_to_string(fallback)
|
||||
.with_context(|| format!("Failed to read machine-id from {} or {}", primary, fallback))?;
|
||||
let trimmed = id.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
return Err(anyhow!("machine-id file is empty"));
|
||||
}
|
||||
Ok(trimmed)
|
||||
}
|
||||
|
||||
/// Resolve the fully-qualified domain name.
|
||||
/// Strategy: `gethostname` via std → fallback to `hostname` CLI → "localhost".
|
||||
pub fn get_fqdn() -> Result<String> {
|
||||
// Try reading from hostname file first (common on systemd systems)
|
||||
if let Ok(name) = fs::read_to_string("/etc/hostname") {
|
||||
let trimmed = name.trim().to_string();
|
||||
if !trimmed.is_empty() && trimmed != "(none)" {
|
||||
return Ok(trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to hostname command
|
||||
if let Ok(output) = Command::new("hostname").arg("-f").output() {
|
||||
if output.status.success() {
|
||||
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !name.is_empty() {
|
||||
return Ok(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to plain hostname
|
||||
if let Ok(output) = Command::new("hostname").output() {
|
||||
if output.status.success() {
|
||||
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !name.is_empty() {
|
||||
return Ok(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok("localhost".into())
|
||||
}
|
||||
|
||||
/// Collect all non-loopback IPv4 addresses from network interfaces.
|
||||
pub fn get_ip_addresses() -> Result<Vec<String>> {
|
||||
let ifaces = if_addrs::get_if_addrs()
|
||||
.context("Failed to enumerate network interfaces")?;
|
||||
|
||||
let mut addrs: Vec<String> = ifaces
|
||||
.iter()
|
||||
.filter_map(|iface| {
|
||||
if iface.is_loopback() {
|
||||
return None;
|
||||
}
|
||||
match &iface.ip() {
|
||||
IpAddr::V4(addr) => Some(addr.to_string()),
|
||||
IpAddr::V6(_) => None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
addrs.sort();
|
||||
addrs.dedup();
|
||||
Ok(addrs)
|
||||
}
|
||||
|
||||
/// Extract OS distribution details from `/etc/os-release` and kernel version.
|
||||
/// Returns a JSON object with: distro, version, id_like, kernel.
|
||||
pub fn get_os_details() -> Result<serde_json::Value> {
|
||||
let mut details = serde_json::Map::new();
|
||||
|
||||
// Parse /etc/os-release (exists on all target distros)
|
||||
if let Ok(content) = fs::read_to_string("/etc/os-release") {
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
// Strip surrounding quotes from value
|
||||
let unquoted = value.trim().trim_matches('"').trim_matches('\'');
|
||||
match key {
|
||||
"NAME" => {
|
||||
details.insert("distro".into(), serde_json::Value::String(unquoted.to_string()));
|
||||
}
|
||||
"VERSION_ID" => {
|
||||
details.insert("version".into(), serde_json::Value::String(unquoted.to_string()));
|
||||
}
|
||||
"ID_LIKE" => {
|
||||
details.insert("id_like".into(), serde_json::Value::String(unquoted.to_string()));
|
||||
}
|
||||
"VERSION_CODENAME" => {
|
||||
details.insert("codename".into(), serde_json::Value::String(unquoted.to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback for systems without os-release (very rare)
|
||||
details.insert("distro".into(), serde_json::Value::String("unknown".into()));
|
||||
details.insert("version".into(), serde_json::Value::String("unknown".into()));
|
||||
}
|
||||
|
||||
// Kernel version via uname -r
|
||||
if let Ok(output) = Command::new("uname").arg("-r").output() {
|
||||
if output.status.success() {
|
||||
let kernel = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
details.insert("kernel".into(), serde_json::Value::String(kernel));
|
||||
}
|
||||
} else {
|
||||
details.insert("kernel".into(), serde_json::Value::String("unknown".into()));
|
||||
}
|
||||
|
||||
Ok(serde_json::Value::Object(details))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn machine_id_is_not_empty() {
|
||||
let id = get_machine_id().expect("Failed to get machine-id");
|
||||
assert!(!id.is_empty(), "machine-id should not be empty");
|
||||
assert_eq!(id.len(), 32, "machine-id should be 32 hex chars");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fqdn_is_not_empty() {
|
||||
let fqdn = get_fqdn().expect("Failed to get FQDN");
|
||||
assert!(!fqdn.is_empty(), "FQDN should not be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn os_details_contains_kernel() {
|
||||
let details = get_os_details().expect("Failed to get OS details");
|
||||
assert!(details.get("kernel").is_some(), "OS details must contain kernel version");
|
||||
}
|
||||
}
|
||||
77
src/enroll/mod.rs
Normal file
77
src/enroll/mod.rs
Normal file
@ -0,0 +1,77 @@
|
||||
//! Self-enrollment module for linux_patch_api daemon.
|
||||
//!
|
||||
//! Handles secure registration with the patch manager, including
|
||||
//! identity extraction (machine-id, FQDN, IPs, OS details) and
|
||||
//! mTLS enrollment via the manager API.
|
||||
|
||||
pub mod client;
|
||||
pub mod identity;
|
||||
pub mod provision;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
/// Re-export key types for ergonomic access from parent modules.
|
||||
pub use client::{
|
||||
EnrollmentClient, EnrollmentRequest, EnrollmentResponse,
|
||||
EnrollmentStatusResponse, PkiBundle,
|
||||
};
|
||||
/// Re-export identity extraction functions.
|
||||
pub use identity::{get_fqdn, get_ip_addresses, get_machine_id, get_os_details};
|
||||
|
||||
/// Run the full enrollment flow against the manager at the given URL.
|
||||
///
|
||||
/// # Phases
|
||||
/// 1. **Registration** - POST machine identity to manager, receive polling token
|
||||
/// 2. **Polling** - Poll manager for approval with configurable interval/max attempts
|
||||
/// 3. **Provisioning** - Write PKI bundle to disk (certs/keys) and append manager IP to whitelist
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns Err on registration failure, polling timeout, denial, user interruption,
|
||||
/// PKI provisioning failure, or whitelist update failure.
|
||||
pub async fn run_enrollment(manager_url: &str, config: &super::AppConfig) -> Result<()> {
|
||||
let client = EnrollmentClient::new(manager_url);
|
||||
|
||||
// Phase 1: Registration
|
||||
tracing::info!(
|
||||
manager_url = manager_url,
|
||||
"Starting enrollment - registration phase"
|
||||
);
|
||||
let response = client.register().await?;
|
||||
tracing::info!("Registration successful - received polling token");
|
||||
|
||||
// Get polling config (use defaults if not set)
|
||||
let interval = config.enrollment.as_ref()
|
||||
.map(|e| e.polling_interval_seconds).unwrap_or(60);
|
||||
let max_attempts = config.enrollment.as_ref()
|
||||
.map(|e| e.max_poll_attempts).unwrap_or(1440);
|
||||
|
||||
// Phase 2: Polling
|
||||
tracing::info!(
|
||||
interval_seconds = interval,
|
||||
max_attempts = max_attempts,
|
||||
"Starting enrollment - polling phase"
|
||||
);
|
||||
let pki_bundle = client.poll_for_approval(&response.polling_token, interval, max_attempts).await?;
|
||||
|
||||
// Phase 3: PKI provisioning & whitelist update
|
||||
tracing::info!("Enrollment approved - starting PKI provisioning phase");
|
||||
|
||||
// Write certificates to configured paths (or defaults)
|
||||
provision::provision_pki_bundle(
|
||||
&pki_bundle.ca_crt,
|
||||
&pki_bundle.server_crt,
|
||||
&pki_bundle.server_key,
|
||||
config.tls_config(),
|
||||
).await?;
|
||||
tracing::info!("PKI bundle written to disk");
|
||||
|
||||
// Resolve manager hostname to IP and append to whitelist
|
||||
let manager_ip = client.manager_ip().await.context(
|
||||
"Failed to resolve manager IP - cannot update whitelist",
|
||||
)?;
|
||||
provision::append_manager_to_whitelist(&manager_ip, config.whitelist_path()).await?;
|
||||
tracing::info!(manager_ip = %manager_ip, "Manager IP appended to whitelist");
|
||||
|
||||
tracing::info!("Enrollment complete - PKI and whitelist configured");
|
||||
Ok(())
|
||||
}
|
||||
361
src/enroll/provision.rs
Normal file
361
src/enroll/provision.rs
Normal file
@ -0,0 +1,361 @@
|
||||
//! PKI provisioning module for self-enrollment.
|
||||
//! Handles certificate extraction, validation, and secure file writing.
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use crate::auth::WhitelistManager;
|
||||
use std::fs::{self, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
/// Default certificate directory when TLS config is not provided.
|
||||
#[allow(dead_code)]
|
||||
const DEFAULT_CERT_DIR: &str = "/etc/linux_patch_api/certs";
|
||||
/// Default CA certificate path.
|
||||
const DEFAULT_CA_CERT: &str = "/etc/linux_patch_api/certs/ca.pem";
|
||||
/// Default server certificate path.
|
||||
const DEFAULT_SERVER_CERT: &str = "/etc/linux_patch_api/certs/server.pem";
|
||||
/// Default server key path.
|
||||
const DEFAULT_SERVER_KEY: &str = "/etc/linux_patch_api/certs/server.key.pem";
|
||||
|
||||
/// Validate that a PEM string has proper format (BEGIN/END markers present).
|
||||
///
|
||||
/// Checks for `-----BEGIN {expected_type}-----` and `-----END {expected_type}-----` markers.
|
||||
/// Returns an error if either marker is missing or the data is empty.
|
||||
pub fn validate_pem(pem_data: &str, expected_type: &str) -> Result<()> {
|
||||
let trimmed = pem_data.trim();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
bail!("PEM data is empty for type '{}'", expected_type);
|
||||
}
|
||||
|
||||
let begin_marker = format!("-----BEGIN {}-----", expected_type);
|
||||
let end_marker = format!("-----END {}-----", expected_type);
|
||||
|
||||
if !trimmed.contains(&begin_marker) {
|
||||
bail!(
|
||||
"Invalid PEM format: missing '{}' marker for type '{}'",
|
||||
begin_marker,
|
||||
expected_type
|
||||
);
|
||||
}
|
||||
|
||||
if !trimmed.contains(&end_marker) {
|
||||
bail!(
|
||||
"Invalid PEM format: missing '{}' marker for type '{}'",
|
||||
end_marker,
|
||||
expected_type
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write PEM data to disk with secure permissions using atomic write pattern.
|
||||
///
|
||||
/// 1. Create target directory if it doesn't exist (with 0o755 permissions)
|
||||
/// 2. Backup existing file if present (.bak extension)
|
||||
/// 3. Write to temp file in same directory
|
||||
/// 4. Set correct permissions (key=0o600, certs=0o644)
|
||||
/// 5. Rename atomically to target path
|
||||
pub fn write_pem_file(path: &str, pem_data: &str, is_key: bool) -> Result<()> {
|
||||
let path = std::path::Path::new(path);
|
||||
|
||||
// Ensure target directory exists
|
||||
if let Some(parent) = path.parent() {
|
||||
if !parent.exists() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create directory: {}", parent.display()))?;
|
||||
// Set directory permissions (0o755 for readability by service, restricted write)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mut perms = fs::metadata(parent)?.permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(parent, perms)
|
||||
.with_context(|| format!("Failed to set permissions on: {}", parent.display()))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backup existing file if present
|
||||
if path.exists() {
|
||||
let backup_path = format!("{}.bak", path.display());
|
||||
fs::rename(path, &backup_path)
|
||||
.with_context(|| format!("Failed to backup existing file: {}", path.display()))?;
|
||||
tracing::info!(
|
||||
original = %path.display(),
|
||||
backup = %backup_path,
|
||||
"Backed up existing certificate file"
|
||||
);
|
||||
}
|
||||
|
||||
// Create temp file in same directory for atomic rename
|
||||
let temp_path = path.with_extension("tmp");
|
||||
|
||||
// Write PEM data to temp file
|
||||
let mut file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.truncate(true)
|
||||
.mode(if is_key { 0o600 } else { 0o644 })
|
||||
.open(&temp_path)
|
||||
.with_context(|| format!("Failed to create temp file: {}", temp_path.display()))?;
|
||||
|
||||
file.write_all(pem_data.as_bytes())
|
||||
.with_context(|| format!("Failed to write PEM data to: {}", temp_path.display()))?;
|
||||
file.flush()
|
||||
.with_context(|| format!("Failed to flush PEM data to: {}", temp_path.display()))?;
|
||||
|
||||
// Atomic rename to target path
|
||||
fs::rename(&temp_path, path)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to atomically rename {} to {}",
|
||||
temp_path.display(),
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(
|
||||
path = %path.display(),
|
||||
is_key = is_key,
|
||||
permissions = if is_key { "0600" } else { "0644" },
|
||||
"Successfully wrote PEM file"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Provision the full PKI bundle from an approved enrollment response.
|
||||
///
|
||||
/// Writes CA cert, server cert, and server key to configured paths.
|
||||
/// Paths are read from TLS config if available, otherwise defaults are used.
|
||||
pub async fn provision_pki_bundle(
|
||||
ca_crt: &str,
|
||||
server_crt: &str,
|
||||
server_key: &str,
|
||||
tls_config: Option<&super::super::config::loader::TlsConfig>,
|
||||
) -> Result<()> {
|
||||
// Determine target paths from config or defaults
|
||||
let (ca_path, cert_path, key_path) = if let Some(tls) = tls_config {
|
||||
(tls.ca_cert.clone(), tls.server_cert.clone(), tls.server_key.clone())
|
||||
} else {
|
||||
(
|
||||
DEFAULT_CA_CERT.to_string(),
|
||||
DEFAULT_SERVER_CERT.to_string(),
|
||||
DEFAULT_SERVER_KEY.to_string(),
|
||||
)
|
||||
};
|
||||
|
||||
// 1. Validate all three PEM strings before any writes
|
||||
validate_pem(ca_crt, "CERTIFICATE")
|
||||
.context("CA certificate validation failed")?;
|
||||
validate_pem(server_crt, "CERTIFICATE")
|
||||
.context("Server certificate validation failed")?;
|
||||
|
||||
// Server key can be PRIVATE KEY (PKCS#8), RSA PRIVATE KEY (PKCS#1), or EC PRIVATE KEY
|
||||
let key_valid = validate_pem(server_key, "PRIVATE KEY").is_ok()
|
||||
|| validate_pem(server_key, "RSA PRIVATE KEY").is_ok()
|
||||
|| validate_pem(server_key, "EC PRIVATE KEY").is_ok();
|
||||
|
||||
if !key_valid {
|
||||
bail!(
|
||||
"Server key validation failed: PEM must be PRIVATE KEY, RSA PRIVATE KEY, or EC PRIVATE KEY"
|
||||
);
|
||||
}
|
||||
|
||||
// 2. Write to configured paths (atomic writes)
|
||||
write_pem_file(&ca_path, ca_crt, false)
|
||||
.context("Failed to write CA certificate")?;
|
||||
|
||||
write_pem_file(&cert_path, server_crt, false)
|
||||
.context("Failed to write server certificate")?;
|
||||
|
||||
write_pem_file(&key_path, server_key, true)
|
||||
.context("Failed to write server key")?;
|
||||
|
||||
// 3. Log successful provisioning with structured fields
|
||||
tracing::info!(
|
||||
ca_cert = %ca_path,
|
||||
server_cert = %cert_path,
|
||||
server_key = %key_path,
|
||||
"PKI bundle provisioned successfully - all certificates written and validated"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append the manager IP to the whitelist after successful enrollment.
|
||||
///
|
||||
/// Creates or loads a `WhitelistManager` and calls `append_entry()` with the
|
||||
/// provided IP/CIDR string. Returns an error if the file cannot be locked,
|
||||
/// written, or reloaded.
|
||||
pub async fn append_manager_to_whitelist(manager_ip: &str, whitelist_path: &str) -> Result<()> {
|
||||
// Validate input before touching any files
|
||||
let ip_or_cidr = manager_ip.trim();
|
||||
if ip_or_cidr.is_empty() {
|
||||
bail!("Manager IP address cannot be empty");
|
||||
}
|
||||
|
||||
// Create or load WhitelistManager and call append_entry
|
||||
let mut manager = WhitelistManager::new(whitelist_path)
|
||||
.with_context(|| format!("Failed to initialize whitelist manager for path: {}", whitelist_path))?;
|
||||
|
||||
manager.append_entry(ip_or_cidr)
|
||||
.with_context(|| format!("Failed to append manager IP '{}' to whitelist at: {}", ip_or_cidr, whitelist_path))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_certificate() -> String {
|
||||
"-----BEGIN CERTIFICATE-----\nMIIBxTCCAWugAwIBAgIRA ...\nBASE64ENCODED DATA HERE ...\n-----END CERTIFICATE-----".to_string()
|
||||
}
|
||||
|
||||
fn sample_rsa_key() -> String {
|
||||
"-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3...\nBASE64ENCODED DATA HERE ...\n-----END RSA PRIVATE KEY-----".to_string()
|
||||
}
|
||||
|
||||
fn sample_pkcs8_key() -> String {
|
||||
"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQE...\nBASE64ENCODED DATA HERE ...\n-----END PRIVATE KEY-----".to_string()
|
||||
}
|
||||
|
||||
fn sample_ec_key() -> String {
|
||||
"-----BEGIN EC PRIVATE KEY-----\nMHQCAQEEIBkg5Lb/...\nBASE64ENCODED DATA HERE ...\n-----END EC PRIVATE KEY-----".to_string()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_valid_certificate() {
|
||||
let cert = sample_certificate();
|
||||
assert!(validate_pem(&cert, "CERTIFICATE").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_valid_rsa_key() {
|
||||
let key = sample_rsa_key();
|
||||
assert!(validate_pem(&key, "RSA PRIVATE KEY").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_valid_pkcs8_key() {
|
||||
let key = sample_pkcs8_key();
|
||||
assert!(validate_pem(&key, "PRIVATE KEY").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_valid_ec_key() {
|
||||
let key = sample_ec_key();
|
||||
assert!(validate_pem(&key, "EC PRIVATE KEY").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_empty_data_fails() {
|
||||
assert!(validate_pem("", "CERTIFICATE").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_missing_begin_marker_fails() {
|
||||
let malformed = "BASE64DATA\n-----END CERTIFICATE-----".to_string();
|
||||
let err = validate_pem(&malformed, "CERTIFICATE").unwrap_err();
|
||||
assert!(err.to_string().contains("BEGIN"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_missing_end_marker_fails() {
|
||||
let malformed = "-----BEGIN CERTIFICATE-----\nBASE64DATA".to_string();
|
||||
let err = validate_pem(&malformed, "CERTIFICATE").unwrap_err();
|
||||
assert!(err.to_string().contains("END"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_wrong_type_fails() {
|
||||
let cert = sample_certificate();
|
||||
// Certificate data checked against wrong type should fail
|
||||
let err = validate_pem(&cert, "RSA PRIVATE KEY").unwrap_err();
|
||||
assert!(err.to_string().contains("BEGIN"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_pem_whitespace_tolerance() {
|
||||
let cert = format!("\n \n {} \n ", sample_certificate());
|
||||
assert!(validate_pem(&cert, "CERTIFICATE").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_pem_file_creates_directory() {
|
||||
let dir = tempdir().expect("failed to create temp dir");
|
||||
let target_path = dir.path().join("subdir").join("cert.pem");
|
||||
let cert = sample_certificate();
|
||||
|
||||
write_pem_file(target_path.to_str().unwrap(), &cert, false).expect("write failed");
|
||||
assert!(target_path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_pem_file_atomic_rename() {
|
||||
let dir = tempdir().expect("failed to create temp dir");
|
||||
let target_path = dir.path().join("cert.pem");
|
||||
let cert = sample_certificate();
|
||||
|
||||
write_pem_file(target_path.to_str().unwrap(), &cert, false).expect("write failed");
|
||||
|
||||
// Verify content matches
|
||||
let written = fs::read_to_string(&target_path).expect("failed to read back");
|
||||
assert_eq!(written, cert);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_pem_file_key_permissions() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let dir = tempdir().expect("failed to create temp dir");
|
||||
let target_path = dir.path().join("key.pem");
|
||||
let key = sample_rsa_key();
|
||||
|
||||
write_pem_file(target_path.to_str().unwrap(), &key, true).expect("write failed");
|
||||
|
||||
let metadata = fs::metadata(&target_path).expect("failed to get metadata");
|
||||
let mode = metadata.permissions().mode() & 0o777;
|
||||
assert_eq!(mode, 0o600, "Key file should have 0600 permissions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_pem_file_cert_permissions() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let dir = tempdir().expect("failed to create temp dir");
|
||||
let target_path = dir.path().join("cert.pem");
|
||||
let cert = sample_certificate();
|
||||
|
||||
write_pem_file(target_path.to_str().unwrap(), &cert, false).expect("write failed");
|
||||
|
||||
let metadata = fs::metadata(&target_path).expect("failed to get metadata");
|
||||
let mode = metadata.permissions().mode() & 0o777;
|
||||
assert_eq!(mode, 0o644, "Cert file should have 0644 permissions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_pem_file_backup_existing() {
|
||||
let dir = tempdir().expect("failed to create temp dir");
|
||||
let target_path = dir.path().join("cert.pem");
|
||||
let cert1 = sample_certificate();
|
||||
let cert2 = "-----BEGIN CERTIFICATE-----\nNEWCERTDATA\n-----END CERTIFICATE-----".to_string();
|
||||
|
||||
// Write initial file
|
||||
write_pem_file(target_path.to_str().unwrap(), &cert1, false).expect("initial write failed");
|
||||
|
||||
// Write again - should create backup
|
||||
write_pem_file(target_path.to_str().unwrap(), &cert2, false).expect("second write failed");
|
||||
|
||||
let backup_path = format!("{}.bak", target_path.display());
|
||||
assert!(std::path::Path::new(&backup_path).exists(), "Backup file should exist");
|
||||
|
||||
// Original content in backup
|
||||
let backup_content = fs::read_to_string(&backup_path).expect("failed to read backup");
|
||||
assert_eq!(backup_content, cert1);
|
||||
}
|
||||
}
|
||||
@ -15,6 +15,7 @@
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod enroll;
|
||||
pub mod jobs;
|
||||
pub mod logging;
|
||||
pub mod packages;
|
||||
|
||||
19
src/main.rs
19
src/main.rs
@ -24,6 +24,7 @@ use tracing::{error, info, warn};
|
||||
use linux_patch_api::api::{configure_api_routes, configure_health_route};
|
||||
use linux_patch_api::auth::{mtls, MtlsMiddleware, WhitelistManager};
|
||||
use linux_patch_api::packages::create_backend;
|
||||
use linux_patch_api::enroll;
|
||||
use linux_patch_api::{init_logging, AppConfig, JobManager};
|
||||
|
||||
/// Linux Patch API CLI arguments
|
||||
@ -39,6 +40,10 @@ struct Args {
|
||||
/// Enable verbose logging
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
|
||||
/// Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)
|
||||
#[arg(long, help = "Enroll with manager at URL (skips mTLS startup, runs enrollment flow only)")]
|
||||
enroll: Option<String>,
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
@ -71,6 +76,20 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
// Handle enrollment mode - runs before server startup
|
||||
if let Some(ref manager_url) = args.enroll {
|
||||
info!(manager_url = manager_url, "Enrollment mode activated - running enrollment flow before server startup");
|
||||
match enroll::run_enrollment(manager_url, &config).await {
|
||||
Ok(()) => {
|
||||
info!("Enrollment complete - proceeding to server startup");
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Enrollment failed - shutting down");
|
||||
return Err(anyhow::anyhow!("Enrollment failed: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize job manager
|
||||
let job_manager = JobManager::new(config.jobs.max_concurrent, config.jobs.timeout_minutes)?;
|
||||
info!(
|
||||
|
||||
Reference in New Issue
Block a user