From 44145b39ce07ed0463ca7b561c09effd3599bef2 Mon Sep 17 00:00:00 2001 From: jae beller Date: Sun, 3 Mar 2024 21:25:33 -0500 Subject: [PATCH] clean up graceful shutdown --- Cargo.lock | 7 +++ Cargo.toml | 1 + src/lib.rs | 34 +++++++++++ src/main.rs | 148 +++++++++++++++++++++--------------------------- src/shutdown.rs | 53 +++++++++++++++++ 5 files changed, 160 insertions(+), 83 deletions(-) create mode 100644 src/lib.rs create mode 100644 src/shutdown.rs diff --git a/Cargo.lock b/Cargo.lock index f3d171d..939eb9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" + [[package]] name = "async-trait" version = "0.1.77" @@ -1385,6 +1391,7 @@ checksum = "626dec3cac7cc0e1577a2ec3fc496277ec2baa084bebad95bb6fdbfae235f84c" name = "poser" version = "0.1.0" dependencies = [ + "anyhow", "axum", "base64 0.22.0", "base64ct", diff --git a/Cargo.toml b/Cargo.toml index 8c54f8a..75a4bd3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +anyhow = "1.0" axum = { version = "0.7", features = ["macros"] } base64ct = "1.6" openidconnect = "3.5" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..17555c0 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,34 @@ +//! # poser +//! +//! poser is a simple, opinionated authentication provider for nginx +//! +//! ## About +//! +//! poser authenticates with Google using OpenID Connect and then uses the +//! Google Workspace Admin SDK to determine what groups a user is a part of. +//! Basic information about the user and what groups they are a part of is +//! returned to nginx in a [Paseto v4] token, which is then passed to the +//! application. +//! +//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec + +pub mod config; +pub mod error; +pub mod oidc; +mod routes; +pub mod shutdown; +pub mod token; + +pub use routes::routes; + +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ServerState { + pub config: crate::config::Config, + pub db: Arc, + pub oidc: openidconnect::core::CoreClient, + + // Signals back to the main thread when dropped + pub shutdown: crate::shutdown::Receiver, +} diff --git a/src/main.rs b/src/main.rs index d453ba2..32470b7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,69 +1,39 @@ -//! # poser -//! -//! poser is a simple, opinionated authentication provider for nginx -//! -//! ## About -//! -//! poser authenticates with Google using OpenID Connect and then uses the -//! Google Workspace Admin SDK to determine what groups a user is a part of. -//! Basic information about the user and what groups they are a part of is -//! returned to nginx in a [Paseto v4] token, which is then passed to the -//! application. -//! -//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec - -pub mod config; -pub mod error; -pub mod oidc; -pub mod routes; -pub mod token; - use std::env::var; use std::future::IntoFuture; use std::sync::Arc; -use config::Config; -use oidc::setup_auth; -use openidconnect::core::CoreClient; -use routes::routes; +use poser::config::Config; +use poser::oidc::setup_auth; +use poser::shutdown; +use poser::{routes, ServerState}; +use anyhow::Context; use tokio::{ net::TcpListener, runtime::Runtime, - select, signal::unix::{signal, SignalKind}, - sync::{broadcast, mpsc}, time::timeout, }; -use tokio_postgres::{Client, NoTls}; +use tokio_postgres::NoTls; use tower::ServiceBuilder; use tower_cookies::CookieManagerLayer; -use tower_http::{ - trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}, - LatencyUnit, -}; +use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; use tracing::{debug, error, info, warn, Level}; -#[derive(Debug, Clone)] -pub struct ServerState { - pub config: Config, - pub db: Arc, - pub oidc: CoreClient, - - // Signals back to the main thread when dropped - _shutdown_complete: mpsc::Sender<()>, -} - fn main() { tracing_subscriber::fmt() - .with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "info".to_string())) + .with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "warn,poser=info".to_string())) .init(); - let config = Config::try_env().expect("invalid configuration"); + let config = Config::try_env() + .context("failed to build config") + .unwrap_or_else(|e| { + error!("{:#}", e); + std::process::exit(1); + }); build_runtime().block_on(async move { - let (shutdown_notify, _) = broadcast::channel(1); - let (shutdown_tx, shutdown_rx) = mpsc::channel(1); + let shutdown = shutdown::Sender::new(); let (db, conn) = tokio_postgres::connect(&config.database, NoTls) .await @@ -77,15 +47,15 @@ fn main() { config: config.clone(), db: Arc::new(db), oidc, - _shutdown_complete: shutdown_tx.clone(), + shutdown: shutdown.subscribe(), }; - let db_signal = shutdown_tx.clone(); - let conn = tokio::spawn(async move { + let postgres_notify = shutdown.subscribe(); + let postgres = async move { let res = conn.await; - drop(db_signal); + drop(postgres_notify); res - }); + }; let router = routes().with_state(state).layer( ServiceBuilder::new() @@ -95,28 +65,28 @@ fn main() { .on_response( DefaultOnResponse::new() .level(Level::INFO) - .latency_unit(LatencyUnit::Micros), + .latency_unit(tower_http::LatencyUnit::Micros), ), ) .layer(CookieManagerLayer::new()), ); + let listener = TcpListener::bind(&config.addr) .await - .expect("failed to bind to socket"); - let shutdown_signal = shutdown_notify.subscribe(); - let server = axum::serve(listener, router) - .with_graceful_shutdown(wait_for_shutdown(shutdown_signal)); + .context("failed to bind to socket") + .unwrap_or_else(|e| { + error!("{:#}", e); + std::process::exit(1); + }); - info!("listening on {}", config.addr); + let mut axum_notify = shutdown.subscribe(); + let server = axum::serve(listener, router) + .with_graceful_shutdown(async move { _ = axum_notify.recv().await }); - select! { - _ = unix_signal(SignalKind::interrupt()) => { - info!("received SIGINT, shutting down"); - }, - _ = unix_signal(SignalKind::terminate()) => { - info!("received SIGTERM, shutting down"); - }, - res = conn => match res { + info!("listening for connections on: {}", config.addr); + tokio::select! { + sig = shutdown_signal() => info!("received {}, starting graceful shutdown...", sig), + res = tokio::spawn(postgres) => match res { Ok(Ok(_)) => error!("database connection closed unexpectedly"), Ok(Err(e)) => error!("database connection error: {}", e), Err(e) => error!("database executor unexpectedly stopped: {}", e), @@ -127,14 +97,16 @@ fn main() { Err(e) => error!("server executor unexpectedly stopped: {}", e), }, } - drop(shutdown_notify); - drop(shutdown_tx); - match timeout(config.grace_period, wait_for_complete(shutdown_rx)).await { - Ok(()) => debug!("shutdown completed"), - Err(_) => warn!( - "graceful shutdown did not complete in {:?}, closing anyways", - config.grace_period - ), + + tokio::select! { + sig = shutdown_signal() => error!("received second {}, aborting.", sig), + res = timeout(config.grace_period, shutdown::Sender::shutdown(shutdown)) => match res { + Ok(()) => debug!("shutdown completed"), + Err(_) => warn!( + "graceful shutdown did not complete in {:?}, closing anyways", + config.grace_period + ), + }, } }) } @@ -143,17 +115,27 @@ fn build_runtime() -> Runtime { tokio::runtime::Builder::new_multi_thread() .enable_all() .build() - .expect("build threaded runtime") -} - -async fn unix_signal(kind: SignalKind) { - signal(kind).expect("register signal handler").recv().await; -} - -async fn wait_for_shutdown(mut signal: broadcast::Receiver<()>) { - _ = signal.recv().await; + .context("failed to build threaded runtime") + .unwrap_or_else(|e| { + error!("{:#}", e); + std::process::exit(1); + }) } -async fn wait_for_complete(mut signal: mpsc::Receiver<()>) { - _ = signal.recv().await; +async fn shutdown_signal() -> &'static str { + async fn wait_for_signal(kind: SignalKind) { + signal(kind) + .context("failed to register signal handler") + .unwrap_or_else(|e| { + error!("{:#}", e); + std::process::exit(1); + }) + .recv() + .await; + } + + tokio::select! { + _ = wait_for_signal(SignalKind::interrupt()) => "SIGINT", + _ = wait_for_signal(SignalKind::terminate()) => "SIGTERM", + } } diff --git a/src/shutdown.rs b/src/shutdown.rs new file mode 100644 index 0000000..11d617b --- /dev/null +++ b/src/shutdown.rs @@ -0,0 +1,53 @@ +use tokio::sync::{mpsc, watch}; + +#[derive(Debug)] +pub struct Sender { + notify: watch::Sender<()>, + process_tx: mpsc::Sender<()>, + process_rx: mpsc::Receiver<()>, +} + +impl Sender { + pub fn new() -> Self { + let (notify, _) = watch::channel(()); + let (process_tx, process_rx) = mpsc::channel(1); + + Self { + notify, + process_tx, + process_rx, + } + } + + pub fn subscribe(&self) -> Receiver { + Receiver { + notify: self.notify.subscribe(), + _handle: self.process_tx.clone(), + } + } + + pub async fn shutdown(mut self) { + let _ = self.notify.send(()); + + drop(self.process_tx); + let _ = self.process_rx.recv().await; + } +} + +impl Default for Sender { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone, Debug)] +pub struct Receiver { + notify: watch::Receiver<()>, + _handle: mpsc::Sender<()>, +} + +impl Receiver { + pub async fn recv(&mut self) { + let _ = self.notify.changed().await; + } +}