Skip to content

Commit

Permalink
Use Arc<Server> instead of wrapping Server's fields in Arc, pass &RwL…
Browse files Browse the repository at this point in the history
…ock around instead
  • Loading branch information
serprex committed Jan 8, 2024
1 parent 0bdfeec commit 2cd6219
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 84 deletions.
17 changes: 8 additions & 9 deletions src/rs/server/src/handleget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
use std::collections::{BTreeMap, HashMap};
use std::convert::Infallible;
use std::fmt::Write;
use std::sync::Arc;
use std::time::SystemTime;

use http_body_util::Full;
Expand Down Expand Up @@ -64,10 +63,10 @@ impl Cache {
}
}

pub type AsyncCache = Arc<RwLock<Cache>>;
pub type AsyncCache = RwLock<Cache>;

pub async fn compress_and_cache(
cache: AsyncCache,
cache: &AsyncCache,
encoding: Encoding,
path: String,
resp: PlainResponse,
Expand Down Expand Up @@ -199,9 +198,9 @@ async fn handle_get_core(
path: &str,
ims: Option<HttpDate>,
accept: Option<AcceptEncoding>,
pgpool: PgPool,
users: AsyncUsers,
cache: AsyncCache,
pgpool: &PgPool,
users: &AsyncUsers,
cache: &AsyncCache,
) -> Result<Response<Bytes>, http::Error> {
let path = if path == "/" { "/index.html" } else { path };
if path.contains("..") || !path.starts_with('/') {
Expand Down Expand Up @@ -504,9 +503,9 @@ pub async fn handle_get(
path: &str,
ims: Option<HttpDate>,
accept: Option<AcceptEncoding>,
pgpool: PgPool,
users: AsyncUsers,
cache: AsyncCache,
pgpool: &PgPool,
users: &AsyncUsers,
cache: &AsyncCache,
) -> Response<Full<Bytes>> {
let (head, body) = handle_get_core(path, ims, accept, pgpool, users, cache).await.unwrap().into_parts();
Response::from_parts(head, Full::<Bytes>::from(body))
Expand Down
18 changes: 9 additions & 9 deletions src/rs/server/src/handlews.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ pub struct Sock {
hide: bool,
}

pub type AsyncUsers = Arc<RwLock<Users>>;
pub type AsyncSocks = Arc<RwLock<HashMap<usize, Sock>>>;
pub type AsyncUserSocks = Arc<RwLock<HashMap<String, usize>>>;
pub type AsyncUsers = RwLock<Users>;
pub type AsyncSocks = RwLock<HashMap<usize, Sock>>;
pub type AsyncUserSocks = RwLock<HashMap<String, usize>>;

fn sendmsg<T>(tx: &WsSender, val: &T)
where
Expand Down Expand Up @@ -323,11 +323,11 @@ async fn ordered_lock<'a, T>(

pub async fn handle_ws(
ws: WsStream,
pgpool: PgPool,
users: AsyncUsers,
usersocks: AsyncUserSocks,
socks: AsyncSocks,
tls: TlsConnector,
pgpool: &PgPool,
users: &AsyncUsers,
usersocks: &AsyncUserSocks,
socks: &AsyncSocks,
tls: &TlsConnector,
) {
let sockid = NEXT_SOCK_ID.fetch_add(1, Ordering::Relaxed);

Expand Down Expand Up @@ -649,7 +649,7 @@ pub async fn handle_ws(
awon + aloss
};
let newscore =
(wilson(awon, awon + aloss) * (1.0 - decay / (decay + 96.0)) * 1000.0) as i32;
(wilson(awon, awon + aloss) * (1.0 - decay / (decay + 80.0)) * 1000.0) as i32;
trx.execute(
if won {
"update arena set won = won+1, score = $3 where arena_id = $1 and user_id = $2"
Expand Down
65 changes: 31 additions & 34 deletions src/rs/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use bb8_postgres::{bb8::Pool, tokio_postgres, PostgresConnectionManager};
use crate::handleget::AsyncCache;
use crate::handlews::{AsyncSocks, AsyncUserSocks, AsyncUsers};

pub type PgPool = Arc<Pool<PostgresConnectionManager<tokio_postgres::NoTls>>>;
pub type PgPool = Pool<PostgresConnectionManager<tokio_postgres::NoTls>>;
pub type WsStream = WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>;

pub fn get_day() -> u32 {
Expand Down Expand Up @@ -76,7 +76,6 @@ impl From<ConfigRaw> for Config {

type Error = Box<dyn std::error::Error + Send + Sync + 'static>;

#[derive(Clone)]
struct Server {
pub users: AsyncUsers,
pub usersocks: AsyncUserSocks,
Expand All @@ -86,24 +85,29 @@ struct Server {
pub tls: TlsConnector,
}

impl hyper::service::Service<Request<Incoming>> for Server {
pub struct ServerService(Arc<Server>);

impl hyper::service::Service<Request<Incoming>> for ServerService {
type Response = Response<Full<Bytes>>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn call(&self, mut req: Request<Incoming>) -> Self::Future {
let pgpool = self.pgpool.clone();
let users = self.users.clone();
let usersocks = self.usersocks.clone();
let socks = self.socks.clone();
let cache = self.cache.clone();
let tls = self.tls.clone();
let server = self.0.clone();
Box::pin(async move {
if hyper_tungstenite::is_upgrade_request(&req) {
if let Ok((response, socket)) = hyper_tungstenite::upgrade(&mut req, None) {
tokio::spawn(async move {
if let Ok(ws) = socket.await {
handlews::handle_ws(ws, pgpool, users, usersocks, socks, tls).await
handlews::handle_ws(
ws,
&server.pgpool,
&server.users,
&server.usersocks,
&server.socks,
&server.tls,
)
.await
}
});

Expand All @@ -122,9 +126,9 @@ impl hyper::service::Service<Request<Incoming>> for Server {
.get("accept-encoding")
.and_then(|hv| hv.to_str().ok())
.and_then(|hv| hv.parse().ok()),
pgpool,
users,
cache,
&server.pgpool,
&server.users,
&server.cache,
)
.await)
}
Expand All @@ -145,34 +149,28 @@ async fn main() {
}
(
listen,
Arc::new(
Pool::builder()
.build(PostgresConnectionManager::new(pg, tokio_postgres::NoTls))
.await
.expect("Failed to create connection pool"),
),
Pool::builder()
.build(PostgresConnectionManager::new(pg, tokio_postgres::NoTls))
.await
.expect("Failed to create connection pool"),
)
};

let (closetx, closerx) = tokio::sync::watch::channel(());
let mut gccloserx = closerx.clone();

let users = AsyncUsers::default();
let usersocks = AsyncUserSocks::default();
let socks = AsyncSocks::default();
let cache = AsyncCache::default();
let gcusers = users.clone();
let gcusersocks = usersocks.clone();
let gcsocks = socks.clone();
let gcpgpool = pgpool.clone();
let mut gccloserx = closerx.clone();
let sigintusers = users.clone();
let sigintpgpool = pgpool.clone();
let tlsconfig = ClientConfig::builder()
.with_root_certificates(RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(),
})
.with_no_client_auth();
let tls = TlsConnector::from(Arc::new(tlsconfig));
let server = Arc::new(Server { pgpool, users, usersocks, socks, cache, tls });
let gc = server.clone();

let mut interval = tokio::time::interval(Duration::new(300, 0));
tokio::spawn(async move {
Expand All @@ -188,10 +186,10 @@ async fn main() {
}
_ = interval.tick() => (),
}
if let Ok(client) = gcpgpool.get().await {
let mut users = gcusers.write().await;
if let Ok(client) = gc.pgpool.get().await {
let mut users = gc.users.write().await;
let _ = tokio::join!(users
.store(&client, gcusersocks.clone(), gcsocks.clone()),
.store(&client, &gc.usersocks, &gc.socks),
client.execute("delete from trade_request where expire_at < now()", &[]),
client.execute(
"with expiredids (id) as (select id from games where expire_at < now()) \
Expand All @@ -204,7 +202,6 @@ async fn main() {

let mut sigintstream = signal(SignalKind::interrupt()).expect("Failed to setup signal handler");
let listener = tokio::net::TcpListener::bind((Ipv4Addr::new(0, 0, 0, 0), listenport)).await.unwrap();
let server = Server { pgpool, users, usersocks, socks, cache, tls };
let mut http = hyper::server::conn::http1::Builder::new();
http.keep_alive(true);

Expand All @@ -214,7 +211,7 @@ async fn main() {
_ = sigintstream.recv() => break,
accepted = listener.accept() => {
if let Ok((stream, _)) = accepted {
let connection = http.serve_connection(TokioIo::new(stream), server.clone()).with_upgrades();
let connection = http.serve_connection(TokioIo::new(stream), ServerService(server.clone())).with_upgrades();
tokio::spawn(async move {
if let Err(err) = connection.await {
println!("Error serving HTTP connection: {err:?}");
Expand All @@ -227,12 +224,12 @@ async fn main() {

drop(closetx);
println!("Shutting down");
if let Ok(client) = sigintpgpool.get().await {
if !sigintusers.write().await.saveall(&client).await {
if let Ok(client) = server.pgpool.get().await {
if !server.users.read().await.saveall(&client).await {
println!("Error while saving users");
}
} else {
println!("Failed to connect");
}
drop(sigintpgpool);
drop(server)
}
56 changes: 24 additions & 32 deletions src/rs/server/src/users.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use bb8_postgres::tokio_postgres::types::{FromSql, ToSql};
Expand Down Expand Up @@ -114,39 +115,37 @@ impl UserObject {
pub type User = Arc<Mutex<UserObject>>;

#[derive(Default)]
pub struct Users(HashMap<String, (bool, User)>);
pub struct Users(HashMap<String, (AtomicBool, User)>);

impl Users {
pub async fn load<GC>(&mut self, client: &GC, name: &str) -> Option<User>
pub async fn load<GC>(&self, client: &GC, name: &str) -> Option<User>
where
GC: GenericClient,
{
if let Some(&mut (ref mut gc, ref user)) = self.0.get_mut(name) {
*gc = false;
if let Some((ref gc, ref user)) = self.0.get(name) {
gc.store(false, Ordering::Release);
Some(user.clone())
} else if let Some(row) = client.query_opt("select u.id, u.auth, u.salt, u.iter, u.algo, ud.data from user_data ud join users u on u.id = ud.user_id where u.name = $1 and ud.type_id = 1", &[&name]).await.expect("Connection failed while loading user") {
let Json(userdata) = row.try_get::<usize, Json<UserData>>(5).expect("Invalid json for user");
let namestr = name.to_string();
let userarc = Arc::new(Mutex::new(UserObject {
name: namestr.clone(),
id: row.get::<usize, i64>(0),
auth: row.get::<usize, String>(1),
salt: row.get::<usize, Vec<u8>>(2),
iter: row.get::<usize, i32>(3) as u32,
algo: row.get::<usize, HashAlgo>(4),
data: userdata,
}));
self.insert(namestr, userarc.clone());
Some(userarc)
} else {
if let Some(row) = client.query_opt("select u.id, u.auth, u.salt, u.iter, u.algo, ud.data from user_data ud join users u on u.id = ud.user_id where u.name = $1 and ud.type_id = 1", &[&name]).await.expect("Connection failed while loading user") {
let Json(userdata) = row.try_get::<usize, Json<UserData>>(5).expect("Invalid json for user");
let namestr = name.to_string();
let userarc = Arc::new(Mutex::new(UserObject {
name: namestr.clone(),
id: row.get::<usize, i64>(0),
auth: row.get::<usize, String>(1),
salt: row.get::<usize, Vec<u8>>(2),
iter: row.get::<usize, i32>(3) as u32,
algo: row.get::<usize, HashAlgo>(4),
data: userdata,
}));
self.insert(namestr, userarc.clone());
Some(userarc)
} else {
None
}
None
}
}

pub fn insert(&mut self, name: String, user: User) {
self.0.insert(name, (false, user));
self.0.insert(name, (AtomicBool::new(false), user));
}

pub fn remove(&mut self, name: &str) {
Expand All @@ -166,7 +165,7 @@ impl Users {
}
}

pub async fn saveall(&mut self, client: &Client) -> bool {
pub async fn saveall(&self, client: &Client) -> bool {
let mut queries = Vec::new();
for &(_, ref user) in self.0.values() {
queries.push(async move {
Expand All @@ -182,19 +181,12 @@ impl Users {
futures::future::join_all(queries).await.into_iter().all(|x| x.is_ok())
}

pub async fn store(&mut self, client: &Client, usersocks: AsyncUserSocks, socks: AsyncSocks) {
pub async fn store(&mut self, client: &Client, usersocks: &AsyncUserSocks, socks: &AsyncSocks) {
if self.saveall(client).await {
let mut usersocks = usersocks.write().await;
let socks = socks.read().await;
usersocks.retain(|_, v| socks.contains_key(v));
self.0.retain(|_, &mut (ref mut gc, _)| {
if *gc {
false
} else {
*gc = true;
true
}
});
self.0.retain(|_, (ref gc, _)| gc.swap(false, Ordering::AcqRel));
}
}
}

0 comments on commit 2cd6219

Please sign in to comment.