Private
Public Access
1
0

feat: add host self-enrollment workflow v0.1.7
All checks were successful
CI Pipeline / Rust Format Check (push) Successful in 5s
CI Pipeline / Clippy Lints (push) Successful in 53s
CI Pipeline / Rust Unit Tests (push) Successful in 1m11s
CI Pipeline / Security Audit (push) Successful in 4s
CI Pipeline / Frontend Lint & Type Check (push) Successful in 14s
CI Pipeline / Build .deb & Release (push) Has been skipped

This commit is contained in:
2026-05-16 16:58:00 +00:00
parent f183c8edf8
commit da3dffd81f
17 changed files with 841 additions and 55 deletions

View File

@ -1,6 +1,8 @@
use crate::config::DatabaseConfig;
use crate::models::{CreateEnrollmentRequest, EnrollmentRequest};
use sqlx::postgres::{PgPool, PgPoolOptions};
use std::time::Duration;
use uuid::Uuid;
/// Initialize and return a PostgreSQL connection pool.
pub async fn init_pool(cfg: &DatabaseConfig) -> Result<PgPool, sqlx::Error> {
@ -56,6 +58,53 @@ pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::migrate::MigrateE
result
}
// ============================================================
// Enrollment Requests
// ============================================================
pub async fn create_enrollment_request(
pool: &PgPool,
req: CreateEnrollmentRequest,
token_hash: String,
) -> Result<EnrollmentRequest, sqlx::Error> {
sqlx::query_as::<
_,
EnrollmentRequest,
>(
r#"
INSERT INTO enrollment_requests (machine_id, fqdn, ip_address, os_details, polling_token)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, machine_id, fqdn, ip_address, os_details, polling_token, created_at, expires_at
"#,
)
.bind(req.machine_id)
.bind(req.fqdn)
.bind(req.ip_address)
.bind(req.os_details)
.bind(token_hash)
.fetch_one(pool)
.await
}
pub async fn list_enrollment_requests(
pool: &PgPool,
) -> Result<Vec<EnrollmentRequest>, sqlx::Error> {
sqlx::query_as::<_, EnrollmentRequest>(
"SELECT id, machine_id, fqdn, ip_address, os_details, polling_token, created_at, expires_at FROM enrollment_requests ORDER BY created_at DESC",
)
.fetch_all(pool)
.await
}
pub async fn delete_enrollment_request(pool: &PgPool, id: Uuid) -> Result<u64, sqlx::Error> {
let result = sqlx::query("DELETE FROM enrollment_requests WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
/// Check that the database schema is at the expected version.
/// Used by the worker to wait until migrations have been applied.
pub async fn check_schema_version(pool: &PgPool) -> Result<i64, sqlx::Error> {

View File

@ -123,6 +123,51 @@ pub struct HostSummary {
pub registered_at: DateTime<Utc>,
}
// ============================================================
// Host Enrollment
// ============================================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct EnrollmentRequest {
pub id: Uuid,
pub machine_id: String,
pub fqdn: String,
pub ip_address: String,
pub os_details: serde_json::Value,
pub polling_token: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
/// Payload for initial host enrollment request.
#[derive(Debug, Deserialize, Serialize)]
pub struct CreateEnrollmentRequest {
pub machine_id: String,
pub fqdn: String,
pub ip_address: String,
pub os_details: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "lowercase")]
pub enum EnrollmentStatusResponse {
Pending,
Approved {
ca_crt: String,
server_crt: String,
server_key: String,
},
Denied,
NotFound,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PkiBundle {
pub ca_crt: String,
pub server_crt: String,
pub server_key: String,
}
// ============================================================
// Health Checks
// ============================================================

View File

@ -36,6 +36,7 @@ dashmap = { version = "6" }
reqwest = { workspace = true }
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
rand = { workspace = true }
hex = "0.4"
base64 = { workspace = true }
sha2 = { workspace = true }
jsonwebtoken = { workspace = true }

View File

@ -9,11 +9,17 @@ use pm_auth::{
jwt,
rbac::{require_auth, AuthConfig},
};
use pm_core::{config::AppConfig, db, logging, request_id::request_id_middleware};
use pm_core::{
config::AppConfig, db, logging, models::PkiBundle, request_id::request_id_middleware,
};
use routes::sso::{OidcCache, SsoSession};
use routes::ws::WsTicket;
use serde_json::{json, Value};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use tower_http::{
services::{ServeDir, ServeFile},
@ -35,6 +41,10 @@ pub struct AppState {
pub oidc_cache: Arc<Mutex<OidcCache>>,
/// Internal certificate authority for mTLS client cert issuance.
pub ca: Arc<pm_ca::CertAuthority>,
/// IP-based rate limits for enrollment requests.
pub enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>>,
/// Short-lived cache for approved enrollment PKI bundles.
pub approved_enrollments: Arc<DashMap<String, PkiBundle>>,
}
#[tokio::main]
@ -91,6 +101,8 @@ async fn main() -> anyhow::Result<()> {
let ws_tickets: Arc<DashMap<String, WsTicket>> = Arc::new(DashMap::new());
let sso_sessions: Arc<DashMap<String, SsoSession>> = Arc::new(DashMap::new());
let oidc_cache: Arc<Mutex<OidcCache>> = Arc::new(Mutex::new(OidcCache::default()));
let enrollment_rate_limits: Arc<DashMap<IpAddr, Instant>> = Arc::new(DashMap::new());
let approved_enrollments: Arc<DashMap<String, PkiBundle>> = Arc::new(DashMap::new());
// Background task: purge expired WS tickets every 30 seconds.
{
@ -129,6 +141,31 @@ async fn main() -> anyhow::Result<()> {
});
}
// Background task: purge expired enrollment rate limits every 5 minutes.
{
let limits = enrollment_rate_limits.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300));
loop {
interval.tick().await;
let now = Instant::now();
limits.retain(|_, v| now.duration_since(*v) < Duration::from_secs(3600));
}
});
}
// Background task: purge approved enrollment PKI bundles every 10 minutes.
{
let approved = approved_enrollments.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(600));
loop {
interval.tick().await;
approved.clear();
}
});
}
let state = AppState {
db: pool,
config: Arc::new(config.clone()),
@ -137,6 +174,8 @@ async fn main() -> anyhow::Result<()> {
ws_tickets,
sso_sessions,
ca: Arc::new(ca),
enrollment_rate_limits,
approved_enrollments,
oidc_cache,
};
@ -223,6 +262,8 @@ pub fn build_router(state: AppState) -> Router {
)
// Settings (admin-only)
.nest("/settings", routes::settings::router())
// Admin enrollment routes (JWT protected, Admin role enforced)
.nest("/admin", routes::enrollment::admin_router())
// Apply auth middleware to all the above
.route_layer(middleware::from_fn(move |req, next| {
let auth_config = auth_config.clone();
@ -233,6 +274,8 @@ pub fn build_router(state: AppState) -> Router {
.route("/status/health", get(health_handler))
// Public auth routes (no JWT needed)
.nest("/api/v1/auth", routes::auth::public_router())
// Public enrollment endpoints (rate-limited, no JWT)
.nest("/api/v1", routes::enrollment::router())
// Public SSO routes (no JWT needed)
.nest("/api/v1/auth/sso", routes::sso::public_router())
// Public Azure SSO routes (no JWT needed)

View File

@ -0,0 +1,318 @@
use crate::AppState;
use axum::{
extract::{ConnectInfo, Path, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::{delete, get, post},
Json, Router,
};
use chrono::Utc;
use pm_auth::AuthUser;
use pm_core::{
db,
models::{
CreateEnrollmentRequest, EnrollmentRequest, EnrollmentStatusResponse, Host, PkiBundle,
},
};
use rand::{distributions::Alphanumeric, Rng};
use serde::Serialize;
use std::net::{IpAddr, SocketAddr};
use std::time::Instant;
#[derive(Debug, Clone, Serialize)]
pub struct HostConflict {
pub existing_host: Host,
pub message: String,
}
/// Define public enrollment routes.
pub fn router() -> Router<AppState> {
Router::new()
.route("/enroll", post(enroll_host))
.route("/enroll/status/:token", get(enroll_status))
}
/// POST /api/v1/enroll
/// Initiates host self-enrollment.
async fn enroll_host(
State(state): State<AppState>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(payload): Json<CreateEnrollmentRequest>,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
// 1. IP-based Rate Limiting
// Prefer real IP from headers if behind proxy (e.g., X-Forwarded-For), else use SocketAddr
let ip = headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(',').next())
.and_then(|h| h.trim().parse::<IpAddr>().ok())
.unwrap_or_else(|| addr.ip());
{
let mut rate_limits = state
.enrollment_rate_limits
.entry(ip)
.or_insert(Instant::now() - std::time::Duration::from_secs(3600));
let last_request = rate_limits.value();
if last_request.elapsed().as_secs() < 60 {
// 1 request per minute per IP
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({ "error": "Rate limit exceeded. Try again in a minute." })),
));
}
*rate_limits = Instant::now();
}
// 2. Generate secure random polling token
let polling_token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
// For database storage, we'll hash the token (spec says hashed)
// Using a simple SHA256 or similar for the hash storage
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(polling_token.as_bytes());
let token_hash = hex::encode(hasher.finalize());
// 3. Store in DB
db::create_enrollment_request(&state.db, payload, token_hash)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to create enrollment request");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// 4. Return the raw token to the client
Ok((
StatusCode::ACCEPTED,
Json(serde_json::json!({ "polling_token": polling_token })),
)
.into_response())
}
/// GET /api/v1/enroll/status/:token
/// Returns status of enrollment (pending/approved/denied/not_found).
async fn enroll_status(
State(state): State<AppState>,
Path(token): Path<String>,
) -> Result<Json<EnrollmentStatusResponse>, (StatusCode, Json<serde_json::Value>)> {
// Hash the provided token to match DB
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
let token_hash = hex::encode(hasher.finalize());
// 1. Check enrollment_requests table
let requests = db::list_enrollment_requests(&state.db).await.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
if let Some(req) = requests.into_iter().find(|r| r.polling_token == token_hash) {
if req.expires_at < Utc::now() {
return Ok(Json(EnrollmentStatusResponse::NotFound));
}
return Ok(Json(EnrollmentStatusResponse::Pending));
}
// 2. If not in pending, check if it was recently approved.
if let Some(pki) = state.approved_enrollments.get(&token_hash) {
return Ok(Json(EnrollmentStatusResponse::Approved {
ca_crt: pki.ca_crt.clone(),
server_crt: pki.server_crt.clone(),
server_key: pki.server_key.clone(),
}));
}
Ok(Json(EnrollmentStatusResponse::NotFound))
}
/// Define admin enrollment routes.
pub fn admin_router() -> Router<AppState> {
Router::new()
.route("/enrollments", get(list_admin_enrollments))
.route("/enrollments/:id/approve", post(approve_enrollment))
.route("/enrollments/:id/deny", delete(deny_enrollment))
}
/// GET /api/v1/admin/enrollments
/// Lists all pending enrollment requests.
async fn list_admin_enrollments(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<Vec<EnrollmentRequest>>, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
db::list_enrollment_requests(&state.db)
.await
.map(|requests| Json(requests))
.map_err(|e| {
tracing::error!(error = %e, "Failed to list enrollment requests");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})
}
/// POST /api/v1/admin/enrollments/{id}/approve
/// Approves a pending enrollment request, generates PKI, and moves to hosts table.
async fn approve_enrollment(
State(state): State<AppState>,
Path(id): Path<uuid::Uuid>,
auth: AuthUser,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
// Fetch the enrollment request
let mut requests = db::list_enrollment_requests(&state.db).await.map_err(|e| {
tracing::error!(error = %e, "Failed to list enrollment requests for approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
let enrollment_request = match requests.iter().position(|r| r.id == id) {
Some(idx) => requests.remove(idx),
None => return Ok(StatusCode::NOT_FOUND),
};
// Check for FQDN/IP collision in hosts table
if let Some(existing_host) = sqlx::query_as::<_, Host>(
"SELECT id, fqdn, ip_address, display_name, os_family, os_name, arch, agent_version, health_status, last_health_at, last_patch_at, agent_port, notes, registered_at, updated_at FROM hosts WHERE fqdn = $1 OR ip_address = $2"
)
.bind(&enrollment_request.fqdn)
.bind(&enrollment_request.ip_address.to_string())
.fetch_optional(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to check for host collision");
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": "Database error" })))
})? {
return Err((
StatusCode::CONFLICT,
Json(serde_json::json!({ "error": "Host collision detected", "conflict": HostConflict { existing_host, message: "FQDN or IP already exists".to_string() } }))
));
}
// Generate PKI bundle using CA
let issued = state
.ca
.issue_client_cert(
enrollment_request.id,
&enrollment_request.fqdn,
&enrollment_request.ip_address.to_string(),
&state.db,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to issue client certificate");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Certificate generation failed" })),
)
})?;
// Move to hosts table
let os_name = enrollment_request
.os_details
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
sqlx::query(
r#"
INSERT INTO hosts (id, fqdn, ip_address, os_name, registered_at, updated_at, machine_id)
VALUES ($1, $2, $3, $4, NOW(), NOW(), $5)
"#,
)
.bind(enrollment_request.id)
.bind(&enrollment_request.fqdn)
.bind(&enrollment_request.ip_address.to_string())
.bind(os_name)
.bind(enrollment_request.machine_id)
.execute(&state.db)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to insert host after approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// Delete from enrollment_requests table
db::delete_enrollment_request(&state.db, id)
.await
.map_err(|e| {
tracing::error!(error = %e, "Failed to delete enrollment request after approval");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})?;
// Store PKI bundle in cache for client retrieval
let pki = PkiBundle {
ca_crt: issued.ca_root_pem,
server_crt: issued.server_cert_pem,
server_key: issued.server_key_pem,
};
state
.approved_enrollments
.insert(enrollment_request.polling_token.clone(), pki);
Ok(StatusCode::OK)
}
/// DELETE /api/v1/admin/enrollments/{id}/deny
/// Denies and purges a pending enrollment request.
async fn deny_enrollment(
State(state): State<AppState>,
Path(id): Path<uuid::Uuid>,
auth: AuthUser,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
if !auth.role.is_admin() {
return Err((
StatusCode::FORBIDDEN,
Json(
serde_json::json!({ "error": { "code": "forbidden", "message": "Admin role required" } }),
),
));
}
db::delete_enrollment_request(&state.db, id)
.await
.map(|_| StatusCode::NO_CONTENT)
.map_err(|e| {
tracing::error!(error = %e, "Failed to deny enrollment request");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": "Database error" })),
)
})
}

