From ecfee382718bc78b6d08566c3263301c5fa8b8c0 Mon Sep 17 00:00:00 2001 From: eason <30045503+Eason0729@users.noreply.github.com> Date: Mon, 4 Dec 2023 20:54:31 +0800 Subject: [PATCH] refactor: :art: remove oncecell --- backend/Cargo.lock | 14 ------- backend/Cargo.toml | 1 - backend/src/controller/submit/mod.rs | 7 ++++ backend/src/controller/submit/pubsub.rs | 54 +------------------------ backend/src/controller/submit/router.rs | 20 +++------ backend/src/endpoint/util/error.rs | 22 +++------- backend/src/init/config.rs | 12 ++---- backend/src/init/db.rs | 14 +++---- backend/src/init/logger.rs | 6 +-- backend/src/init/mod.rs | 11 +++-- backend/src/macro_tool.rs | 4 +- backend/src/server.rs | 9 ++--- 12 files changed, 44 insertions(+), 130 deletions(-) diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 38175607..057f758d 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -181,7 +181,6 @@ dependencies = [ "ring 0.17.5", "sea-orm", "serde", - "sha256", "spin 0.9.8", "thiserror", "tokio", @@ -2063,19 +2062,6 @@ dependencies = [ "digest", ] -[[package]] -name = "sha256" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7895c8ae88588ccead14ff438b939b0c569cd619116f14b4d13fdff7b8333386" -dependencies = [ - "async-trait", - "bytes", - "hex", - "sha2", - "tokio", -] - [[package]] name = "sharded-slab" version = "0.1.4" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index f8ebda9d..672d2c84 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -17,7 +17,6 @@ toml = "0.7.4" prost-types = "0.11.9" entity = { path = "./entity" } chrono = "0.4.26" -sha256 = "1.3.0" thiserror = "1.0.44" ring = "^0.17" lockfree = "0.5.1" diff --git a/backend/src/controller/submit/mod.rs b/backend/src/controller/submit/mod.rs index 9280f673..58996e66 100644 --- a/backend/src/controller/submit/mod.rs +++ b/backend/src/controller/submit/mod.rs @@ -7,6 +7,7 @@ use futures::Future; use leaky_bucket::RateLimiter; use sea_orm::{ActiveModelTrait, ActiveValue, EntityTrait, QueryOrder}; use thiserror::Error; +use tokio::sync::OnceCell; use tokio_stream::StreamExt; use tonic::Status; use uuid::Uuid; @@ -26,6 +27,8 @@ use self::{ use super::code::Code; use entity::*; +pub static SECRET: OnceCell<&'static str> = OnceCell::const_new(); + struct Waker; impl std::task::Wake for Waker { @@ -118,6 +121,10 @@ pub struct SubmitController { impl SubmitController { pub async fn new(config: &GlobalConfig) -> Result { + if let Some(secret) = &config.judger_secret { + let secret = Box::new(secret.clone()).leak(); + SECRET.set(secret).unwrap(); + }; Ok(SubmitController { router: Router::new(&config.judger).await?, pubsub: Arc::new(PubSub::default()), diff --git a/backend/src/controller/submit/pubsub.rs b/backend/src/controller/submit/pubsub.rs index 8f4fb965..80d64c4d 100644 --- a/backend/src/controller/submit/pubsub.rs +++ b/backend/src/controller/submit/pubsub.rs @@ -1,11 +1,9 @@ use spin::mutex::Mutex; use std::ops::{Deref, DerefMut}; use std::pin::Pin; -use tokio_stream::wrappers::errors::BroadcastStreamRecvError; -// use std::pin::Pin; -// use std::task::Poll; use std::{collections::HashMap, hash::Hash, sync::Arc}; use tokio::sync::broadcast::*; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; use tokio_stream::{Stream, StreamExt}; @@ -64,27 +62,6 @@ where M: Clone + Send + 'static, I: Eq + Clone + Hash + Send + 'static, { - pub fn stream( - self: &Arc, - mut stream: impl Stream + Unpin + Send + 'static, - id: I, - ) { - let tx = { - let (tx, rx) = channel(16); - self.outgoing.lock().insert(id.clone(), rx); - tx - }; - - let self_ = self.clone(); - tokio::spawn(async move { - while let Some(messenge) = stream.next().await { - if tx.send(messenge).is_err() { - log::trace!("PubSub: messege") - } - } - self_.outgoing.lock().remove(&id); - }); - } pub fn publish(self: &Arc, id: I) -> PubGuard { let (tx, rx) = channel(16); self.outgoing.lock().insert(id.clone(), rx); @@ -107,32 +84,3 @@ where }) } } - -// pub struct SubStream(BroadcastStream>); - -// impl Stream for SubStream -// where -// M: 'static + Clone + Send, -// { -// type Item = M; - -// fn poll_next( -// mut self: Pin<&mut Self>, -// cx: &mut std::task::Context<'_>, -// ) -> std::task::Poll> { -// let a = Pin::new(&mut self.0); -// if let Poll::Ready(x) = BroadcastStream::poll_next(a, cx) { -// if let Some(x) = x { -// if let Ok(x) = x { -// Poll::Ready(x) -// } else { -// Poll::Ready(None) -// } -// } else { -// Poll::Ready(None) -// } -// } else { -// Poll::Pending -// } -// } -// } diff --git a/backend/src/controller/submit/router.rs b/backend/src/controller/submit/router.rs index daf08c93..afe83087 100644 --- a/backend/src/controller/submit/router.rs +++ b/backend/src/controller/submit/router.rs @@ -1,3 +1,4 @@ +// TODO: we need docker swarm in dns setup, so we need to accept dns round robin use std::{ collections::{BTreeMap, VecDeque}, ops::DerefMut, @@ -14,10 +15,10 @@ use uuid::Uuid; use crate::{ grpc::judger::{judger_client::*, *}, - init::config::{self, CONFIG}, + init::config::{self}, }; -use super::super::submit::Error; +use super::{super::submit::Error, SECRET}; const PIPELINE: usize = 8; const JUDGER_QUE_MAX: usize = 16; @@ -30,8 +31,7 @@ type AuthIntercept = JudgerClient< >, >; fn auth_middleware(mut req: tonic::Request<()>) -> Result, tonic::Status> { - let config = CONFIG.get().unwrap(); - match &config.judger_secret { + match &SECRET.get() { Some(secret) => { let token: metadata::MetadataValue<_> = format!("basic {}", secret).parse().unwrap(); req.metadata_mut().insert("Authorization", token); @@ -189,15 +189,6 @@ impl Router { Ok(_) => upstreams.push(upstream), } } - // let futs: Box>> = configs - // .iter() - // .map(|x| async { - // let upstream = Upstream::new(Arc::new(x.clone())); - // upstream.health_check().await; - // upstream - // }) - // .collect(); - // tokio::join!(futs); if upstreams.is_empty() { return Err(Error::JudgerUnavailable); } @@ -209,8 +200,7 @@ impl Router { pub fn langs(&self) -> Vec { self.upstreams .iter() - .map(|x| x.langs()) - .flatten() + .flat_map(|x| x.langs()) .unique_by(|x| x.lang_uid.clone()) .collect() } diff --git a/backend/src/endpoint/util/error.rs b/backend/src/endpoint/util/error.rs index d845a134..7be57d80 100644 --- a/backend/src/endpoint/util/error.rs +++ b/backend/src/endpoint/util/error.rs @@ -1,3 +1,5 @@ +use crate::report_internal; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Premission deny: `{0}`")] @@ -27,19 +29,12 @@ impl From for tonic::Status { log::debug!("Client request inaccessible resource, hint: {}", x); tonic::Status::permission_denied(x) } - Error::DBErr(x) => { - log::error!("{}", x); - #[cfg(feature = "unsecured-log")] - return tonic::Status::internal(format!("{}", x)); - tonic::Status::unavailable("") - } + Error::DBErr(x) => report_internal!(error, "{}", x), // all argument should be checked before processing, // so this error is considered as internal error Error::BadArgument(x) => { - log::warn!("Client sent invaild argument: payload.{}", x); - #[cfg(feature = "unsecured-log")] - return tonic::Status::invalid_argument(format!("Bad Argument {}", x)); - tonic::Status::invalid_argument("") + log::debug!("Client sent invaild argument: payload.{}", x); + tonic::Status::invalid_argument(x) } Error::NotInPayload(x) => { log::trace!("{} is not found in client payload", x); @@ -63,12 +58,7 @@ impl From for tonic::Status { "Invaild request_id(should be a client generated UUIDv4)", ) } - Error::Unreachable(x) => { - log::error!("Function should be unreachable: {}", x); - #[cfg(feature = "unsecured-log")] - return tonic::Status::internal(format!("Function should be unreachable: {}", x)); - tonic::Status::aborted("") - } + Error::Unreachable(x) => report_internal!(error, "{}", x), } } } diff --git a/backend/src/init/config.rs b/backend/src/init/config.rs index 25ae2b54..19b6ec85 100644 --- a/backend/src/init/config.rs +++ b/backend/src/init/config.rs @@ -1,9 +1,7 @@ use std::{path::PathBuf, sync::Arc}; use serde::{Deserialize, Serialize}; -use tokio::{fs, io::AsyncReadExt, sync::OnceCell}; - -pub static CONFIG: OnceCell = OnceCell::const_new(); +use tokio::{fs, io::AsyncReadExt}; static CONFIG_PATH: &str = "config/config.toml"; @@ -66,7 +64,7 @@ impl Default for Database { } } -pub async fn init() { +pub async fn init() -> GlobalConfig { if fs::metadata(CONFIG_PATH).await.is_ok() { let mut buf = Vec::new(); let mut config = fs::File::open(CONFIG_PATH) @@ -76,14 +74,13 @@ pub async fn init() { let config = std::str::from_utf8(&buf).expect("Config file may container non-utf8 character"); let config: GlobalConfig = toml::from_str(config).unwrap(); - CONFIG.set(config).ok(); + config } else { println!("Unable to find {}, generating default config", CONFIG_PATH); let config: GlobalConfig = toml::from_str("").unwrap(); let config_txt = toml::to_string(&config).unwrap(); fs::write(CONFIG_PATH, config_txt).await.unwrap(); - CONFIG.set(config).ok(); println!( "Config generated, please edit {} before restart", @@ -96,11 +93,10 @@ pub async fn init() { #[cfg(test)] mod test { - use super::{init, CONFIG}; + use super::init; #[tokio::test] async fn default() { init().await; - assert!(CONFIG.get().is_some()); } } diff --git a/backend/src/init/db.rs b/backend/src/init/db.rs index a6188137..1ac60075 100644 --- a/backend/src/init/db.rs +++ b/backend/src/init/db.rs @@ -5,13 +5,12 @@ use sea_orm::{ActiveModelTrait, ActiveValue, Database, DatabaseConnection}; use tokio::fs; use tokio::sync::OnceCell; -use super::config::CONFIG; +use super::config::GlobalConfig; use crate::controller::token::UserPermBytes; pub static DB: OnceCell = OnceCell::const_new(); -pub async fn init() { - let config = CONFIG.get().unwrap(); +pub async fn init(config: &GlobalConfig) { let uri = format!("sqlite://{}", config.database.path.clone()); match Database::connect(&uri).await { @@ -24,7 +23,7 @@ pub async fn init() { fs::File::create(PathBuf::from(config.database.path.clone())) .await .unwrap(); - first_migration().await; + first_migration(config).await; let db: DatabaseConnection = Database::connect(&uri).await.unwrap(); @@ -33,8 +32,7 @@ pub async fn init() { } } } -fn hash(src: &str) -> Vec { - let config = CONFIG.get().unwrap(); +fn hash(config: &GlobalConfig, src: &str) -> Vec { digest::digest( &digest::SHA256, &[src.as_bytes(), config.database.salt.as_bytes()].concat(), @@ -43,7 +41,7 @@ fn hash(src: &str) -> Vec { .to_vec() } -pub async fn first_migration() { +pub async fn first_migration(config: &GlobalConfig) { let db = DB.get().unwrap(); let mut perm = UserPermBytes::default(); @@ -59,7 +57,7 @@ pub async fn first_migration() { entity::user::ActiveModel { permission: ActiveValue::Set(perm.0), username: ActiveValue::Set("admin".to_owned()), - password: ActiveValue::Set(hash("admin")), + password: ActiveValue::Set(hash(config, "admin")), ..Default::default() } .save(db) diff --git a/backend/src/init/logger.rs b/backend/src/init/logger.rs index 0d8f0511..6b27e8b8 100644 --- a/backend/src/init/logger.rs +++ b/backend/src/init/logger.rs @@ -1,10 +1,8 @@ use tracing::Level; -use super::config::CONFIG; - -pub fn init() { - let config = CONFIG.get().unwrap(); +use super::config::GlobalConfig; +pub fn init(config: &GlobalConfig) { let level = match config.log_level { 0 => Level::TRACE, 1 => Level::DEBUG, diff --git a/backend/src/init/mod.rs b/backend/src/init/mod.rs index 474235cb..56380d30 100644 --- a/backend/src/init/mod.rs +++ b/backend/src/init/mod.rs @@ -1,9 +1,12 @@ +use self::config::GlobalConfig; + pub mod config; pub mod db; pub mod logger; -pub async fn new() { - config::init().await; - logger::init(); - db::init().await; +pub async fn new() -> GlobalConfig { + let config = config::init().await; + logger::init(&config); + db::init(&config).await; + config } diff --git a/backend/src/macro_tool.rs b/backend/src/macro_tool.rs index cfe7961e..4e583536 100644 --- a/backend/src/macro_tool.rs +++ b/backend/src/macro_tool.rs @@ -16,11 +16,11 @@ macro_rules! report_internal { ($level:ident,$pattern:literal) => {{ log::$level!($pattern); tonic::Status::unknown("unknown error") - };}; + }}; ($level:ident,$pattern:literal, $error:expr) => {{ log::$level!($pattern, $error); tonic::Status::unknown("unknown error") - };}; + }}; } #[macro_export] diff --git a/backend/src/server.rs b/backend/src/server.rs index 2735a7c0..4714e676 100644 --- a/backend/src/server.rs +++ b/backend/src/server.rs @@ -11,7 +11,7 @@ use crate::{ testcase_set_server::TestcaseSetServer, token_set_server::TokenSetServer, user_set_server::UserSetServer, }, - init::config::CONFIG, + init::config::{self}, }; const MAX_FRAME_SIZE: u32 = 1024 * 1024 * 8; @@ -25,8 +25,7 @@ pub struct Server { impl Server { pub async fn start() { - let config = CONFIG.get().unwrap(); - + let config = config::init().await; log::info!("Loading TLS certificate..."); let cert = fs::read_to_string(&config.grpc.public_pem).await.unwrap(); let key = fs::read_to_string(&config.grpc.private_pem).await.unwrap(); @@ -36,9 +35,9 @@ impl Server { let server = Arc::new(Server { token: token::TokenController::new(), - submit: submit::SubmitController::new(config).await.unwrap(), + submit: submit::SubmitController::new(&config).await.unwrap(), dup: duplicate::DupController::default(), - crypto: crypto::CryptoController::new(config), + crypto: crypto::CryptoController::new(&config), }); transport::Server::builder()