Skip to content

Commit

Permalink
Merge AsyncUserSocks into AsyncUsers, make sockid AtomicUsize
Browse files Browse the repository at this point in the history
Revise authentication logic to use constant time string comparison

Greatly reduce write locks on users hashmap,
particularly when authenticating messages attempt to optimistically load with a read lock
  • Loading branch information
serprex committed Jan 10, 2024
1 parent e0f06f0 commit 9696023
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 101 deletions.
114 changes: 34 additions & 80 deletions src/rs/server/src/handlews.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

Expand Down Expand Up @@ -35,7 +36,7 @@ use crate::starters::{ORIGINAL_STARTERS, STARTERS};
use crate::users::{self, HashAlgo, UserData, UserObject, UserRole, Users};
use crate::{get_day, PgPool, WsStream};

static NEXT_SOCK_ID: AtomicUsize = AtomicUsize::new(0);
static NEXT_SOCK_ID: AtomicUsize = AtomicUsize::new(1);

const SELL_VALUES: [u8; 5] = [5, 1, 3, 15, 150];

Expand Down Expand Up @@ -73,8 +74,7 @@ pub struct Sock {
}

pub type AsyncUsers = RwLock<Users>;
pub type AsyncSocks = RwLock<HashMap<usize, Sock>>;
pub type AsyncUserSocks = RwLock<HashMap<String, usize>>;
pub type AsyncSocks = RwLock<HashMap<NonZeroUsize, Sock>>;

fn sendmsg<T>(tx: &WsSender, val: &T)
where
Expand Down Expand Up @@ -144,14 +144,7 @@ where
client.execute("with arank as (select user_id, arena_id, \"rank\", (row_number() over (partition by arena_id order by score desc, day desc, \"rank\"))::int realrank from arena) update arena set \"rank\" = realrank, bestrank = least(bestrank, realrank) from arank where arank.arena_id = arena.arena_id and arank.user_id = arena.user_id and arank.realrank <> arank.\"rank\"", &[])
}

