From 4a528256353da8414566e0117abe52126b18fd5c Mon Sep 17 00:00:00 2001 From: Billy Batista Date: Tue, 17 Sep 2024 01:34:16 -0400 Subject: [PATCH] check and cache revoked tokens we should eventually store this in sqlite or something, but the list is tiny so in mem is fine --- .envrc | 2 +- Cargo.lock | 1 + Cargo.toml | 5 ++-- src/action.rs | 20 ++++---------- src/auth.rs | 8 +++++- src/io.rs | 0 src/main.rs | 8 ++++-- src/meta_db.rs | 2 +- src/tokens.rs | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 97 insertions(+), 23 deletions(-) delete mode 100644 src/io.rs create mode 100644 src/tokens.rs diff --git a/.envrc b/.envrc index 08b4017..ec7df30 100644 --- a/.envrc +++ b/.envrc @@ -1,2 +1,2 @@ -# this is *not* the prod public key, obviously lol +export BIG_CENTRAL_URL="http://localhost:4000"; use flake; diff --git a/Cargo.lock b/Cargo.lock index 4876c26..b777ad7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -882,6 +882,7 @@ dependencies = [ "biscuit-auth", "bytes", "futures", + "hex", "humantime", "instant-acme", "opentelemetry", diff --git a/Cargo.toml b/Cargo.toml index 3aa4b81..cf0dd58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ tracing = { version = "0.1" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } opentelemetry-otlp = { version = "0.15.0", features = ["http-proto", "reqwest", "reqwest-client", "reqwest-rustls", "trace", "tokio"] } opentelemetry = "0.22" -reqwest = "0.11" +reqwest = { version = "0.11", features = ["json"] } opentelemetry_sdk = { version = "0.22", features = ["rt-tokio", "trace"] } tracing-opentelemetry = "0.23" @@ -37,7 +37,8 @@ rcgen = "0.12" warp = "0.3" serde = "1" serde_json = { version = "1", features = ["raw_value"] } -bytes = "1.6" +bytes = "1" +hex = "0.4" [features] s3 = ["rust-s3"] diff --git a/src/action.rs b/src/action.rs index d4dda61..3d9ca70 100644 --- a/src/action.rs +++ b/src/action.rs @@ -1,14 +1,12 @@ use anyhow::anyhow; use bfsp::internal::ActionInfo; -use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; -use tracing::error_span; use tracing::{error, Level}; use rand::Rng; -use crate::{chunk_db::ChunkDB, meta_db::MetaDB}; +use crate::meta_db::MetaDB; #[derive(Debug)] enum Action { @@ -35,10 +33,7 @@ impl TryFrom for Action { } } -pub async fn check_run_actions_loop( - meta_db: Arc, - chunk_db: Arc, -) { +pub async fn check_run_actions_loop(meta_db: Arc) { loop { tracing::span!(Level::INFO, "run_current_actions"); @@ -49,10 +44,9 @@ pub async fn check_run_actions_loop( Ok(actions) => { for action_info in actions.into_iter() { let meta_db = Arc::clone(&meta_db); - let chunk_db = Arc::clone(&chunk_db); tokio::task::spawn(async move { - match run_action(Arc::clone(&meta_db), chunk_db, &action_info).await { + match run_action(Arc::clone(&meta_db), &action_info).await { Ok(_) => { let _ = meta_db.executed_action(action_info.id.unwrap()).await; } @@ -74,12 +68,8 @@ pub async fn check_run_actions_loop( } } -#[tracing::instrument(err, skip(meta_db, chunk_db))] -async fn run_action( - meta_db: Arc, - chunk_db: Arc, - action_info: &ActionInfo, -) -> anyhow::Result<()> { +#[tracing::instrument(err, skip(meta_db))] +async fn run_action(meta_db: Arc, action_info: &ActionInfo) -> anyhow::Result<()> { let action: Action = action_info.action.clone().try_into()?; let user_id = action_info.user_id; diff --git a/src/auth.rs b/src/auth.rs index 4b11b20..9264640 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -10,7 +10,7 @@ use biscuit_auth::{ }; use tracing::{event, Level}; -use crate::meta_db::MetaDB; +use crate::{meta_db::MetaDB, tokens::check_token_revoked}; #[derive(Debug)] pub enum Right { @@ -20,6 +20,7 @@ pub enum Right { Delete, Usage, Payment, + Settings, } impl Right { @@ -31,6 +32,7 @@ impl Right { Right::Delete => "delete", Right::Usage => "usage", Right::Payment => "payment", + Right::Settings => "settings", } } } @@ -42,6 +44,10 @@ pub async fn authorize( file_ids: Vec, meta_db: &M, ) -> anyhow::Result { + if check_token_revoked(token).await { + return Err(anyhow!("token is revoked")); + } + let user_id = get_user_id(token)?; // first, check if the user has been suspended from the right they're trying to execute diff --git a/src/io.rs b/src/io.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/main.rs b/src/main.rs index 73ff91b..d388759 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod chunk_db; mod internal; mod meta_db; mod tls; +mod tokens; use action::check_run_actions_loop; use anyhow::anyhow; @@ -32,6 +33,7 @@ use std::{ collections::{HashMap, HashSet}, sync::Arc, }; +use tokens::refresh_revoked_tokens; use tokio::{fs, io}; use tracing::{event, Level}; use tracing_opentelemetry::PreSampledTracer; @@ -133,12 +135,11 @@ async fn main() -> Result<()> { let chunk_db_clone = Arc::clone(&chunk_db); let meta_db_clone = Arc::clone(&meta_db); + tokio::task::spawn(async move { refresh_revoked_tokens().await }); tokio::task::spawn(async move { chunk_db_clone.garbage_collect(meta_db_clone).await }); - let chunk_db_clone = Arc::clone(&chunk_db); let meta_db_clone = Arc::clone(&meta_db); - - tokio::task::spawn(async move { check_run_actions_loop(meta_db_clone, chunk_db_clone).await }); + tokio::task::spawn(async move { check_run_actions_loop(meta_db_clone).await }); let internal_tcp_addr = "[::]:9990".to_socket_addrs().unwrap().next().unwrap(); @@ -492,6 +493,7 @@ pub async fn handle_message( .encode_to_vec(), Err(_) => todo!(), }, + _ => todo!(), } .prepend_len()) } diff --git a/src/meta_db.rs b/src/meta_db.rs index b571b23..5d8e9b0 100644 --- a/src/meta_db.rs +++ b/src/meta_db.rs @@ -12,7 +12,7 @@ use bfsp::{ use serde::{Deserialize, Serialize}; use sqlx::{ types::{time::OffsetDateTime, Json}, - Execute, Executor, PgPool, QueryBuilder, Row, + Executor, PgPool, QueryBuilder, Row, }; use thiserror::Error; diff --git a/src/tokens.rs b/src/tokens.rs new file mode 100644 index 0000000..c1976b9 --- /dev/null +++ b/src/tokens.rs @@ -0,0 +1,74 @@ +use std::{collections::HashSet, env, sync::OnceLock, time::Duration}; + +use anyhow::anyhow; +use biscuit_auth::Biscuit; +use reqwest::StatusCode; +use tokio::sync::RwLock; + +pub type RevocationIdentiifer = Vec; + +static REVOKED_TOKENS: OnceLock>> = OnceLock::new(); + +#[tracing::instrument] +pub async fn refresh_revoked_tokens() { + let mut token_num: u32 = 0; + loop { + match single_update_revoked_tokens(token_num, 100).await { + Ok(tokens_inserted) => { + // we don't want to get too far behind, we we should keep iterating up til we can't + token_num += tokens_inserted; + if tokens_inserted > 0 { + continue; + } + } + Err(err) => { + tracing::error!( + token_num = token_num, + page_size = 100, + "Error updating revoked tokens: {err}" + ); + } + } + tokio::time::sleep(Duration::from_secs(5 * 60)).await; + } +} + +pub async fn check_token_revoked(token: &Biscuit) -> bool { + for identifier in token.revocation_identifiers().iter() { + let revoked_tokens = REVOKED_TOKENS.get_or_init(|| RwLock::new(HashSet::new())); + if revoked_tokens.read().await.contains(identifier.as_slice()) { + return true; + } + } + + false +} + +#[tracing::instrument(err)] +async fn single_update_revoked_tokens(token_num: u32, page_size: u32) -> anyhow::Result { + let revoked_tokens = REVOKED_TOKENS.get_or_init(|| RwLock::new(HashSet::new())); + let big_central_url = big_central_url(); + let resp = reqwest::get(format!( + "{big_central_url}/api/v1/revoked_tokens?token_num={token_num}&page_size={page_size}" + )) + .await?; + + if resp.status() != StatusCode::OK { + return Err(anyhow!("{}", resp.text().await?)); + } + + let tokens: Vec = resp.json().await?; + let num_tokens: u32 = tokens.len().try_into()?; + + let revoked_tokens = &mut revoked_tokens.write().await; + for token in tokens.into_iter() { + let token: RevocationIdentiifer = hex::decode(token)?; + revoked_tokens.insert(token); + } + + Ok(num_tokens) +} + +fn big_central_url() -> String { + env::var("BIG_CENTRAL_URL").unwrap_or_else(|_| "https://bbfs.io".to_string()) +}