Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fix_server/example_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ where
match self.serde_plugin.process_incoming(message, session_id) {
Ok(result) => {
let user_id = &result.get_user_id();
self.user_plugin.add_add_user_session(&user_id, session_id);
self.user_plugin.add_user_session(&user_id, session_id);

let seq_num = result.get_seq_num(); // Ensure R has this method or adjust accordingly
if self.seq_num_plugin.valid_seq_num(seq_num, session_id) {
Expand Down
2 changes: 1 addition & 1 deletion libs/axum-fix-server/src/plugins/user_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl UserPlugin {
}
}

pub fn add_add_user_session(&self, user_id: &(u32, Address), session_id: &SessionId) {
pub fn add_user_session(&self, user_id: &(u32, Address), session_id: &SessionId) {
let mut write_lock = self.users_sessions.write().unwrap();
write_lock
.entry(*user_id)
Expand Down
31 changes: 11 additions & 20 deletions libs/axum-fix-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use axum::{
use eyre::{eyre, Result};
use futures_util::future::join_all;
use itertools::Itertools;
use parking_lot::RwLock;
use std::net::SocketAddr;

/// Server
Expand All @@ -25,7 +24,7 @@ use std::net::SocketAddr;
/// The `Plugin` consumes the message and is responsible for deserialization,
/// message validation (fields, seqnum, signatures), publishing to application, etc.
pub struct Server<Response, Plugin> {
server_state: Arc<RwLock<ServerState<Response, Plugin>>>,
server_state: Arc<ServerState<Response, Plugin>>,
}

impl<Response, Plugin> Server<Response, Plugin>
Expand All @@ -39,18 +38,12 @@ where
/// Initializes the session ID counter to start from 1.
pub fn new(plugin: Plugin) -> Self {
Self {
server_state: Arc::new(RwLock::new(ServerState::new(plugin))),
server_state: Arc::new(ServerState::new(plugin)),
}
}

pub fn with_plugin<Ret>(&self, cb: impl FnOnce(&Plugin) -> Ret) -> Ret {
let state = self.server_state.read();
cb(state.get_plugin())
}

pub fn with_plugin_mut<Ret>(&self, cb: impl FnOnce(&mut Plugin) -> Ret) -> Ret {
let mut state = self.server_state.write();
cb(state.get_plugin_mut())
cb(self.server_state.get_plugin())
}

/// start_server
Expand Down Expand Up @@ -82,14 +75,14 @@ where
///
/// Closes server for new connections
pub fn close_server(&self) {
self.server_state.write().close();
self.server_state.close();
}

/// close_server
///
/// Closes all sessions
pub async fn stop_server(&self) -> Result<()> {
let sessions = self.server_state.write().close_all_sessions()?;
let sessions = self.server_state.close_all_sessions()?;
let stop_futures = sessions.iter().map(|s| s.wait_stopped()).collect_vec();

let (_, failures): (Vec<_>, Vec<_>) =
Expand All @@ -110,7 +103,7 @@ where
/// Sends a response to the appropriate client session based on the session ID in the response.
/// Returns a `Result` indicating success or failure if the session is not found.
pub fn send_response(&self, response: Response) -> Result<()> {
self.server_state.write().process_outgoing(response)
self.server_state.process_outgoing(response)
}
}

Expand All @@ -122,7 +115,7 @@ where
/// session is closed by the client, the loop is broken and the session destroyed.
async fn ws_handler<Response, Plugin>(
ws: WebSocketUpgrade,
State(server_state): State<Arc<RwLock<ServerState<Response, Plugin>>>>,
State(server_state): State<Arc<ServerState<Response, Plugin>>>,
) -> impl IntoResponse
where
Response: Send + Sync + 'static,
Expand All @@ -131,11 +124,11 @@ where
let timeout = std::time::Duration::from_secs(10);

ws.on_upgrade(async move |mut ws: WebSocket| {
if !server_state.read().is_accepting_connections() {
if !server_state.is_accepting_connections() {
return;
}

let (mut rx, session) = match server_state.write().create_session() {
let (mut rx, session) = match server_state.create_session() {
Err(err) => {
tracing::warn!("Failed to create session: {:?}", err);
return;
Expand Down Expand Up @@ -165,9 +158,7 @@ where
}
};

let maybe_message = server_state
.read()
.process_incoming(incoming_message, &session_id);
let maybe_message = server_state.process_incoming(incoming_message, &session_id);

match maybe_message {
Ok(result) => {
Expand Down Expand Up @@ -195,7 +186,7 @@ where

tracing::info!(%session_id, "Closing session");

if let Err(err) = server_state.write().close_session(session_id) {
if let Err(err) = server_state.close_session(session_id) {
tracing::warn!("Failed to close session: {:?}", err);
}
})
Expand Down
37 changes: 23 additions & 14 deletions libs/axum-fix-server/src/server_state.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::{
collections::{hash_map::Entry, HashMap},
marker::PhantomData,
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};

use eyre::{eyre, OptionExt, Result};
use itertools::Itertools;
use parking_lot::RwLock;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};

use crate::{messages::SessionId, server_plugin::ServerPlugin, session::Session};
Expand All @@ -31,9 +35,9 @@ impl Sessions {

pub struct ServerState<Response, Plugin> {
_phantom_response: PhantomData<Response>,
sessions: Sessions,
sessions: RwLock<Sessions>,
plugin: Plugin,
accept_connections: bool,
accept_connections: AtomicBool,
}

impl<Response, Plugin> ServerState<Response, Plugin>
Expand All @@ -45,18 +49,18 @@ where
Self {
_phantom_response: PhantomData::default(),

sessions: Sessions::new(),
sessions: RwLock::new(Sessions::new()),
plugin,
accept_connections: true,
accept_connections: AtomicBool::new(true),
}
}

pub fn is_accepting_connections(&self) -> bool {
self.accept_connections
self.accept_connections.load(Ordering::Relaxed)
}

pub fn close(&mut self) {
self.accept_connections = false;
pub fn close(&self) {
self.accept_connections.store(false, Ordering::Relaxed);
}

pub fn get_plugin(&self) -> &Plugin {
Expand All @@ -68,12 +72,13 @@ where
}

pub fn get_session_ids(&self) -> Vec<SessionId> {
self.sessions.sessions.keys().cloned().collect_vec()
self.sessions.read().sessions.keys().cloned().collect_vec()
}
pub fn create_session(&mut self) -> Result<(UnboundedReceiver<String>, Arc<Session>)> {
let session_id = self.sessions.gen_next_session_id();
pub fn create_session(&self) -> Result<(UnboundedReceiver<String>, Arc<Session>)> {
let mut sessions_write = self.sessions.write();
let session_id = sessions_write.gen_next_session_id();

match self.sessions.sessions.entry(session_id.clone()) {
match sessions_write.sessions.entry(session_id.clone()) {
Entry::Occupied(_) => Err(eyre!("Session {} already exists", session_id)),
Entry::Vacant(vacant_entry) => {
let (tx, rx) = unbounded_channel();
Expand All @@ -86,10 +91,12 @@ where
}
}

pub fn close_session(&mut self, session_id: SessionId) -> Result<Arc<Session>> {
pub fn close_session(&self, session_id: SessionId) -> Result<Arc<Session>> {
self.plugin.destroy_session(&session_id)?;

let session = self
.sessions
.write()
.sessions
.remove(&session_id)
.ok_or_eyre("Session no longer exists")?;
Expand All @@ -98,7 +105,7 @@ where
Ok(session)
}

pub fn close_all_sessions(&mut self) -> Result<Vec<Arc<Session>>> {
pub fn close_all_sessions(&self) -> Result<Vec<Arc<Session>>> {
let (good, bad): (Vec<_>, Vec<_>) = self
.get_session_ids()
.into_iter()
Expand Down Expand Up @@ -128,8 +135,10 @@ where
.map(|(sid, resp)| {
let session = self
.sessions
.read()
.sessions
.get(&sid)
.cloned()
.ok_or_else(|| eyre!("Session {} not found", sid))?;

session.send_response(resp)
Expand Down
14 changes: 2 additions & 12 deletions libs/axum-fix-server/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
use std::sync::Arc;

use axum::extract::ws::{Message, WebSocket};
use chrono::Utc;
use eyre::{eyre, Result};
use futures_util::TryFutureExt;
use itertools::Itertools;
use tokio::{
select,
sync::mpsc::{UnboundedReceiver, UnboundedSender},
time::sleep,
};
use tokio::sync::mpsc::UnboundedSender;
use tokio_util::sync::CancellationToken;

use crate::{messages::SessionId, server_plugin::ServerPlugin, server_state::ServerState};
use crate::messages::SessionId;

/// Session
///
Expand Down
14 changes: 2 additions & 12 deletions src/server/fix/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,7 @@ mod tests {
// Test ACK response
let ack_response = FixResponse::create_ack(&user_id, &session_id, 123);

assert_eq!(
axum_fix_server::plugins::rate_limit_plugin::WithRateLimitPlugin::get_user_id(
&ack_response
),
user_id
);
assert_eq!(WithRateLimitPlugin::get_user_id(&ack_response), user_id);
assert_eq!(
ack_response.get_rate_limit_key(),
RateLimitKey::User(user_id.0, user_id.1)
Expand All @@ -180,12 +175,7 @@ mod tests {
"Rate limit exceeded".to_string(),
);

assert_eq!(
axum_fix_server::plugins::rate_limit_plugin::WithRateLimitPlugin::get_user_id(
&nak_response
),
user_id
);
assert_eq!(WithRateLimitPlugin::get_user_id(&nak_response), user_id);
assert_eq!(
nak_response.get_rate_limit_key(),
RateLimitKey::User(user_id.0, user_id.1)
Expand Down
15 changes: 11 additions & 4 deletions src/server/fix/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::sync::Arc;

use axum_fix_server::server::Server as AxumFixServer;
use eyre::Result;
use symm_core::core::functional::{IntoObservableManyVTable, NotificationHandler};
use symm_core::core::{
functional::{IntoObservableManyVTable, NotificationHandler},
telemetry::{TraceableEvent, WithTracingContext},
};

use crate::server::{
fix::rate_limit_config::FixRateLimitConfig,
Expand All @@ -11,7 +14,7 @@ use crate::server::{
};

pub struct Server {
inner: AxumFixServer<ServerResponse, ServerPlugin>,
inner: AxumFixServer<TraceableEvent<ServerResponse>, ServerPlugin>,
}

impl Server {
Expand All @@ -37,13 +40,17 @@ impl Server {
impl IntoObservableManyVTable<Arc<ServerEvent>> for Server {
fn add_observer(&mut self, observer: Box<dyn NotificationHandler<Arc<ServerEvent>>>) {
self.inner
.with_plugin_mut(|plugin| plugin.add_observer(observer))
.with_plugin(|plugin| plugin.add_observer(observer))
}
}

impl ServerInterface for Server {
fn respond_with(&mut self, response: ServerResponse) {
if let Err(err) = self.inner.send_response(response) {
let mut traceable_response = TraceableEvent::new(response);
traceable_response.inject_baggage();
traceable_response.inject_current_context();

if let Err(err) = self.inner.send_response(traceable_response) {
tracing::warn!("Failed to respond with: {:?}", err);
}
}
Expand Down
Loading