diff --git a/src/rs/server/src/handleget.rs b/src/rs/server/src/handleget.rs index 3f5f089e..ce37a2dc 100644 --- a/src/rs/server/src/handleget.rs +++ b/src/rs/server/src/handleget.rs @@ -3,7 +3,6 @@ use std::collections::{BTreeMap, HashMap}; use std::convert::Infallible; use std::fmt::Write; -use std::sync::Arc; use std::time::SystemTime; use http_body_util::Full; @@ -64,10 +63,10 @@ impl Cache { } } -pub type AsyncCache = Arc>; +pub type AsyncCache = RwLock; pub async fn compress_and_cache( - cache: AsyncCache, + cache: &AsyncCache, encoding: Encoding, path: String, resp: PlainResponse, @@ -199,9 +198,9 @@ async fn handle_get_core( path: &str, ims: Option, accept: Option, - pgpool: PgPool, - users: AsyncUsers, - cache: AsyncCache, + pgpool: &PgPool, + users: &AsyncUsers, + cache: &AsyncCache, ) -> Result, http::Error> { let path = if path == "/" { "/index.html" } else { path }; if path.contains("..") || !path.starts_with('/') { @@ -504,9 +503,9 @@ pub async fn handle_get( path: &str, ims: Option, accept: Option, - pgpool: PgPool, - users: AsyncUsers, - cache: AsyncCache, + pgpool: &PgPool, + users: &AsyncUsers, + cache: &AsyncCache, ) -> Response> { let (head, body) = handle_get_core(path, ims, accept, pgpool, users, cache).await.unwrap().into_parts(); Response::from_parts(head, Full::::from(body)) diff --git a/src/rs/server/src/handlews.rs b/src/rs/server/src/handlews.rs index e24ec081..94070a7d 100644 --- a/src/rs/server/src/handlews.rs +++ b/src/rs/server/src/handlews.rs @@ -72,9 +72,9 @@ pub struct Sock { hide: bool, } -pub type AsyncUsers = Arc>; -pub type AsyncSocks = Arc>>; -pub type AsyncUserSocks = Arc>>; +pub type AsyncUsers = RwLock; +pub type AsyncSocks = RwLock>; +pub type AsyncUserSocks = RwLock>; fn sendmsg(tx: &WsSender, val: &T) where @@ -323,11 +323,11 @@ async fn ordered_lock<'a, T>( pub async fn handle_ws( ws: WsStream, - pgpool: PgPool, - users: AsyncUsers, - usersocks: AsyncUserSocks, - socks: AsyncSocks, - tls: TlsConnector, + pgpool: &PgPool, + users: &AsyncUsers, + usersocks: &AsyncUserSocks, + socks: &AsyncSocks, + tls: &TlsConnector, ) { let sockid = NEXT_SOCK_ID.fetch_add(1, Ordering::Relaxed); @@ -649,7 +649,7 @@ pub async fn handle_ws( awon + aloss }; let newscore = - (wilson(awon, awon + aloss) * (1.0 - decay / (decay + 96.0)) * 1000.0) as i32; + (wilson(awon, awon + aloss) * (1.0 - decay / (decay + 80.0)) * 1000.0) as i32; trx.execute( if won { "update arena set won = won+1, score = $3 where arena_id = $1 and user_id = $2" diff --git a/src/rs/server/src/main.rs b/src/rs/server/src/main.rs index 07a2ecd4..43829c47 100644 --- a/src/rs/server/src/main.rs +++ b/src/rs/server/src/main.rs @@ -36,7 +36,7 @@ use bb8_postgres::{bb8::Pool, tokio_postgres, PostgresConnectionManager}; use crate::handleget::AsyncCache; use crate::handlews::{AsyncSocks, AsyncUserSocks, AsyncUsers}; -pub type PgPool = Arc>>; +pub type PgPool = Pool>; pub type WsStream = WebSocketStream>; pub fn get_day() -> u32 { @@ -76,7 +76,6 @@ impl From for Config { type Error = Box; -#[derive(Clone)] struct Server { pub users: AsyncUsers, pub usersocks: AsyncUserSocks, @@ -86,24 +85,27 @@ struct Server { pub tls: TlsConnector, } -impl hyper::service::Service> for Server { +pub struct ServerService(Arc); + +impl hyper::service::Service> for ServerService { type Response = Response>; type Error = Error; type Future = Pin> + Send>>; fn call(&self, mut req: Request) -> Self::Future { - let pgpool = self.pgpool.clone(); - let users = self.users.clone(); - let usersocks = self.usersocks.clone(); - let socks = self.socks.clone(); - let cache = self.cache.clone(); - let tls = self.tls.clone(); + let server = self.0.clone(); Box::pin(async move { if hyper_tungstenite::is_upgrade_request(&req) { if let Ok((response, socket)) = hyper_tungstenite::upgrade(&mut req, None) { tokio::spawn(async move { if let Ok(ws) = socket.await { - handlews::handle_ws(ws, pgpool, users, usersocks, socks, tls).await + handlews::handle_ws(ws, + &server.pgpool, + &server.users, + &server.usersocks, + &server.socks, + &server.tls, + ).await } }); @@ -122,9 +124,9 @@ impl hyper::service::Service> for Server { .get("accept-encoding") .and_then(|hv| hv.to_str().ok()) .and_then(|hv| hv.parse().ok()), - pgpool, - users, - cache, + &server.pgpool, + &server.users, + &server.cache, ) .await) } @@ -145,34 +147,28 @@ async fn main() { } ( listen, - Arc::new( - Pool::builder() - .build(PostgresConnectionManager::new(pg, tokio_postgres::NoTls)) - .await - .expect("Failed to create connection pool"), - ), + Pool::builder() + .build(PostgresConnectionManager::new(pg, tokio_postgres::NoTls)) + .await + .expect("Failed to create connection pool"), ) }; let (closetx, closerx) = tokio::sync::watch::channel(()); + let mut gccloserx = closerx.clone(); let users = AsyncUsers::default(); let usersocks = AsyncUserSocks::default(); let socks = AsyncSocks::default(); let cache = AsyncCache::default(); - let gcusers = users.clone(); - let gcusersocks = usersocks.clone(); - let gcsocks = socks.clone(); - let gcpgpool = pgpool.clone(); - let mut gccloserx = closerx.clone(); - let sigintusers = users.clone(); - let sigintpgpool = pgpool.clone(); let tlsconfig = ClientConfig::builder() .with_root_certificates(RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(), }) .with_no_client_auth(); let tls = TlsConnector::from(Arc::new(tlsconfig)); + let server = Arc::new(Server { pgpool, users, usersocks, socks, cache, tls }); + let gc = server.clone(); let mut interval = tokio::time::interval(Duration::new(300, 0)); tokio::spawn(async move { @@ -188,10 +184,10 @@ async fn main() { } _ = interval.tick() => (), } - if let Ok(client) = gcpgpool.get().await { - let mut users = gcusers.write().await; + if let Ok(client) = gc.pgpool.get().await { + let mut users = gc.users.write().await; let _ = tokio::join!(users - .store(&client, gcusersocks.clone(), gcsocks.clone()), + .store(&client, &gc.usersocks, &gc.socks), client.execute("delete from trade_request where expire_at < now()", &[]), client.execute( "with expiredids (id) as (select id from games where expire_at < now()) \ @@ -204,7 +200,6 @@ async fn main() { let mut sigintstream = signal(SignalKind::interrupt()).expect("Failed to setup signal handler"); let listener = tokio::net::TcpListener::bind((Ipv4Addr::new(0, 0, 0, 0), listenport)).await.unwrap(); - let server = Server { pgpool, users, usersocks, socks, cache, tls }; let mut http = hyper::server::conn::http1::Builder::new(); http.keep_alive(true); @@ -214,7 +209,7 @@ async fn main() { _ = sigintstream.recv() => break, accepted = listener.accept() => { if let Ok((stream, _)) = accepted { - let connection = http.serve_connection(TokioIo::new(stream), server.clone()).with_upgrades(); + let connection = http.serve_connection(TokioIo::new(stream), ServerService(server.clone())).with_upgrades(); tokio::spawn(async move { if let Err(err) = connection.await { println!("Error serving HTTP connection: {err:?}"); @@ -227,12 +222,12 @@ async fn main() { drop(closetx); println!("Shutting down"); - if let Ok(client) = sigintpgpool.get().await { - if !sigintusers.write().await.saveall(&client).await { + if let Ok(client) = server.pgpool.get().await { + if !server.users.write().await.saveall(&client).await { println!("Error while saving users"); } } else { println!("Failed to connect"); } - drop(sigintpgpool); + drop(server) } diff --git a/src/rs/server/src/users.rs b/src/rs/server/src/users.rs index 5397a56a..5c7fd8d8 100644 --- a/src/rs/server/src/users.rs +++ b/src/rs/server/src/users.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use bb8_postgres::tokio_postgres::types::{FromSql, ToSql}; use bb8_postgres::tokio_postgres::{types::Json, Client, GenericClient}; @@ -114,15 +115,15 @@ impl UserObject { pub type User = Arc>; #[derive(Default)] -pub struct Users(HashMap); +pub struct Users(HashMap); impl Users { pub async fn load(&mut self, client: &GC, name: &str) -> Option where GC: GenericClient, { - if let Some(&mut (ref mut gc, ref user)) = self.0.get_mut(name) { - *gc = false; + if let Some((ref gc, ref user)) = self.0.get(name) { + gc.store(false, Ordering::Release); Some(user.clone()) } else { if let Some(row) = client.query_opt("select u.id, u.auth, u.salt, u.iter, u.algo, ud.data from user_data ud join users u on u.id = ud.user_id where u.name = $1 and ud.type_id = 1", &[&name]).await.expect("Connection failed while loading user") { @@ -146,7 +147,7 @@ impl Users { } pub fn insert(&mut self, name: String, user: User) { - self.0.insert(name, (false, user)); + self.0.insert(name, (AtomicBool::new(false), user)); } pub fn remove(&mut self, name: &str) { @@ -166,7 +167,7 @@ impl Users { } } - pub async fn saveall(&mut self, client: &Client) -> bool { + pub async fn saveall(&self, client: &Client) -> bool { let mut queries = Vec::new(); for &(_, ref user) in self.0.values() { queries.push(async move { @@ -182,19 +183,12 @@ impl Users { futures::future::join_all(queries).await.into_iter().all(|x| x.is_ok()) } - pub async fn store(&mut self, client: &Client, usersocks: AsyncUserSocks, socks: AsyncSocks) { + pub async fn store(&mut self, client: &Client, usersocks: &AsyncUserSocks, socks: &AsyncSocks) { if self.saveall(client).await { let mut usersocks = usersocks.write().await; let socks = socks.read().await; usersocks.retain(|_, v| socks.contains_key(v)); - self.0.retain(|_, &mut (ref mut gc, _)| { - if *gc { - false - } else { - *gc = true; - true - } - }); + self.0.retain(|_, (ref gc, _)| gc.swap(false, Ordering::AcqRel)); } } }