Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update dependencies #5

Merged
merged 3 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,521 changes: 1,002 additions & 519 deletions Cargo.lock

Large diffs are not rendered by default.

25 changes: 13 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@ version = "0.1.0"
edition = "2021"

[dependencies]
axum = { version = "0.6", features = ["macros"] }
anyhow = "1.0"
axum = { version = "0.7", features = ["macros"] }
base64ct = "1.6"
openidconnect = "2.5"
regex = "1.7"
openidconnect = "3.5"
regex = "1.10"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
time = { version = "0.3", features = ["formatting", "parsing"] }
tokio = { version = "1.27", features = ["full"] }
tokio = { version = "1.36", features = ["full"] }
tokio-postgres = { version = "0.7", features = ["with-time-0_3", "with-uuid-1"] }
tower = "0.4"
tower-cookies = "0.9"
tower-http = { version = "0.4", features = ["trace"] }
tower-cookies = "0.10"
tower-http = { version = "0.5", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.3", features = ["v4", "fast-rng"] }
uuid = { version = "1.7", features = ["v4", "fast-rng"] }

base64 = "0.21"
base64 = "0.22"
blake2 = "0.10"
chacha20 = "0.9"
ed25519-dalek = { version = "2.0.0-rc.2", features = ["pem"] }
ed25519-dalek = { version = "2.1.1", features = ["pem"] }
getrandom = "0.2"
subtle = "2.4"
zeroize = "1.6"
subtle = "2.5"
zeroize = "1.7"

[dev-dependencies]
hex-literal = "0.3"
hex-literal = "0.4"
34 changes: 34 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//! # poser
//!
//! poser is a simple, opinionated authentication provider for nginx
//!
//! ## About
//!
//! poser authenticates with Google using OpenID Connect and then uses the
//! Google Workspace Admin SDK to determine what groups a user is a part of.
//! Basic information about the user and what groups they are a part of is
//! returned to nginx in a [Paseto v4] token, which is then passed to the
//! application.
//!
//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec

pub mod config;
pub mod error;
pub mod oidc;
mod routes;
pub mod shutdown;
pub mod token;

pub use routes::routes;

use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct ServerState {
pub config: crate::config::Config,
pub db: Arc<tokio_postgres::Client>,
pub oidc: openidconnect::core::CoreClient,

// Signals back to the main thread when dropped
pub shutdown: crate::shutdown::Receiver,
}
159 changes: 72 additions & 87 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,68 +1,39 @@
//! # poser
//!
//! poser is a simple, opinionated authentication provider for nginx
//!
//! ## About
//!
//! poser authenticates with Google using OpenID Connect and then uses the
//! Google Workspace Admin SDK to determine what groups a user is a part of.
//! Basic information about the user and what groups they are a part of is
//! returned to nginx in a [Paseto v4] token, which is then passed to the
//! application.
//!
//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec

pub mod config;
pub mod error;
pub mod oidc;
pub mod routes;
pub mod token;

use std::env::var;
use std::future::IntoFuture;
use std::sync::Arc;

use config::Config;
use oidc::setup_auth;
use openidconnect::core::CoreClient;
use routes::routes;
use poser::config::Config;
use poser::oidc::setup_auth;
use poser::shutdown;
use poser::{routes, ServerState};

use axum::Server;
use anyhow::Context;
use tokio::{
net::TcpListener,
runtime::Runtime,
select,
signal::unix::{signal, SignalKind},
sync::{broadcast, mpsc},
time::timeout,
};
use tokio_postgres::{Client, NoTls};
use tokio_postgres::NoTls;
use tower::ServiceBuilder;
use tower_cookies::CookieManagerLayer;
use tower_http::{
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
LatencyUnit,
};
use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer};
use tracing::{debug, error, info, warn, Level};

#[derive(Debug, Clone)]
pub struct ServerState {
pub config: Config,
pub db: Arc<Client>,
pub oidc: CoreClient,

// Signals back to the main thread when dropped
_shutdown_complete: mpsc::Sender<()>,
}

fn main() {
tracing_subscriber::fmt()
.with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "info".to_string()))
.with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "warn,poser=info".to_string()))
.init();

let config = Config::try_env().expect("invalid configuration");
let config = Config::try_env()
.context("failed to build config")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
});