View File

@ -2,6 +2,7 @@
pub mod auth;
pub mod ca;
pub mod discovery;
pub mod enrollment;
pub mod groups;
pub mod health_checks;
pub mod hosts;

View File

@ -14,6 +14,7 @@ mod patch_poller;
mod refresh_listener;
mod ws_relay;
use chrono::Utc;
use pm_core::{config::AppConfig, db, logging};
use sqlx::PgPool;
use std::{sync::Arc, time::Duration};
@ -31,7 +32,7 @@ use ws_relay::run_ws_relay;
/// Minimum number of applied migrations the worker requires before
/// accepting work. Prevents the worker from running against a schema
/// that hasn't been migrated yet.
const REQUIRED_MIGRATION_COUNT: i64 = 8;
const REQUIRED_MIGRATION_COUNT: i64 = 16;
/// How long to wait between schema-version checks before giving up.
const SCHEMA_CHECK_TIMEOUT: Duration = Duration::from_secs(120);
@ -94,6 +95,9 @@ async fn main() -> anyhow::Result<()> {
// Health check poller — runs configured service/HTTP health checks
let health_check_handle = tokio::spawn(run_health_check_poller(pool.clone(), config.clone()));
// Enrollment cleanup task (runs every hour)
let enrollment_cleanup_handle = tokio::spawn(run_enrollment_cleanup_task(pool.clone()));
tracing::info!("Worker tasks started");
// Wait for all tasks (they run indefinitely)
@ -106,6 +110,8 @@ async fn main() -> anyhow::Result<()> {
maint_sched_handle,
ws_relay_handle,
audit_verifier_handle,
health_check_handle,
enrollment_cleanup_handle,
);
Ok(())
@ -174,3 +180,29 @@ async fn run_heartbeat(pool: PgPool, interval_secs: u64) {
}
}
}
/// Periodically deletes expired enrollment requests.
async fn run_enrollment_cleanup_task(pool: PgPool) {
let mut interval = tokio::time::interval(Duration::from_secs(3600)); // Every hour
interval.tick().await; // Initial tick to run immediately if needed
loop {
interval.tick().await;
let now = Utc::now();
match sqlx::query("DELETE FROM enrollment_requests WHERE expires_at < $1")
.bind(now)
.execute(&pool)
.await
{
Ok(result) => {
if result.rows_affected() > 0 {
tracing::info!(
removed = result.rows_affected(),
"Purged expired enrollment requests"
);
}
},
Err(e) => tracing::error!(error = %e, "Failed to purge expired enrollment requests"),
}
}
}