Skip to content

Commit

Permalink
clean up graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellerb committed Mar 4, 2024
1 parent 6f97dfa commit 44145b3
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 83 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = "1.0"
axum = { version = "0.7", features = ["macros"] }
base64ct = "1.6"
openidconnect = "3.5"
Expand Down
34 changes: 34 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//! # poser
//!
//! poser is a simple, opinionated authentication provider for nginx
//!
//! ## About
//!
//! poser authenticates with Google using OpenID Connect and then uses the
//! Google Workspace Admin SDK to determine what groups a user is a part of.
//! Basic information about the user and what groups they are a part of is
//! returned to nginx in a [Paseto v4] token, which is then passed to the
//! application.
//!
//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec
pub mod config;
pub mod error;
pub mod oidc;
mod routes;
pub mod shutdown;
pub mod token;

pub use routes::routes;

use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct ServerState {
pub config: crate::config::Config,
pub db: Arc<tokio_postgres::Client>,
pub oidc: openidconnect::core::CoreClient,

// Signals back to the main thread when dropped
pub shutdown: crate::shutdown::Receiver,
}
148 changes: 65 additions & 83 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,39 @@
//! # poser
//!
//! poser is a simple, opinionated authentication provider for nginx
//!
//! ## About
//!
//! poser authenticates with Google using OpenID Connect and then uses the
//! Google Workspace Admin SDK to determine what groups a user is a part of.
//! Basic information about the user and what groups they are a part of is
//! returned to nginx in a [Paseto v4] token, which is then passed to the
//! application.
//!
//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec
pub mod config;
pub mod error;
pub mod oidc;
pub mod routes;
pub mod token;

use std::env::var;
use std::future::IntoFuture;
use std::sync::Arc;

use config::Config;
use oidc::setup_auth;
use openidconnect::core::CoreClient;
use routes::routes;
use poser::config::Config;
use poser::oidc::setup_auth;
use poser::shutdown;
use poser::{routes, ServerState};

use anyhow::Context;
use tokio::{
net::TcpListener,
runtime::Runtime,
select,
signal::unix::{signal, SignalKind},
sync::{broadcast, mpsc},
time::timeout,
};
use tokio_postgres::{Client, NoTls};
use tokio_postgres::NoTls;
use tower::ServiceBuilder;
use tower_cookies::CookieManagerLayer;
use tower_http::{
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
LatencyUnit,
};
use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer};
use tracing::{debug, error, info, warn, Level};

#[derive(Debug, Clone)]
pub struct ServerState {
pub config: Config,
pub db: Arc<Client>,
pub oidc: CoreClient,

// Signals back to the main thread when dropped
_shutdown_complete: mpsc::Sender<()>,
}

fn main() {
tracing_subscriber::fmt()
.with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "info".to_string()))
.with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "warn,poser=info".to_string()))
.init();

let config = Config::try_env().expect("invalid configuration");
let config = Config::try_env()
.context("failed to build config")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
});

build_runtime().block_on(async move {
let (shutdown_notify, _) = broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let shutdown = shutdown::Sender::new();

let (db, conn) = tokio_postgres::connect(&config.database, NoTls)
.await
Expand All @@ -77,15 +47,15 @@ fn main() {
config: config.clone(),
db: Arc::new(db),
oidc,
_shutdown_complete: shutdown_tx.clone(),
shutdown: shutdown.subscribe(),
};

let db_signal = shutdown_tx.clone();
let conn = tokio::spawn(async move {
let postgres_notify = shutdown.subscribe();
let postgres = async move {
let res = conn.await;
drop(db_signal);
drop(postgres_notify);
res
});
};

let router = routes().with_state(state).layer(
ServiceBuilder::new()
Expand All @@ -95,28 +65,28 @@ fn main() {
.on_response(
DefaultOnResponse::new()
.level(Level::INFO)
.latency_unit(LatencyUnit::Micros),
.latency_unit(tower_http::LatencyUnit::Micros),
),
)
.layer(CookieManagerLayer::new()),
);

let listener = TcpListener::bind(&config.addr)
.await
.expect("failed to bind to socket");
let shutdown_signal = shutdown_notify.subscribe();
let server = axum::serve(listener, router)
.with_graceful_shutdown(wait_for_shutdown(shutdown_signal));
.context("failed to bind to socket")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
});