build_runtime().block_on(async move {
let (shutdown_notify, _) = broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let shutdown = shutdown::Sender::new();

let (db, conn) = tokio_postgres::connect(&config.database, NoTls)
.await
Expand All @@ -76,62 +47,66 @@ fn main() {
config: config.clone(),
db: Arc::new(db),
oidc,
_shutdown_complete: shutdown_tx.clone(),
shutdown: shutdown.subscribe(),
};

let db_signal = shutdown_tx.clone();
let conn = tokio::spawn(async move {
let postgres_notify = shutdown.subscribe();
let postgres = async move {
let res = conn.await;
drop(db_signal);
drop(postgres_notify);
res
});
};

let app = routes().with_state(state).layer(
let router = routes().with_state(state).layer(
ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.on_request(DefaultOnRequest::new().level(Level::INFO))
.on_response(
DefaultOnResponse::new()
.level(Level::INFO)
.latency_unit(LatencyUnit::Micros),
.latency_unit(tower_http::LatencyUnit::Micros),
),
)
.layer(CookieManagerLayer::new()),
);
let shutdown_signal = shutdown_notify.subscribe();
let server = Server::bind(&config.addr)
.serve(app.into_make_service())
.with_graceful_shutdown(wait_for_shutdown(shutdown_signal));

info!("listening on {}", config.addr);

select! {
_ = unix_signal(SignalKind::interrupt()) => {
info!("received SIGINT, shutting down");
},
_ = unix_signal(SignalKind::terminate()) => {
info!("received SIGTERM, shutting down");
},
res = conn => match res {
let listener = TcpListener::bind(&config.addr)
.await
.context("failed to bind to socket")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
});

let mut axum_notify = shutdown.subscribe();
let server = axum::serve(listener, router)
.with_graceful_shutdown(async move { _ = axum_notify.recv().await });

info!("listening for connections on: {}", config.addr);
tokio::select! {
sig = shutdown_signal() => info!("received {}, starting graceful shutdown...", sig),
res = tokio::spawn(postgres) => match res {
Ok(Ok(_)) => error!("database connection closed unexpectedly"),
Ok(Err(e)) => error!("database connection error: {}", e),
Err(e) => error!("database executor unexpectedly stopped: {}", e),
},
res = tokio::spawn(server) => match res {
res = tokio::spawn(server.into_future()) => match res {
Ok(Ok(_)) => info!("server shutting down"),
Ok(Err(e)) => error!("server unexpectedly stopped: {}", e),
Err(e) => error!("server executor unexpectedly stopped: {}", e),
},
}
drop(shutdown_notify);
drop(shutdown_tx);
match timeout(config.grace_period, wait_for_complete(shutdown_rx)).await {
Ok(()) => debug!("shutdown completed"),
Err(_) => warn!(
"graceful shutdown did not complete in {:?}, closing anyways",
config.grace_period
),

tokio::select! {
sig = shutdown_signal() => error!("received second {}, aborting.", sig),
res = timeout(config.grace_period, shutdown::Sender::shutdown(shutdown)) => match res {
Ok(()) => debug!("shutdown completed"),
Err(_) => warn!(
"graceful shutdown did not complete in {:?}, closing anyways",
config.grace_period
),
},
}
})
}
Expand All @@ -140,17 +115,27 @@ fn build_runtime() -> Runtime {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("build threaded runtime")
}

async fn unix_signal(kind: SignalKind) {
signal(kind).expect("register signal handler").recv().await;
}

async fn wait_for_shutdown(mut signal: broadcast::Receiver<()>) {
_ = signal.recv().await;
.context("failed to build threaded runtime")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
})
}

async fn wait_for_complete(mut signal: mpsc::Receiver<()>) {
_ = signal.recv().await;
async fn shutdown_signal() -> &'static str {
async fn wait_for_signal(kind: SignalKind) {
signal(kind)
.context("failed to register signal handler")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
})
.recv()
.await;
}

tokio::select! {
_ = wait_for_signal(SignalKind::interrupt()) => "SIGINT",
_ = wait_for_signal(SignalKind::terminate()) => "SIGTERM",
}
}
29 changes: 16 additions & 13 deletions src/routes/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pub enum Error {
#[error("missing id token")]
MissingToken,
#[error("error exchanging code")]
CodeError,
CodeExchange,
#[error("error verifying token")]
TokenError,
VerifyToken,
#[error("error interacting with database")]
DatabaseError,
DatabaseInsert,
}

/// A handler for receiving the callback during the OIDC flow.
Expand Down Expand Up @@ -86,11 +86,11 @@ pub async fn callback_handler(
let id = create_session(&state.db, &token, &expiration).await?;

cookies.add(
Cookie::build(state.config.cookie.name, id.simple().to_string())
Cookie::build((state.config.cookie.name, id.simple().to_string()))
.secure(state.config.cookie.secure)
.http_only(true)
.expires(expiration)
.finish(),
.build(),
);

Ok(Redirect::to(oidc.get_redirect()))
Expand All @@ -107,7 +107,7 @@ async fn get_token(
.await
.map_err(|e| {
error!("failed to exchange code for token: {}", e);
Error::CodeError
Error::CodeExchange
})?;

let token = token_response.extra_fields().id_token().ok_or_else(|| {
Expand All @@ -118,19 +118,22 @@ async fn get_token(
let id_token_verifier = client.id_token_verifier();
let claims = token.claims(&id_token_verifier, nonce).map_err(|e| {
error!("failed to verify id token: {}", e);
Error::TokenError
Error::VerifyToken
})?;

let subj = claims.subject();
let name = claims.name().and_then(|s| s.get(None)).ok_or_else(|| {
error!("name missing from id token");
Error::TokenError
Error::VerifyToken
})?;
let email = claims.email().ok_or_else(|| {
error!("email missing from id token");
Error::TokenError
Error::VerifyToken
})?;
let expiration = claims.expiration().timestamp_nanos();
let expiration = claims
.expiration()
.timestamp_nanos_opt()
.expect("todo: fix timestamp handling before 2262");

Ok((
UserToken {
Expand Down Expand Up @@ -174,7 +177,7 @@ async fn create_session(
.await
.map_err(|e| {
error!("error creating session: {}", e);
Error::DatabaseError
Error::DatabaseInsert
})
.map(|r| r.get::<_, Uuid>("id"))
}
Expand All @@ -186,8 +189,8 @@ impl IntoResponse for Error {
| Error::MissingState
| Error::MissingCode
| Error::MissingCookie => json!({ "error": "invalid request" }),
Error::MissingToken | Error::TokenError => json!({ "error": "authentication error" }),
Error::InvalidDateTime | Error::CodeError | Error::DatabaseError => {
Error::MissingToken | Error::VerifyToken => json!({ "error": "authentication error" }),
Error::InvalidDateTime | Error::CodeExchange | Error::DatabaseInsert => {
json!({ "error": "internal error" })
}
};
Expand Down
Loading
Loading