async fn login_success(
usersocks: &AsyncUserSocks,
tx: &WsSender,
sockid: usize,
user: &mut UserObject,
username: &str,
client: &mut Client,
) {
async fn login_success(tx: &WsSender, user: &mut UserObject, username: &str, client: &mut Client) {
if user.id != -1 {
let today = get_day();
let oracle = user.data.oracle;
Expand Down Expand Up @@ -181,9 +174,7 @@ async fn login_success(
}

if let Ok(userstr) = serde_json::to_string(&WsResponse::login(&*user)) {
if tx.send(Message::Text(userstr)).is_ok() {
usersocks.write().await.insert(String::from(username), sockid);
}
tx.send(Message::Text(userstr)).ok();
}

if user.data.daily == 0 {
Expand Down Expand Up @@ -325,11 +316,10 @@ pub async fn handle_ws(
ws: WsStream,
pgpool: &PgPool,
users: &AsyncUsers,
usersocks: &AsyncUserSocks,
socks: &AsyncSocks,
tls: &TlsConnector,
) {
let sockid = NEXT_SOCK_ID.fetch_add(1, Ordering::Relaxed);
let Some(sockid) = NonZeroUsize::new(NEXT_SOCK_ID.fetch_add(1, Ordering::Relaxed)) else { return };

let (mut user_ws_tx, mut user_ws_rx) = ws.split();
let (tx, rx) = mpsc::unbounded_channel();
Expand All @@ -345,27 +335,13 @@ pub async fn handle_ws(
'msgloop: while let Some(Ok(result)) = user_ws_rx.next().await {
let Message::Text(msg) = result else { continue };
if let Ok(msg) = serde_json::from_str::<UserMessage>(&msg) {
let mut client = pgpool.get().await.expect("Failed to acquire sql connection");
let Ok(mut client) = pgpool.get().await else { continue };

match msg {
UserMessage::a { u, a, msg } => {
let (user, userid) = if let Ok(row) =
client.query_one("select id, auth from users where name = $1", &[&u]).await
{
if a == row.get::<usize, &str>(1) {
if let Some(user) = users.write().await.load(&*client, &u).await {
let rightsock = usersocks.read().await.get(&u) == Some(&sockid);
if !rightsock {
usersocks.write().await.insert(u.clone(), sockid);
}
(user, row.get::<usize, i64>(0))
} else {
continue;
}
} else {
continue;
}
} else {
let Some((user, userid)) =
Users::load_with_auth(users, &*client, &u, &a, sockid).await
else {
continue;
};
match msg {
Expand All @@ -379,8 +355,7 @@ pub async fn handle_ws(
if u == "serprex" {
let mut wusers = users.write().await;
if let Some(user) = wusers.load(&*client, &m).await {
let mut user = user.lock().await;
user.auth = String::new();
user.lock().await.auth.clear();
}
}
}
Expand Down Expand Up @@ -478,7 +453,7 @@ pub async fn handle_ws(
decks.insert(String::from("2"), String::from(sid.2));
decks.insert(String::from("3"), String::from(sid.3));
user.data.decks = decks;
login_success(&usersocks, &tx, sockid, &mut *user, &u, &mut client)
login_success(&tx, &mut *user, &u, &mut client)
.await;
}
}
Expand Down Expand Up @@ -514,11 +489,7 @@ pub async fn handle_ws(
}
}
AuthMessage::logout => {
let mut wusers = users.write().await;
let mut wusersocks = usersocks.write().await;
wusersocks.remove(&u);
drop(wusersocks);
wusers.evict(&client, &u).await;
users.write().await.evict(&client, &u).await;
}
AuthMessage::delete => {
let params: &[&(dyn ToSql + Sync)] = &[&userid];
Expand All @@ -535,10 +506,7 @@ pub async fn handle_ws(
.is_ok()
{
trx.commit().await.ok();
let mut wusers = users.write().await;
let mut wusersocks = usersocks.write().await;
wusersocks.remove(&u);
wusers.remove(&u);
users.write().await.remove(&u);
}
}
}
Expand Down Expand Up @@ -1030,10 +998,8 @@ pub async fn handle_ws(
},
)
};
if let Some(foesockid) = usersocks.read().await.get(&f) {
if let Some((foesockid, foeuser)) = users.read().await.get(&f) {
if let Some(foesock) = socks.read().await.get(&foesockid) {
let mut wusers = users.write().await;
if let Some(foeuser) = wusers.load(&*client, &f).await {
let foeuserid = foeuser.lock().await.id;
if let Ok(trx) = client.transaction().await {
trx.execute("delete from match_request mr1 where user_id = $1 and accepted", &[&userid]).await.ok();
Expand Down Expand Up @@ -1107,7 +1073,6 @@ pub async fn handle_ws(
}
}
}
}
}
}
AuthMessage::r#move {
Expand All @@ -1117,15 +1082,15 @@ pub async fn handle_ws(
cmd,
} => {
if let Ok(trx) = client.transaction().await {
if let (Ok(moves), Ok(users)) = (
if let (Ok(moves), Ok(urows)) = (
trx.query_one(
"select g.moves from games g join match_request mr on mr.game_id = g.id join users u on u.id = mr.user_id where g.id = $1 and u.id = $2 for update",
&[&id, &userid]).await,
trx.query(
"select u.id, u.name from match_request mr join users u on mr.user_id = u.id where mr.game_id = $1",
&[&id]).await,
) {
if users.iter().all(|row| row.get::<usize, i64>(0) != userid) {
if urows.iter().all(|row| row.get::<usize, i64>(0) != userid) {
sendmsg(&tx, &WsResponse::chat {
mode: 1,
msg: "You aren't in that match",
Expand All @@ -1142,13 +1107,13 @@ pub async fn handle_ws(
"update games set moves = array_append(moves, $2), expire_at = now() + interval '1 hour' where id = $1",
&[&id, &Json(GamesMove { cmd, hash })]).await.is_ok() && trx.commit().await.is_ok() {
if let Ok(movejson) = serde_json::to_string(&WsResponse::r#move { cmd, hash }) {
let rusersocks = usersocks.read().await;
let rusers = users.read().await;
let rsocks = socks.read().await;
for row in users.iter() {
for row in urows.iter() {
let uid: i64 = row.get(0);
if uid != userid {
let name: &str = row.get(1);
if let Some(sockid) = rusersocks.get(name) {
if let Some(ref sockid) = rusers.get_sockid(name) {
if let Some(sock) = rsocks.get(sockid) {
sock.tx.send(Message::Text(movejson.clone())).ok();
}
Expand Down Expand Up @@ -1226,10 +1191,8 @@ pub async fn handle_ws(
}
AuthMessage::canceltrade { f } => {
if u != f {
if let Some(foesockid) = usersocks.read().await.get(&f) {
if let Some((foesockid, foeuser)) = users.read().await.get(&f) {
if let Some(foesock) = socks.read().await.get(&foesockid) {
let mut wusers = users.write().await;
if let Some(foeuser) = wusers.load(&*client, &f).await {
sendmsg(
&foesock.tx,
&WsResponse::tradecanceled { u: &u },
Expand All @@ -1247,14 +1210,12 @@ pub async fn handle_ws(
"delete from trade_request where (user_id = $1 and for_user_id = $2) or (user_id = $2 and for_user_id = $1)",
&[&userid, &foeuserid]).await.ok();
}
}
}
}
}
AuthMessage::reloadtrade { f } => {
if u != f {
let mut wusers = users.write().await;
if let Some(foeuser) = wusers.load(&*client, &f).await {
if let Some((_, foeuser)) = users.read().await.get(&f) {
let foeuserid = foeuser.lock().await.id;
if let Ok(trade) = client.query_one(
"select cards, g from trade_request where user_id = $2 and for_user_id = $1", &[&userid, &foeuserid]).await {
Expand All @@ -1277,10 +1238,8 @@ pub async fn handle_ws(
g,
} => {
if u != f {
if let Some(foesockid) = usersocks.read().await.get(&f) {
if let Some((foesockid, foeuser)) = users.read().await.get(&f) {
if let Some(foesock) = socks.read().await.get(&foesockid) {
let mut wusers = users.write().await;
if let Some(foeuser) = wusers.load(&*client, &f).await {
let (mut user, mut foeuser) =
ordered_lock(&user, &foeuser).await;
if let Ok(trx) = client.transaction().await {
Expand Down Expand Up @@ -1394,7 +1353,6 @@ pub async fn handle_ws(
});
}
}
}
}
}
}
Expand All @@ -1418,7 +1376,7 @@ pub async fn handle_ws(
sendmsg(&tx, &WsResponse::passchange { auth: &user.auth });
}
AuthMessage::challrecv { f, trade } => {
if let Some(foesockid) = usersocks.read().await.get(&f) {
if let Some(foesockid) = users.read().await.get_sockid(&f) {
if let Some(foesock) = socks.read().await.get(&foesockid) {
sendmsg(
&foesock.tx,
Expand All @@ -1438,7 +1396,7 @@ pub async fn handle_ws(
AuthMessage::chat { to, msg } => {
if let Some(to) = to {
let mut sent = false;
if let Some(tosockid) = usersocks.read().await.get(&to) {
if let Some(tosockid) = users.read().await.get_sockid(&to) {
if let Some(sock) = socks.read().await.get(&tosockid) {
if serde_json::to_string(&WsResponse::chatu {
mode: 2,
Expand Down Expand Up @@ -1774,7 +1732,6 @@ pub async fn handle_ws(
);
drop(user);
let mut wusers = users.write().await;
let rusersocks = usersocks.read().await;
let rsocks = socks.read().await;
for sell in sells {
{
Expand Down Expand Up @@ -1806,8 +1763,7 @@ pub async fn handle_ws(
if let Some(selltx) = if sell.u == u {
Some(tx.clone())
} else {
rusersocks
.get(&sell.u)
wusers.get_sockid(&sell.u)
.and_then(|sockid| rsocks.get(&sockid))
.map(|sock| sock.tx.clone())
} {
Expand Down Expand Up @@ -2176,7 +2132,7 @@ pub async fn handle_ws(
algo: users::HASH_ALGO,
data: UserData { oracle: u32::MAX, ..Default::default() },
}));
wusers.insert(u.clone(), user.clone());
wusers.insert(u.clone(), sockid.get(), user.clone());
user
};
let mut user = user.lock().await;
Expand Down Expand Up @@ -2211,7 +2167,8 @@ pub async fn handle_ws(
} else {
user.auth.is_empty()
} {
login_success(&usersocks, &tx, sockid, &mut *user, &u, &mut client).await;
wusers.set_sockid(&u, sockid.get());
login_success(&tx, &mut *user, &u, &mut client).await;
} else {
sendmsg(&tx, &WsResponse::loginfail { err: "Authentication failed" });
}
Expand All @@ -2238,7 +2195,6 @@ pub async fn handle_ws(
(1..output.len()).into_iter().rev().find(|&idx| {
output[idx - 1] == b'\n' && output[idx] == b'{'
}) {
println!("{}", String::from_utf8_lossy(&output));
if let Ok(Value::Object(body)) =
serde_json::from_slice::<Value>(&output[pos..])
{
Expand All @@ -2260,14 +2216,13 @@ pub async fn handle_ws(
let mut user = user.lock().await;
user.auth = g.clone();
login_success(
&usersocks,
&tx,
sockid,
&mut user,
&name,
&mut client,
)
.await;
wusers.set_sockid(&name, sockid.get());
} else {
let mut newuser = UserObject {
name: name.clone(),
Expand All @@ -2279,16 +2234,15 @@ pub async fn handle_ws(
data: Default::default(),
};
login_success(
&usersocks,
&tx,
sockid,
&mut newuser,
&name,
&mut client,
)
.await;
wusers.insert(
name,
sockid.get(),
Arc::new(Mutex::new(newuser)),
);
}
Expand Down Expand Up @@ -2476,10 +2430,10 @@ pub async fn handle_ws(
UserMessage::who => {
let mut res = String::new();
{
let rusersocks = usersocks.read().await;
let rusers = users.read().await;
let rsocks = socks.read().await;
for (name, id) in rusersocks.iter() {
if let Some(sock) = rsocks.get(id) {
for (name, id) in rusers.iter_name_sockid() {
if let Some(sock) = rsocks.get(&id) {
if !sock.hide {
if !res.is_empty() {
res.push_str(", ");
Expand Down
Loading

0 comments on commit 9696023

Please sign in to comment.