info!("listening on {}", config.addr);
let mut axum_notify = shutdown.subscribe();
let server = axum::serve(listener, router)
.with_graceful_shutdown(async move { _ = axum_notify.recv().await });

select! {
_ = unix_signal(SignalKind::interrupt()) => {
info!("received SIGINT, shutting down");
},
_ = unix_signal(SignalKind::terminate()) => {
info!("received SIGTERM, shutting down");
},
res = conn => match res {
info!("listening for connections on: {}", config.addr);
tokio::select! {
sig = shutdown_signal() => info!("received {}, starting graceful shutdown...", sig),
res = tokio::spawn(postgres) => match res {
Ok(Ok(_)) => error!("database connection closed unexpectedly"),
Ok(Err(e)) => error!("database connection error: {}", e),
Err(e) => error!("database executor unexpectedly stopped: {}", e),
Expand All @@ -127,14 +97,16 @@ fn main() {
Err(e) => error!("server executor unexpectedly stopped: {}", e),
},
}
drop(shutdown_notify);
drop(shutdown_tx);
match timeout(config.grace_period, wait_for_complete(shutdown_rx)).await {
Ok(()) => debug!("shutdown completed"),
Err(_) => warn!(
"graceful shutdown did not complete in {:?}, closing anyways",
config.grace_period
),

tokio::select! {
sig = shutdown_signal() => error!("received second {}, aborting.", sig),
res = timeout(config.grace_period, shutdown::Sender::shutdown(shutdown)) => match res {
Ok(()) => debug!("shutdown completed"),
Err(_) => warn!(
"graceful shutdown did not complete in {:?}, closing anyways",
config.grace_period
),
},
}
})
}
Expand All @@ -143,17 +115,27 @@ fn build_runtime() -> Runtime {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("build threaded runtime")
}

async fn unix_signal(kind: SignalKind) {
signal(kind).expect("register signal handler").recv().await;
}

async fn wait_for_shutdown(mut signal: broadcast::Receiver<()>) {
_ = signal.recv().await;
.context("failed to build threaded runtime")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
})
}

async fn wait_for_complete(mut signal: mpsc::Receiver<()>) {
_ = signal.recv().await;
async fn shutdown_signal() -> &'static str {
async fn wait_for_signal(kind: SignalKind) {
signal(kind)
.context("failed to register signal handler")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
})
.recv()
.await;
}

tokio::select! {
_ = wait_for_signal(SignalKind::interrupt()) => "SIGINT",
_ = wait_for_signal(SignalKind::terminate()) => "SIGTERM",
}
}
53 changes: 53 additions & 0 deletions src/shutdown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use tokio::sync::{mpsc, watch};

#[derive(Debug)]
pub struct Sender {
notify: watch::Sender<()>,
process_tx: mpsc::Sender<()>,
process_rx: mpsc::Receiver<()>,
}

impl Sender {
pub fn new() -> Self {
let (notify, _) = watch::channel(());
let (process_tx, process_rx) = mpsc::channel(1);

Self {
notify,
process_tx,
process_rx,
}
}

pub fn subscribe(&self) -> Receiver {
Receiver {
notify: self.notify.subscribe(),
_handle: self.process_tx.clone(),
}
}

pub async fn shutdown(mut self) {
let _ = self.notify.send(());

drop(self.process_tx);
let _ = self.process_rx.recv().await;
}
}

impl Default for Sender {
fn default() -> Self {
Self::new()
}
}

#[derive(Clone, Debug)]
pub struct Receiver {
notify: watch::Receiver<()>,
_handle: mpsc::Sender<()>,
}

impl Receiver {
pub async fn recv(&mut self) {
let _ = self.notify.changed().await;
}
}

0 comments on commit 44145b3

Please sign in to comment.