diff --git a/src/client/async.rs b/src/client/async.rs index 4876f0b5..09f6eb9c 100644 --- a/src/client/async.rs +++ b/src/client/async.rs @@ -7,7 +7,7 @@ use log::debug; use time::OffsetDateTime; use time_tz::Tz; -use crate::connection::common::StartupMessageCallback; +use crate::connection::common::{ConnectionOptions, StartupMessageCallback}; use crate::connection::{r#async::AsyncConnection, ConnectionMetadata}; use crate::messages::{OutgoingMessages, RequestMessage}; use crate::transport::{ @@ -117,7 +117,37 @@ impl Client { /// } /// ``` pub async fn connect_with_callback(address: &str, client_id: i32, startup_callback: Option) -> Result { - let connection = AsyncConnection::connect_with_callback(address, client_id, startup_callback).await?; + Self::connect_with_options(address, client_id, startup_callback.into()).await + } + + /// Establishes async connection to TWS or Gateway with custom options + /// + /// This is similar to [`connect`](Self::connect), but allows you to configure + /// connection options like `TCP_NODELAY` and startup callbacks via + /// [`ConnectionOptions`]. + /// + /// # Arguments + /// * `address` - address of server. e.g. 127.0.0.1:4002 + /// * `client_id` - id of client. e.g. 100 + /// * `options` - connection options + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::{Client, ConnectionOptions}; + /// + /// #[tokio::main] + /// async fn main() { + /// let options = ConnectionOptions::default() + /// .tcp_no_delay(true); + /// + /// let client = Client::connect_with_options("127.0.0.1:4002", 100, options) + /// .await + /// .expect("connection failed"); + /// } + /// ``` + pub async fn connect_with_options(address: &str, client_id: i32, options: ConnectionOptions) -> Result { + let connection = AsyncConnection::connect_with_options(address, client_id, options).await?; let connection_metadata = connection.connection_metadata(); let message_bus = Arc::new(AsyncTcpMessageBus::new(connection)?); diff --git a/src/client/sync.rs b/src/client/sync.rs index c5062165..36dbc2d0 100644 --- a/src/client/sync.rs +++ b/src/client/sync.rs @@ -5,7 +5,6 @@ //! subscriptions, and maintains the connection state. use std::fmt::Debug; -use std::net::TcpStream; use std::sync::Arc; use std::time::Duration; @@ -15,7 +14,7 @@ use time_tz::Tz; use crate::accounts::types::{AccountGroup, AccountId, ContractId, ModelCode}; use crate::accounts::{AccountSummaryResult, AccountUpdate, AccountUpdateMulti, FamilyCode, PnL, PnLSingle, PositionUpdate, PositionUpdateMulti}; -use crate::connection::common::StartupMessageCallback; +use crate::connection::common::{ConnectionOptions, StartupMessageCallback}; use crate::connection::{sync::Connection, ConnectionMetadata}; use crate::contracts::{Contract, OptionComputation, SecurityType}; use crate::errors::Error; @@ -28,7 +27,7 @@ use crate::news::NewsArticle; use crate::orders::{CancelOrder, Executions, ExerciseOptions, Order, OrderBuilder, OrderUpdate, Orders, PlaceOrder}; use crate::scanner::ScannerData; use crate::subscriptions::sync::Subscription; -use crate::transport::{InternalSubscription, MessageBus, TcpMessageBus, TcpSocket}; +use crate::transport::{InternalSubscription, MessageBus, TcpMessageBus}; use crate::wsh::AutoFill; use crate::{accounts, contracts, display_groups, market_data, news, orders, scanner, wsh}; @@ -114,10 +113,34 @@ impl Client { /// println!("Received {} startup orders", orders.lock().unwrap().len()); /// ``` pub fn connect_with_callback(address: &str, client_id: i32, startup_callback: Option) -> Result { - let stream = TcpStream::connect(address)?; - let socket = TcpSocket::new(stream, address)?; + Self::connect_with_options(address, client_id, startup_callback.into()) + } - let connection = Connection::connect_with_callback(socket, client_id, startup_callback)?; + /// Establishes connection to TWS or Gateway with custom options + /// + /// This is similar to [`connect`](Self::connect), but allows you to configure + /// connection options like `TCP_NODELAY` and startup callbacks via + /// [`ConnectionOptions`]. + /// + /// # Arguments + /// * `address` - address of server. e.g. 127.0.0.1:4002 + /// * `client_id` - id of client. e.g. 100 + /// * `options` - connection options + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::client::blocking::Client; + /// use ibapi::ConnectionOptions; + /// + /// let options = ConnectionOptions::default() + /// .tcp_no_delay(true); + /// + /// let client = Client::connect_with_options("127.0.0.1:4002", 100, options) + /// .expect("connection failed"); + /// ``` + pub fn connect_with_options(address: &str, client_id: i32, options: ConnectionOptions) -> Result { + let connection = Connection::connect_with_options(address, client_id, options)?; let connection_metadata = connection.connection_metadata(); let message_bus = Arc::new(TcpMessageBus::new(connection)?); diff --git a/src/connection.rs b/src/connection.rs index ade98277..aab51ad0 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -5,7 +5,7 @@ use time_tz::Tz; pub mod common; -// Re-export StartupMessageCallback for lib.rs to re-export publicly +pub use common::ConnectionOptions; pub use common::StartupMessageCallback; /// Metadata about the connection to TWS diff --git a/src/connection/async.rs b/src/connection/async.rs index 67dd9ab8..9d51ed03 100644 --- a/src/connection/async.rs +++ b/src/connection/async.rs @@ -6,7 +6,7 @@ use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::time::sleep; -use super::common::{parse_connection_time, AccountInfo, ConnectionHandler, ConnectionProtocol, StartupMessageCallback}; +use super::common::{parse_connection_time, AccountInfo, ConnectionHandler, ConnectionOptions, ConnectionProtocol, StartupMessageCallback}; use super::ConnectionMetadata; use crate::errors::Error; use crate::messages::{RequestMessage, ResponseMessage}; @@ -25,6 +25,7 @@ pub struct AsyncConnection { pub(crate) recorder: MessageRecorder, pub(crate) connection_handler: ConnectionHandler, pub(crate) connection_url: String, + pub(crate) options: ConnectionOptions, } impl AsyncConnection { @@ -39,7 +40,15 @@ impl AsyncConnection { /// The callback will be invoked for any messages received during connection /// setup that are not part of the normal handshake (e.g., OpenOrder, OrderStatus). pub async fn connect_with_callback(address: &str, client_id: i32, startup_callback: Option) -> Result { - let socket = TcpStream::connect(address).await?; + Self::connect_with_options(address, client_id, startup_callback.into()).await + } + + /// Create a new async connection with custom options. + /// + /// Applies settings from [`ConnectionOptions`] (e.g. `TCP_NODELAY`, startup callback) + /// before performing the TWS handshake. + pub async fn connect_with_options(address: &str, client_id: i32, options: ConnectionOptions) -> Result { + let socket = Self::connect_socket(address, &options).await?; let connection = Self { client_id, @@ -51,13 +60,21 @@ impl AsyncConnection { recorder: MessageRecorder::from_env(), connection_handler: ConnectionHandler::default(), connection_url: address.to_string(), + options, }; - connection.establish_connection(startup_callback.as_ref()).await?; + let cb_ref = connection.options.startup_callback.as_deref(); + connection.establish_connection(cb_ref).await?; Ok(connection) } + async fn connect_socket(address: &str, options: &ConnectionOptions) -> Result { + let socket = TcpStream::connect(address).await?; + socket.set_nodelay(options.tcp_no_delay)?; + Ok(socket) + } + /// Get a copy of the connection metadata pub fn connection_metadata(&self) -> ConnectionMetadata { // For now, we'll use blocking lock since this is called during initialization @@ -88,7 +105,7 @@ impl AsyncConnection { sleep(next_delay).await; - match TcpStream::connect(&self.connection_url).await { + match Self::connect_socket(&self.connection_url, &self.options).await { Ok(new_socket) => { info!("reconnected !!!"); @@ -112,7 +129,7 @@ impl AsyncConnection { } /// Establish connection to TWS - pub(crate) async fn establish_connection(&self, startup_callback: Option<&StartupMessageCallback>) -> Result<(), Error> { + pub(crate) async fn establish_connection(&self, startup_callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>) -> Result<(), Error> { self.handshake().await?; self.start_api().await?; self.receive_account_info(startup_callback).await?; @@ -219,7 +236,7 @@ impl AsyncConnection { } // Fetches next order id and managed accounts. - pub(crate) async fn receive_account_info(&self, startup_callback: Option<&StartupMessageCallback>) -> Result<(), Error> { + pub(crate) async fn receive_account_info(&self, startup_callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>) -> Result<(), Error> { let mut account_info = AccountInfo::default(); let mut attempts = 0; diff --git a/src/connection/common.rs b/src/connection/common.rs index 77bc02cc..f91ea5da 100644 --- a/src/connection/common.rs +++ b/src/connection/common.rs @@ -1,5 +1,8 @@ //! Common connection logic shared between sync and async implementations +use std::fmt; +use std::sync::Arc; + use log::{debug, error, warn}; use time::macros::format_description; use time::OffsetDateTime; @@ -17,6 +20,64 @@ use crate::server_versions; /// instead of discarding them. pub type StartupMessageCallback = Box; +/// Options for configuring a connection to TWS or IB Gateway. +/// +/// Use the builder methods to configure options, then pass to +/// [`Client::connect_with_options`](crate::Client::connect_with_options). +/// +/// # Examples +/// +/// ``` +/// use ibapi::ConnectionOptions; +/// +/// let options = ConnectionOptions::default() +/// .tcp_no_delay(true); +/// ``` +#[derive(Clone, Default)] +pub struct ConnectionOptions { + pub(crate) tcp_no_delay: bool, + pub(crate) startup_callback: Option>, +} + +impl ConnectionOptions { + /// Enable or disable `TCP_NODELAY` on the connection socket. + /// + /// When enabled, disables Nagle's algorithm for lower latency. + /// Default: `false`. + pub fn tcp_no_delay(mut self, enabled: bool) -> Self { + self.tcp_no_delay = enabled; + self + } + + /// Set a callback for unsolicited messages during connection setup. + /// + /// When TWS sends messages like `OpenOrder` or `OrderStatus` during the + /// connection handshake, this callback processes them instead of discarding. + pub fn startup_callback(mut self, callback: impl Fn(ResponseMessage) + Send + Sync + 'static) -> Self { + self.startup_callback = Some(Arc::new(callback)); + self + } +} + +impl From> for ConnectionOptions { + fn from(callback: Option) -> Self { + let mut opts = Self::default(); + if let Some(cb) = callback { + opts.startup_callback = Some(Arc::from(cb)); + } + opts + } +} + +impl fmt::Debug for ConnectionOptions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConnectionOptions") + .field("tcp_no_delay", &self.tcp_no_delay) + .field("startup_callback", &self.startup_callback.is_some()) + .finish() + } +} + /// Data exchanged during the connection handshake #[derive(Debug, Clone)] #[allow(dead_code)] @@ -44,7 +105,11 @@ pub trait ConnectionProtocol { /// /// If a callback is provided, unsolicited messages (like OpenOrder, OrderStatus) /// will be passed to it instead of being discarded. - fn parse_account_info(&self, message: &mut ResponseMessage, callback: Option<&StartupMessageCallback>) -> Result; + fn parse_account_info( + &self, + message: &mut ResponseMessage, + callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>, + ) -> Result; } /// Account information received during connection establishment @@ -109,7 +174,11 @@ impl ConnectionProtocol for ConnectionHandler { message } - fn parse_account_info(&self, message: &mut ResponseMessage, callback: Option<&StartupMessageCallback>) -> Result { + fn parse_account_info( + &self, + message: &mut ResponseMessage, + callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>, + ) -> Result { let mut info = AccountInfo::default(); match message.message_type() { @@ -442,4 +511,33 @@ mod tests { // server_time will contain replacement characters but parsing succeeds assert!(handshake_data.server_time.contains("20251205")); } + + #[test] + fn test_connection_options_default() { + let opts = ConnectionOptions::default(); + assert_eq!(opts.tcp_no_delay, false); + assert!(opts.startup_callback.is_none()); + } + + #[test] + fn test_connection_options_builder() { + let opts = ConnectionOptions::default().tcp_no_delay(true).startup_callback(|_msg| {}); + assert_eq!(opts.tcp_no_delay, true); + assert!(opts.startup_callback.is_some()); + } + + #[test] + fn test_connection_options_clone() { + let opts = ConnectionOptions::default().tcp_no_delay(true); + let cloned = opts.clone(); + assert_eq!(cloned.tcp_no_delay, true); + } + + #[test] + fn test_connection_options_debug() { + let opts = ConnectionOptions::default().tcp_no_delay(true); + let debug_str = format!("{:?}", opts); + assert!(debug_str.contains("tcp_no_delay: true")); + assert!(debug_str.contains("startup_callback: false")); + } } diff --git a/src/connection/sync.rs b/src/connection/sync.rs index 09764807..a793327f 100644 --- a/src/connection/sync.rs +++ b/src/connection/sync.rs @@ -4,7 +4,7 @@ use std::sync::Mutex; use log::{debug, info}; -use super::common::{parse_connection_time, AccountInfo, ConnectionHandler, ConnectionProtocol, StartupMessageCallback}; +use super::common::{parse_connection_time, AccountInfo, ConnectionHandler, ConnectionOptions, ConnectionProtocol, StartupMessageCallback}; use super::ConnectionMetadata; use crate::errors::Error; use crate::messages::{RequestMessage, ResponseMessage}; @@ -12,6 +12,7 @@ use crate::trace; use crate::transport::common::{FibonacciBackoff, MAX_RECONNECT_ATTEMPTS}; use crate::transport::recorder::MessageRecorder; use crate::transport::sync::Stream; +use crate::transport::sync::TcpSocket; type Response = Result; @@ -26,18 +27,34 @@ pub struct Connection { pub(crate) connection_handler: ConnectionHandler, } +impl Connection { + /// Create a connection with custom options. + /// + /// Applies settings from [`ConnectionOptions`] (e.g. `TCP_NODELAY`, startup callback) + /// before performing the TWS handshake. + pub fn connect_with_options(address: &str, client_id: i32, options: ConnectionOptions) -> Result { + let socket = TcpSocket::connect(address, options.tcp_no_delay)?; + Self::init(socket, client_id, options.startup_callback.as_deref()) + } +} + impl Connection { /// Create a new connection #[allow(dead_code)] pub fn connect(socket: S, client_id: i32) -> Result { - Self::connect_with_callback(socket, client_id, None) + Self::init(socket, client_id, None) } /// Create a new connection with a callback for unsolicited messages /// /// The callback will be invoked for any messages received during connection /// setup that are not part of the normal handshake (e.g., OpenOrder, OrderStatus). + #[allow(dead_code)] pub fn connect_with_callback(socket: S, client_id: i32, startup_callback: Option) -> Result { + Self::init(socket, client_id, startup_callback.as_deref()) + } + + fn init(socket: S, client_id: i32, startup_callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>) -> Result { let connection = Self { client_id, socket, @@ -50,7 +67,7 @@ impl Connection { connection_handler: ConnectionHandler::default(), }; - connection.establish_connection(startup_callback.as_ref())?; + connection.establish_connection(startup_callback)?; Ok(connection) } @@ -95,7 +112,7 @@ impl Connection { } /// Establish connection to TWS - pub(crate) fn establish_connection(&self, startup_callback: Option<&StartupMessageCallback>) -> Result<(), Error> { + pub(crate) fn establish_connection(&self, startup_callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>) -> Result<(), Error> { self.handshake()?; self.start_api()?; self.receive_account_info(startup_callback)?; @@ -175,7 +192,7 @@ impl Connection { } // Fetches next order id and managed accounts. - pub(crate) fn receive_account_info(&self, startup_callback: Option<&StartupMessageCallback>) -> Result<(), Error> { + pub(crate) fn receive_account_info(&self, startup_callback: Option<&(dyn Fn(ResponseMessage) + Send + Sync)>) -> Result<(), Error> { let mut account_info = AccountInfo::default(); let mut attempts = 0; diff --git a/src/lib.rs b/src/lib.rs index 66c324e8..ef272506 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,6 +86,7 @@ pub(crate) mod connection; /// println!("Received {} startup orders", orders.lock().unwrap().len()); /// } /// ``` +pub use connection::ConnectionOptions; pub use connection::StartupMessageCallback; /// Common utilities shared across modules diff --git a/src/prelude.rs b/src/prelude.rs index f5bd4044..9668f3a5 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -20,6 +20,7 @@ // Core client pub use crate::Client; +pub use crate::ConnectionOptions; pub use crate::Error; // Contract types diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 32f4e391..121fc91e 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -267,9 +267,6 @@ impl SubscriptionBuilder { #[cfg(feature = "sync")] pub use sync::TcpMessageBus; -#[cfg(feature = "sync")] -pub(crate) use sync::TcpSocket; - // Async exports (placeholder for now) #[cfg(feature = "async")] pub use r#async::{AsyncInternalSubscription, AsyncMessageBus}; diff --git a/src/transport/sync.rs b/src/transport/sync.rs index d0f26ec8..6b0d022b 100644 --- a/src/transport/sync.rs +++ b/src/transport/sync.rs @@ -699,17 +699,25 @@ pub(crate) struct TcpSocket { reader: Mutex, writer: Mutex, connection_url: String, + tcp_no_delay: bool, } impl TcpSocket { - pub fn new(stream: TcpStream, connection_url: &str) -> Result { + pub fn connect(address: &str, tcp_no_delay: bool) -> Result { + let stream = TcpStream::connect(address)?; + Self::new(stream, address, tcp_no_delay) + } + + pub fn new(stream: TcpStream, connection_url: &str, tcp_no_delay: bool) -> Result { let writer = stream.try_clone()?; stream.set_read_timeout(Some(TWS_READ_TIMEOUT))?; + stream.set_nodelay(tcp_no_delay)?; Ok(Self { reader: Mutex::new(stream), writer: Mutex::new(writer), connection_url: connection_url.to_string(), + tcp_no_delay, }) } } @@ -719,6 +727,7 @@ impl Reconnect for TcpSocket { match TcpStream::connect(&self.connection_url) { Ok(stream) => { stream.set_read_timeout(Some(TWS_READ_TIMEOUT))?; + stream.set_nodelay(self.tcp_no_delay)?; let mut reader = self.reader.lock()?; *reader = stream.try_clone()?;