From ca57308edc02d39dd4c4d29a41376e9535870bb5 Mon Sep 17 00:00:00 2001 From: AdAstraPerAsperaMX Date: Tue, 6 Jan 2026 19:24:21 -0600 Subject: [PATCH 01/18] feat: Add update_display_group method to clients Implement new method to update contracts in TWS display groups for both synchronous and asynchronous clients. The method allows changing the displayed contract by providing a contract identifier. --- src/client/async.rs | 34 +++++++++++++++++++++++ src/client/sync.rs | 31 +++++++++++++++++++++ src/display_groups/async.rs | 19 +++++++++++++ src/display_groups/common/encoders.rs | 39 +++++++++++++++++++++++++++ src/display_groups/sync.rs | 19 +++++++++++++ src/subscriptions/async.rs | 5 ++++ src/subscriptions/sync.rs | 5 ++++ 7 files changed, 152 insertions(+) diff --git a/src/client/async.rs b/src/client/async.rs index 529063bf..a634283e 100644 --- a/src/client/async.rs +++ b/src/client/async.rs @@ -604,6 +604,40 @@ impl Client { display_groups::r#async::subscribe_to_group_events(self, group_id).await } + /// Updates the contract displayed in a TWS display group. + /// + /// This function changes the contract shown in the specified display group within TWS. + /// You must first subscribe to the group using [`subscribe_to_group_events`](Self::subscribe_to_group_events) + /// before calling this function. + /// + /// # Arguments + /// * `request_id` - The request ID from the subscription (use `subscription.request_id()`) + /// * `contract_info` - Contract to display: + /// - `"contractID@exchange"` for individual contracts (e.g., "265598@SMART") + /// - `"none"` for empty selection + /// - `"combo"` for combination contracts + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::Client; + /// + /// #[tokio::main] + /// async fn main() { + /// let client = Client::connect("127.0.0.1:7497", 100).await.expect("connection failed"); + /// + /// // First subscribe to the display group + /// let subscription = client.subscribe_to_group_events(1).await.expect("subscription failed"); + /// let request_id = subscription.request_id().expect("no request ID"); + /// + /// // Update the display group to show AAPL + /// client.update_display_group(request_id, "265598@SMART").await.expect("update failed"); + /// } + /// ``` + pub async fn update_display_group(&self, request_id: i32, contract_info: &str) -> Result<(), Error> { + display_groups::r#async::update_display_group(self, request_id, contract_info).await + } + // === Market Data === /// Creates a market data subscription builder with a fluent interface. diff --git a/src/client/sync.rs b/src/client/sync.rs index f01cecee..728690f8 100644 --- a/src/client/sync.rs +++ b/src/client/sync.rs @@ -510,6 +510,37 @@ impl Client { display_groups::sync::subscribe_to_group_events(self, group_id) } + /// Updates the contract displayed in a TWS display group. + /// + /// This function changes the contract shown in the specified display group within TWS. + /// You must first subscribe to the group using [`subscribe_to_group_events`](Self::subscribe_to_group_events) + /// before calling this function. + /// + /// # Arguments + /// * `request_id` - The request ID from the subscription (use `subscription.request_id()`) + /// * `contract_info` - Contract to display: + /// - `"contractID@exchange"` for individual contracts (e.g., "265598@SMART") + /// - `"none"` for empty selection + /// - `"combo"` for combination contracts + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::client::blocking::Client; + /// + /// let client = Client::connect("127.0.0.1:7497", 100).expect("connection failed"); + /// + /// // First subscribe to the display group + /// let subscription = client.subscribe_to_group_events(1).expect("subscription failed"); + /// let request_id = subscription.request_id().expect("no request ID"); + /// + /// // Update the display group to show AAPL + /// client.update_display_group(request_id, "265598@SMART").expect("update failed"); + /// ``` + pub fn update_display_group(&self, request_id: i32, contract_info: &str) -> Result<(), Error> { + display_groups::sync::update_display_group(self, request_id, contract_info) + } + // === Contracts === /// Requests contract information. diff --git a/src/display_groups/async.rs b/src/display_groups/async.rs index 93d0c49e..f6ebd9e6 100644 --- a/src/display_groups/async.rs +++ b/src/display_groups/async.rs @@ -22,6 +22,25 @@ pub async fn subscribe_to_group_events(client: &Client, group_id: i32) -> Result builder.send::(request).await } +/// Updates the contract displayed in a TWS display group. +/// +/// This function changes the contract shown in the specified display group within TWS. +/// You must first subscribe to the group using [`subscribe_to_group_events`] before +/// calling this function. The update will trigger a `DisplayGroupUpdated` callback +/// on the existing subscription. +/// +/// # Arguments +/// * `client` - The connected client +/// * `request_id` - The request ID from the subscription (use `subscription.request_id()`) +/// * `contract_info` - Contract to display: +/// - `"contractID@exchange"` for individual contracts (e.g., "265598@SMART") +/// - `"none"` for empty selection +/// - `"combo"` for combination contracts +pub async fn update_display_group(client: &Client, request_id: i32, contract_info: &str) -> Result<(), Error> { + let request = encoders::encode_update_display_group(request_id, contract_info)?; + client.send_message(request).await +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/display_groups/common/encoders.rs b/src/display_groups/common/encoders.rs index 5a39c9ed..bb098ac5 100644 --- a/src/display_groups/common/encoders.rs +++ b/src/display_groups/common/encoders.rs @@ -24,6 +24,21 @@ pub(crate) fn encode_unsubscribe_from_group_events(request_id: i32) -> Result Result { + let mut message = RequestMessage::new(); + message.push_field(&OutgoingMessages::UpdateDisplayGroup); + message.push_field(&VERSION); + message.push_field(&request_id); + message.push_field(&contract_info); + Ok(message) +} + #[cfg(test)] mod tests { use super::*; @@ -52,4 +67,28 @@ mod tests { assert_eq!(message[1], "1"); // version assert_eq!(message[2], request_id.to_field()); } + + #[test] + fn test_encode_update_display_group() { + let request_id = 9000; + let contract_info = "265598@SMART"; + + let message = encode_update_display_group(request_id, contract_info).expect("encoding failed"); + + assert_eq!(message[0], OutgoingMessages::UpdateDisplayGroup.to_field()); + assert_eq!(message[1], "1"); // version + assert_eq!(message[2], request_id.to_field()); + assert_eq!(message[3], contract_info); + } + + #[test] + fn test_encode_update_display_group_none() { + let request_id = 9000; + let contract_info = "none"; + + let message = encode_update_display_group(request_id, contract_info).expect("encoding failed"); + + assert_eq!(message[0], OutgoingMessages::UpdateDisplayGroup.to_field()); + assert_eq!(message[3], "none"); + } } diff --git a/src/display_groups/sync.rs b/src/display_groups/sync.rs index a99fb352..80c38e1f 100644 --- a/src/display_groups/sync.rs +++ b/src/display_groups/sync.rs @@ -21,3 +21,22 @@ pub fn subscribe_to_group_events(client: &Client, group_id: i32) -> Result Result<(), Error> { + let request = encoders::encode_update_display_group(request_id, contract_info)?; + client.send_message(request) +} diff --git a/src/subscriptions/async.rs b/src/subscriptions/async.rs index 0fca1aa1..4e10a025 100644 --- a/src/subscriptions/async.rs +++ b/src/subscriptions/async.rs @@ -249,6 +249,11 @@ impl Subscription { SubscriptionInner::PreDecoded { receiver } => receiver.recv().await, } } + + /// Get the request ID associated with this subscription + pub fn request_id(&self) -> Option { + self.request_id + } } impl Subscription { diff --git a/src/subscriptions/sync.rs b/src/subscriptions/sync.rs index 283ca322..9cf8806f 100644 --- a/src/subscriptions/sync.rs +++ b/src/subscriptions/sync.rs @@ -100,6 +100,11 @@ impl> Subscription { } } + /// Returns the request ID associated with this subscription. + pub fn request_id(&self) -> Option { + self.request_id + } + /// Returns the next available value, blocking if necessary until a value becomes available. /// /// # Examples From caa7a2d68079a7047632370465b21f9ab89daaa0 Mon Sep 17 00:00:00 2001 From: kingyond Date: Thu, 8 Jan 2026 07:57:23 +0800 Subject: [PATCH 02/18] Add time zone map to support China standard time (#364) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add time zone map * fix: correct China timezone mapping to Asia/Shanghai - Create shared timezone utility in src/common/timezone.rs - Map Chinese timezone names to Asia/Shanghai (UTC+8) instead of CST (UTC-6) - Handle UTF-8 Chinese names: 中国标准时间, 北京时间 - Handle Windows name: China Standard Time - Handle GB2312 mojibake by detecting U+FFFD replacement characters - Update both connection and market_data modules to use shared utility - Add tests for timezone mapping --------- Co-authored-by: kingyond <900483@qq.com> Co-authored-by: Wil Boayue --- src/common/mod.rs | 1 + src/common/timezone.rs | 83 +++++++++++++++++++ src/connection/common.rs | 42 +++++++++- src/market_data/historical/common/decoders.rs | 5 +- 4 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 src/common/timezone.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 3e25ceac..d67bb0b9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,6 +3,7 @@ pub mod error_helpers; pub mod request_helpers; pub mod retry; +pub mod timezone; #[cfg(test)] pub mod test_utils; diff --git a/src/common/timezone.rs b/src/common/timezone.rs new file mode 100644 index 00000000..4879286e --- /dev/null +++ b/src/common/timezone.rs @@ -0,0 +1,83 @@ +//! Timezone utilities for handling IB Gateway timezone names + +use time_tz::{timezones, Tz}; + +/// Find timezone by name, handling non-standard names from IB Gateway. +/// +/// IB Gateway may send timezone names in various formats: +/// - IANA names: "America/New_York", "Asia/Shanghai" +/// - Abbreviations: "PST", "EST" +/// - Windows names: "China Standard Time" +/// - Localized names: "中国标准时间" (Chinese) +/// - Mojibake from encoding issues (GB2312 decoded as UTF-8) +pub fn find_timezone(name: &str) -> Vec<&'static Tz> { + let mapped = map_timezone_name(name); + timezones::find_by_name(mapped) +} + +/// Map non-standard timezone names to IANA identifiers. +fn map_timezone_name(name: &str) -> &str { + // UTF-8 Chinese timezone names + if name == "中国标准时间" || name == "北京时间" { + return "Asia/Shanghai"; + } + + // Windows English timezone name + if name == "China Standard Time" { + return "Asia/Shanghai"; + } + + // GB2312/GBK encoded strings decoded as UTF-8 lossy contain U+FFFD. + // In IB Gateway context, this indicates a Chinese installation. + if name.contains('\u{FFFD}') { + return "Asia/Shanghai"; + } + + name +} + +#[cfg(test)] +mod tests { + use super::*; + use time_tz::TimeZone; + + #[test] + fn test_find_timezone_standard() { + assert!(!find_timezone("PST").is_empty()); + assert!(!find_timezone("America/New_York").is_empty()); + } + + #[test] + fn test_find_timezone_china_utf8() { + let zones = find_timezone("中国标准时间"); + assert!(!zones.is_empty()); + assert_eq!(zones[0].name(), "Asia/Shanghai"); + + let zones = find_timezone("北京时间"); + assert!(!zones.is_empty()); + assert_eq!(zones[0].name(), "Asia/Shanghai"); + } + + #[test] + fn test_find_timezone_china_english() { + let zones = find_timezone("China Standard Time"); + assert!(!zones.is_empty()); + assert_eq!(zones[0].name(), "Asia/Shanghai"); + } + + #[test] + fn test_find_timezone_mojibake() { + // Simulate GB2312 decoded as UTF-8 lossy (contains replacement characters) + let mojibake = "test\u{FFFD}\u{FFFD}zone"; + let zones = find_timezone(mojibake); + assert!(!zones.is_empty()); + assert_eq!(zones[0].name(), "Asia/Shanghai"); + } + + #[test] + fn test_find_timezone_passthrough() { + // Unknown timezone names pass through unchanged + let zones = find_timezone("Unknown/Timezone"); + assert!(zones.is_empty()); + } +} diff --git a/src/connection/common.rs b/src/connection/common.rs index 9d66f8fa..3cc38385 100644 --- a/src/connection/common.rs +++ b/src/connection/common.rs @@ -3,8 +3,9 @@ use log::{debug, error, warn}; use time::macros::format_description; use time::OffsetDateTime; -use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, Tz}; +use time_tz::{OffsetResult, PrimitiveDateTimeExt, Tz}; +use crate::common::timezone::find_timezone; use crate::errors::Error; use crate::messages::{encode_length, IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; use crate::server_versions; @@ -152,10 +153,12 @@ pub fn parse_connection_time(connection_time: &str) -> (Option, return (None, None); } - let zones = timezones::find_by_name(parts[2]); + // Combine timezone parts if more than 3 parts (e.g., "China Standard Time") + let tz_name = if parts.len() > 3 { parts[2..].join(" ") } else { parts[2].to_string() }; + let zones = find_timezone(&tz_name); if zones.is_empty() { - error!("Time zone not found for {}", parts[2]); + error!("Time zone not found for {}", tz_name); return (None, None); } @@ -185,7 +188,7 @@ mod tests { use super::*; use std::sync::{Arc, Mutex}; use time::macros::datetime; - use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt}; + use time_tz::{timezones, OffsetResult, PrimitiveDateTimeExt, TimeZone}; #[test] fn test_parse_account_info_next_valid_id() { @@ -348,6 +351,37 @@ mod tests { } } + #[test] + fn test_parse_connection_time_china_standard_time() { + let example = "20230405 22:20:39 China Standard Time"; + let (connection_time, timezone) = parse_connection_time(example); + + assert!(connection_time.is_some()); + assert!(timezone.is_some()); + assert_eq!(timezone.unwrap().name(), "Asia/Shanghai"); + } + + #[test] + fn test_parse_connection_time_chinese_utf8() { + let example = "20230405 22:20:39 中国标准时间"; + let (connection_time, timezone) = parse_connection_time(example); + + assert!(connection_time.is_some()); + assert!(timezone.is_some()); + assert_eq!(timezone.unwrap().name(), "Asia/Shanghai"); + } + + #[test] + fn test_parse_connection_time_mojibake() { + // Simulate GB2312 timezone decoded as UTF-8 lossy + let example = "20230405 22:20:39 \u{FFFD}\u{FFFD}\u{FFFD}"; + let (connection_time, timezone) = parse_connection_time(example); + + assert!(connection_time.is_some()); + assert!(timezone.is_some()); + assert_eq!(timezone.unwrap().name(), "Asia/Shanghai"); + } + #[test] fn test_connection_handler_handshake() { let handler = ConnectionHandler::default(); diff --git a/src/market_data/historical/common/decoders.rs b/src/market_data/historical/common/decoders.rs index 7286cf9b..9c319ff1 100644 --- a/src/market_data/historical/common/decoders.rs +++ b/src/market_data/historical/common/decoders.rs @@ -1,7 +1,8 @@ use time::macros::{format_description, time}; use time::{Date, OffsetDateTime, PrimitiveDateTime}; -use time_tz::{timezones, OffsetDateTimeExt, PrimitiveDateTimeExt, Tz}; +use time_tz::{OffsetDateTimeExt, PrimitiveDateTimeExt, Tz}; +use crate::common::timezone::find_timezone; use crate::messages::ResponseMessage; use crate::{server_versions, Error}; @@ -221,7 +222,7 @@ pub(crate) fn decode_histogram_data(message: &mut ResponseMessage) -> Result &Tz { - let zones = timezones::find_by_name(name); + let zones = find_timezone(name); if zones.is_empty() { panic!("timezone not found for: {name}") } From eef039ca89cac7c1415915629ed106746fed2788 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Wed, 7 Jan 2026 16:22:55 -0800 Subject: [PATCH 03/18] feat: add Notice helper methods to distinguish message types (#366) Add helper methods to Notice struct: - is_cancellation() - code 202 (order cancelled confirmation) - is_warning() - codes 2100-2169 - is_system_message() - codes 1100, 1101, 1102, 1300 (connectivity) - is_informational() - true if any of the above - is_error() - true if NOT informational Fixes #365 --- src/messages.rs | 53 +++++++++++++++ src/messages/tests.rs | 140 +++++++++++++++++++++++++++++++++++++++ src/transport/routing.rs | 5 +- 3 files changed, 194 insertions(+), 4 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index b7c6e713..0d460407 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1074,6 +1074,19 @@ pub struct Notice { pub message: String, } +/// Error code indicating an order was cancelled (confirmation, not an error). +pub const ORDER_CANCELLED_CODE: i32 = 202; + +/// Range of error codes that are considered warnings (2100-2169). +pub const WARNING_CODE_RANGE: std::ops::RangeInclusive = 2100..=2169; + +/// System message codes indicating connectivity status. +/// - 1100: Connectivity lost +/// - 1101: Connectivity restored, market data lost (resubscribe needed) +/// - 1102: Connectivity restored, market data maintained +/// - 1300: Socket port reset during active connection +pub const SYSTEM_MESSAGE_CODES: [i32; 4] = [1100, 1101, 1102, 1300]; + impl Notice { #[allow(private_interfaces)] /// Construct a notice from a response message. @@ -1082,6 +1095,46 @@ impl Notice { let message = message.peek_string(MESSAGE_INDEX); Notice { code, message } } + + /// Returns `true` if this notice indicates an order was cancelled (code 202). + /// + /// Code 202 is sent by TWS to confirm an order cancellation. This is an + /// informational message, not an error. + pub fn is_cancellation(&self) -> bool { + self.code == ORDER_CANCELLED_CODE + } + + /// Returns `true` if this is a warning message (codes 2100-2169). + pub fn is_warning(&self) -> bool { + WARNING_CODE_RANGE.contains(&self.code) + } + + /// Returns `true` if this is a system/connectivity message (codes 1100-1102, 1300). + /// + /// System messages indicate connectivity status changes: + /// - 1100: Connectivity between IB and TWS lost + /// - 1101: Connectivity restored, market data lost (resubscribe needed) + /// - 1102: Connectivity restored, market data maintained + /// - 1300: Socket port reset during active connection + pub fn is_system_message(&self) -> bool { + SYSTEM_MESSAGE_CODES.contains(&self.code) + } + + /// Returns `true` if this is an informational notice (not an error). + /// + /// Informational notices include cancellation confirmations, warnings, + /// and system/connectivity messages. + pub fn is_informational(&self) -> bool { + self.is_cancellation() || self.is_warning() || self.is_system_message() + } + + /// Returns `true` if this is an error requiring attention. + /// + /// Returns `false` for informational messages like cancellation confirmations, + /// warnings, and system messages. + pub fn is_error(&self) -> bool { + !self.is_informational() + } } impl Display for Notice { diff --git a/src/messages/tests.rs b/src/messages/tests.rs index de78688c..35667e33 100644 --- a/src/messages/tests.rs +++ b/src/messages/tests.rs @@ -1210,6 +1210,146 @@ fn test_notice_edge_cases() { } } +#[test] +fn test_notice_is_cancellation() { + // Code 202 = order cancelled + let cancellation = Notice { + code: 202, + message: "Order Cancelled - reason:".to_string(), + }; + assert!(cancellation.is_cancellation()); + assert!(!cancellation.is_warning()); + assert!(!cancellation.is_system_message()); + assert!(cancellation.is_informational()); + assert!(!cancellation.is_error()); + + // Other codes are not cancellations + let error = Notice { + code: 200, + message: "No security definition found".to_string(), + }; + assert!(!error.is_cancellation()); +} + +#[test] +fn test_notice_is_warning() { + // Codes 2100-2169 are warnings + let warning_codes = [2100, 2107, 2119, 2150, 2169]; + for code in warning_codes { + let notice = Notice { + code, + message: format!("Warning with code {}", code), + }; + assert!(notice.is_warning(), "Code {} should be a warning", code); + assert!(!notice.is_cancellation()); + assert!(!notice.is_system_message()); + assert!(notice.is_informational()); + assert!(!notice.is_error()); + } + + // Codes outside 2100-2169 are not warnings + let non_warning_codes = [2099, 2170, 200, 202, 1000]; + for code in non_warning_codes { + let notice = Notice { + code, + message: format!("Non-warning with code {}", code), + }; + assert!(!notice.is_warning(), "Code {} should not be a warning", code); + } +} + +#[test] +fn test_notice_is_system_message() { + // System message codes: 1100, 1101, 1102, 1300 + let system_codes = [ + (1100, "Connectivity between IB and TWS has been lost."), + (1101, "Connectivity restored, data lost."), + (1102, "Connectivity restored, data maintained."), + (1300, "Socket port has been reset."), + ]; + for (code, msg) in system_codes { + let notice = Notice { + code, + message: msg.to_string(), + }; + assert!(notice.is_system_message(), "Code {} should be a system message", code); + assert!(!notice.is_cancellation()); + assert!(!notice.is_warning()); + assert!(notice.is_informational()); + assert!(!notice.is_error()); + } + + // Non-system codes + let non_system_codes = [200, 202, 1099, 1103, 1299, 1301, 2100]; + for code in non_system_codes { + let notice = Notice { + code, + message: format!("Non-system message with code {}", code), + }; + assert!(!notice.is_system_message(), "Code {} should not be a system message", code); + } +} + +#[test] +fn test_notice_is_informational() { + // Informational includes cancellations, warnings, and system messages + let informational_codes = [202, 1100, 1101, 1102, 1300, 2100, 2107, 2169]; + for code in informational_codes { + let notice = Notice { + code, + message: format!("Informational code {}", code), + }; + assert!(notice.is_informational(), "Code {} should be informational", code); + assert!(!notice.is_error(), "Code {} should not be an error", code); + } + + // Non-informational (actual errors) + let error_codes = [100, 200, 201, 321, 502, 10000]; + for code in error_codes { + let notice = Notice { + code, + message: format!("Error code {}", code), + }; + assert!(!notice.is_informational(), "Code {} should not be informational", code); + assert!(notice.is_error(), "Code {} should be an error", code); + } +} + +#[test] +fn test_notice_is_error() { + // Code 200 = actual error + let error = Notice { + code: 200, + message: "No security definition found".to_string(), + }; + assert!(error.is_error()); + assert!(!error.is_informational()); + + // Code 202 = cancellation, not error + let cancellation = Notice { + code: 202, + message: "Order Cancelled".to_string(), + }; + assert!(!cancellation.is_error()); + assert!(cancellation.is_informational()); + + // Code 1100 = system message, not error + let system_msg = Notice { + code: 1100, + message: "Connectivity lost".to_string(), + }; + assert!(!system_msg.is_error()); + assert!(system_msg.is_informational()); + + // Code 2107 = warning, not error + let warning = Notice { + code: 2107, + message: "HMDS data farm connection is inactive.".to_string(), + }; + assert!(!warning.is_error()); + assert!(warning.is_informational()); +} + #[test] fn test_all_incoming_message_conversions() { // Test boundary values and ensure all message types are covered diff --git a/src/transport/routing.rs b/src/transport/routing.rs index a5e91ff5..63d28148 100644 --- a/src/transport/routing.rs +++ b/src/transport/routing.rs @@ -1,6 +1,6 @@ //! Common message routing logic for sync and async implementations -use crate::messages::{IncomingMessages, ResponseMessage}; +use crate::messages::{IncomingMessages, ResponseMessage, WARNING_CODE_RANGE}; /// Represents how a message should be routed #[derive(Debug, Clone, PartialEq)] @@ -72,9 +72,6 @@ pub fn determine_routing(message: &ResponseMessage) -> RoutingDecision { } } -/// Range of error codes that are considered warnings -pub const WARNING_CODE_RANGE: std::ops::RangeInclusive = 2100..=2169; - /// Check if an error code is a warning pub fn is_warning_error(error_code: i32) -> bool { WARNING_CODE_RANGE.contains(&error_code) From d560e482913ada1826767b58992ba9a91bf8eb41 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Wed, 7 Jan 2026 16:37:50 -0800 Subject: [PATCH 04/18] chore: bump version to 2.6.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a377a8f6..9fdf736c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ibapi" -version = "2.5.0" +version = "2.6.0" edition = "2021" authors = ["Wil Boayue "] description = "A Rust implementation of the Interactive Brokers TWS API, providing a reliable and user friendly interface for TWS and IB Gateway. Designed with a focus on simplicity and performance." From 15da4832d53d8acfc6d26be3fe81bd6b216794db Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 16 Jan 2026 17:43:17 -0800 Subject: [PATCH 05/18] fix: prevent scanner infinite hang on error messages (#371) * fix: prevent scanner infinite hang on error messages (#370) - Add explicit handling for IncomingMessages::Error in scanner decode functions (async and sync) to return Error::Message instead of Error::UnexpectedResponse - Add bounded retry logic (max 10 attempts) to subscription next() methods to prevent infinite loops from unexpected responses - Log warnings on retry attempts and errors when max retries exceeded Fixes #370 * refactor: extract shared scanner decode function Move scanner message type matching logic to decode_scanner_message() in the common decoders module to avoid duplication between async and sync implementations. * refactor: extract retry decision logic to common module Move retry checking and logging to check_retry() function in the common module. This consolidates the retry decision logic that was duplicated between async and sync subscription implementations. - Add RetryDecision enum to represent continue/stop decisions - Add check_retry() function that handles counting and logging - Update async and sync subscriptions to use the shared function - Add test for check_retry() * style: fix variable naming and format code Remove underscore prefix from 'err' variable since it is actually used in the should_retry_error() call. --- src/scanner/async.rs | 5 +---- src/scanner/common/decoders.rs | 12 ++++++++++- src/scanner/sync.rs | 7 ++----- src/subscriptions/async.rs | 35 ++++++++++++++++++++------------ src/subscriptions/common.rs | 37 ++++++++++++++++++++++++++++++++++ src/subscriptions/sync.rs | 25 +++++++++++++++++------ 6 files changed, 92 insertions(+), 29 deletions(-) diff --git a/src/scanner/async.rs b/src/scanner/async.rs index 8c9e2fae..67202b83 100644 --- a/src/scanner/async.rs +++ b/src/scanner/async.rs @@ -16,10 +16,7 @@ impl StreamDecoder> for Vec { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::ScannerData]; fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result, Error> { - match message.message_type() { - IncomingMessages::ScannerData => Ok(decoders::decode_scanner_data(message.clone())?), - _ => Err(Error::UnexpectedResponse(message.clone())), - } + decoders::decode_scanner_message(message) } fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { diff --git a/src/scanner/common/decoders.rs b/src/scanner/common/decoders.rs index e95f2525..8efdc5a0 100644 --- a/src/scanner/common/decoders.rs +++ b/src/scanner/common/decoders.rs @@ -1,9 +1,19 @@ use crate::contracts::{Currency, Exchange, SecurityType, Symbol}; -use crate::messages::ResponseMessage; +use crate::messages::{IncomingMessages, ResponseMessage}; use crate::Error; use super::super::ScannerData; +/// Shared decode function for scanner data messages. +/// Handles message type matching and error conversion. +pub(in crate::scanner) fn decode_scanner_message(message: &mut ResponseMessage) -> Result, Error> { + match message.message_type() { + IncomingMessages::ScannerData => decode_scanner_data(message.clone()), + IncomingMessages::Error => Err(Error::from(message.clone())), + _ => Err(Error::UnexpectedResponse(message.clone())), + } +} + pub(in crate::scanner) fn decode_scanner_parameters(mut message: ResponseMessage) -> Result { message.skip(); // skip message type message.skip(); // skip message version diff --git a/src/scanner/sync.rs b/src/scanner/sync.rs index f4204909..f7313145 100644 --- a/src/scanner/sync.rs +++ b/src/scanner/sync.rs @@ -6,16 +6,13 @@ use super::common::{decoders, encoders}; use super::*; use crate::client::blocking::Subscription; use crate::client::{ResponseContext, StreamDecoder}; -use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; +use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage}; use crate::orders::TagValue; use crate::{client::sync::Client, server_versions, Error}; impl StreamDecoder> for Vec { fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result, Error> { - match message.message_type() { - IncomingMessages::ScannerData => Ok(decoders::decode_scanner_data(message.clone())?), - _ => Err(Error::UnexpectedResponse(message.clone())), - } + decoders::decode_scanner_message(message) } fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { diff --git a/src/subscriptions/async.rs b/src/subscriptions/async.rs index 4e10a025..4fb6810b 100644 --- a/src/subscriptions/async.rs +++ b/src/subscriptions/async.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use log::{debug, warn}; use tokio::sync::mpsc; -use super::common::{process_decode_result, ProcessingResult}; +use super::common::{check_retry, process_decode_result, ProcessingResult, RetryDecision}; use super::{ResponseContext, StreamDecoder}; use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage}; use crate::transport::{AsyncInternalSubscription, AsyncMessageBus}; @@ -231,21 +231,30 @@ impl Subscription { subscription, decoder, server_version, - } => loop { - match subscription.next().await { - Some(Ok(mut message)) => { - let result = decoder(*server_version, &mut message); - match process_decode_result(result) { - ProcessingResult::Success(val) => return Some(Ok(val)), - ProcessingResult::EndOfStream => return None, - ProcessingResult::Retry => continue, - ProcessingResult::Error(err) => return Some(Err(err)), + } => { + let mut retry_count = 0; + loop { + match subscription.next().await { + Some(Ok(mut message)) => { + let result = decoder(*server_version, &mut message); + match process_decode_result(result) { + ProcessingResult::Success(val) => return Some(Ok(val)), + ProcessingResult::EndOfStream => return None, + ProcessingResult::Retry => { + if check_retry(retry_count) == RetryDecision::Stop { + return None; + } + retry_count += 1; + continue; + } + ProcessingResult::Error(err) => return Some(Err(err)), + } } + Some(Err(e)) => return Some(Err(e)), + None => return None, } - Some(Err(e)) => return Some(Err(e)), - None => return None, } - }, + } SubscriptionInner::PreDecoded { receiver } => receiver.recv().await, } } diff --git a/src/subscriptions/common.rs b/src/subscriptions/common.rs index 65a597c0..6f707495 100644 --- a/src/subscriptions/common.rs +++ b/src/subscriptions/common.rs @@ -3,6 +3,31 @@ use crate::errors::Error; use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; +/// Maximum number of retry attempts when encountering unexpected responses. +/// This prevents infinite loops when TWS sends unexpected message types. +pub(crate) const MAX_DECODE_RETRIES: usize = 10; + +/// Result of checking whether a retry should be attempted +#[derive(Debug, PartialEq)] +pub(crate) enum RetryDecision { + /// Continue retrying + Continue, + /// Stop retrying, max attempts exceeded + Stop, +} + +/// Checks if a retry should be attempted and logs appropriately. +/// Returns `RetryDecision::Continue` if retry count is below max, `RetryDecision::Stop` otherwise. +pub(crate) fn check_retry(retry_count: usize) -> RetryDecision { + if retry_count < MAX_DECODE_RETRIES { + log::warn!("retrying after unexpected response (attempt {}/{})", retry_count + 1, MAX_DECODE_RETRIES); + RetryDecision::Continue + } else { + log::error!("max retries ({}) exceeded, stopping subscription", MAX_DECODE_RETRIES); + RetryDecision::Stop + } +} + /// Checks if an error indicates the subscription should retry processing #[allow(dead_code)] pub(crate) fn should_retry_error(error: &Error) -> bool { @@ -73,6 +98,18 @@ mod tests { assert!(should_store_error(&Error::ConnectionFailed)); } + #[test] + fn test_check_retry() { + // Should continue when under max retries + assert_eq!(check_retry(0), RetryDecision::Continue); + assert_eq!(check_retry(5), RetryDecision::Continue); + assert_eq!(check_retry(MAX_DECODE_RETRIES - 1), RetryDecision::Continue); + + // Should stop when at or over max retries + assert_eq!(check_retry(MAX_DECODE_RETRIES), RetryDecision::Stop); + assert_eq!(check_retry(MAX_DECODE_RETRIES + 1), RetryDecision::Stop); + } + #[test] fn test_process_decode_result() { // Test success case diff --git a/src/subscriptions/sync.rs b/src/subscriptions/sync.rs index 9cf8806f..dddfa950 100644 --- a/src/subscriptions/sync.rs +++ b/src/subscriptions/sync.rs @@ -1,13 +1,13 @@ //! Synchronous subscription implementation use std::marker::PhantomData; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; use log::{debug, error, warn}; -use super::common::{process_decode_result, should_retry_error, should_store_error, ProcessingResult}; +use super::common::{check_retry, process_decode_result, should_retry_error, should_store_error, ProcessingResult, RetryDecision}; use super::{ResponseContext, StreamDecoder}; use crate::errors::Error; use crate::messages::{OutgoingMessages, ResponseMessage}; @@ -31,6 +31,7 @@ pub struct Subscription> { subscription: InternalSubscription, response_context: Option, error: Mutex>, + retry_count: AtomicUsize, } #[allow(private_bounds)] @@ -57,6 +58,7 @@ impl> Subscription { cancelled: AtomicBool::new(false), snapshot_ended: AtomicBool::new(false), error: Mutex::new(None), + retry_count: AtomicUsize::new(0), } } @@ -136,13 +138,24 @@ impl> Subscription { /// * `None` - If the subscription has ended or encountered an error pub fn next(&self) -> Option { match self.process_response(self.subscription.next()) { - Some(val) => Some(val), + Some(val) => { + self.retry_count.store(0, Ordering::Relaxed); + Some(val) + } None => match self.error() { Some(ref err) if should_retry_error(err) => { - debug!("retrying after error: {err:?}"); - self.next() + let retries = self.retry_count.fetch_add(1, Ordering::Relaxed); + if check_retry(retries) == RetryDecision::Continue { + self.next() + } else { + self.retry_count.store(0, Ordering::Relaxed); + None + } + } + _ => { + self.retry_count.store(0, Ordering::Relaxed); + None } - _ => None, }, } } From 4cc4f8ea95e24ccb7efaa244c11681232e908c26 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 16 Jan 2026 18:37:45 -0800 Subject: [PATCH 06/18] chore: bump version to 2.6.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9fdf736c..73033333 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ibapi" -version = "2.6.0" +version = "2.6.1" edition = "2021" authors = ["Wil Boayue "] description = "A Rust implementation of the Interactive Brokers TWS API, providing a reliable and user friendly interface for TWS and IB Gateway. Designed with a focus on simplicity and performance." From 54280a857c6503863af2da2ac9f32ddb2897e086 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 23 Jan 2026 12:01:02 -0800 Subject: [PATCH 07/18] fix: prevent bracket/OCA order ID collisions (#375) Reserve all order IDs upfront by calling next_order_id() for each order instead of using base_id + offset calculation. Fixes #374 --- src/orders/builder/async_impl.rs | 19 ++++++++++--------- src/orders/builder/sync_impl.rs | 19 ++++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/orders/builder/async_impl.rs b/src/orders/builder/async_impl.rs index 05017048..a3f7fd41 100644 --- a/src/orders/builder/async_impl.rs +++ b/src/orders/builder/async_impl.rs @@ -54,19 +54,21 @@ impl<'a> BracketOrderBuilder<'a, Client> { pub async fn submit_all(self) -> Result { let client = self.parent_builder.client; let contract = self.parent_builder.contract; - let base_id = client.next_order_id(); let orders = self.build()?; - let mut order_ids = Vec::new(); + // Reserve all order IDs upfront to prevent collisions + let parent_id = client.next_order_id(); + let tp_id = client.next_order_id(); + let sl_id = client.next_order_id(); + let reserved_ids = [parent_id, tp_id, sl_id]; for (i, mut order) in orders.into_iter().enumerate() { - let order_id = base_id + i as i32; + let order_id = reserved_ids[i]; order.order_id = order_id; - order_ids.push(order_id); // Update parent_id for child orders if i > 0 { - order.parent_id = base_id; + order.parent_id = parent_id; } // Only transmit the last order @@ -77,7 +79,7 @@ impl<'a> BracketOrderBuilder<'a, Client> { orders::submit_order(client, order_id, contract, &order).await?; } - Ok(BracketOrderIds::new(order_ids[0], order_ids[1], order_ids[2])) + Ok(BracketOrderIds::new(parent_id, tp_id, sl_id)) } } @@ -118,10 +120,9 @@ impl Client { /// ``` pub async fn submit_oca_orders(&self, orders: Vec<(crate::contracts::Contract, crate::orders::Order)>) -> Result, Error> { let mut order_ids = Vec::new(); - let base_id = self.next_order_id(); - for (i, (contract, mut order)) in orders.into_iter().enumerate() { - let order_id = base_id + i as i32; + for (contract, mut order) in orders.into_iter() { + let order_id = self.next_order_id(); order.order_id = order_id; order_ids.push(OrderId::new(order_id)); orders::submit_order(self, order_id, &contract, &order).await?; diff --git a/src/orders/builder/sync_impl.rs b/src/orders/builder/sync_impl.rs index 615dff20..ca0e531e 100644 --- a/src/orders/builder/sync_impl.rs +++ b/src/orders/builder/sync_impl.rs @@ -55,19 +55,21 @@ impl<'a> BracketOrderBuilder<'a, Client> { pub fn submit_all(self) -> Result { let client = self.parent_builder.client; let contract = self.parent_builder.contract; - let base_id = client.next_order_id(); let orders = self.build()?; - let mut order_ids = Vec::new(); + // Reserve all order IDs upfront to prevent collisions + let parent_id = client.next_order_id(); + let tp_id = client.next_order_id(); + let sl_id = client.next_order_id(); + let reserved_ids = [parent_id, tp_id, sl_id]; for (i, mut order) in orders.into_iter().enumerate() { - let order_id = base_id + i as i32; + let order_id = reserved_ids[i]; order.order_id = order_id; - order_ids.push(order_id); // Update parent_id for child orders if i > 0 { - order.parent_id = base_id; + order.parent_id = parent_id; } // Only transmit the last order @@ -78,7 +80,7 @@ impl<'a> BracketOrderBuilder<'a, Client> { orders::blocking::submit_order(client, order_id, contract, &order)?; } - Ok(BracketOrderIds::new(order_ids[0], order_ids[1], order_ids[2])) + Ok(BracketOrderIds::new(parent_id, tp_id, sl_id)) } } @@ -116,10 +118,9 @@ impl Client { /// ``` pub fn submit_oca_orders(&self, orders: Vec<(Contract, crate::orders::Order)>) -> Result, Error> { let mut order_ids = Vec::new(); - let base_id = self.next_order_id(); - for (i, (contract, mut order)) in orders.into_iter().enumerate() { - let order_id = base_id + i as i32; + for (contract, mut order) in orders.into_iter() { + let order_id = self.next_order_id(); order.order_id = order_id; order_ids.push(OrderId::new(order_id)); orders::blocking::submit_order(self, order_id, &contract, &order)?; From 3f14a7680ee7845bc30a45958b33d284727ab41a Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 23 Jan 2026 12:08:57 -0800 Subject: [PATCH 08/18] fix: bracket child orders inherit outside_rth from parent (#373) (#376) Take profit and stop loss orders now inherit the outside_rth flag from the parent order, enabling bracket orders to work correctly during extended trading hours. --- src/orders/builder/order_builder.rs | 2 ++ src/orders/builder/order_builder/tests.rs | 37 +++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/orders/builder/order_builder.rs b/src/orders/builder/order_builder.rs index edc088db..00d4f3c6 100644 --- a/src/orders/builder/order_builder.rs +++ b/src/orders/builder/order_builder.rs @@ -999,6 +999,7 @@ impl<'a, C> BracketOrderBuilder<'a, C> { limit_price: Some(take_profit.value()), parent_id: parent.order_id, transmit: false, + outside_rth: parent.outside_rth, ..Default::default() }; @@ -1010,6 +1011,7 @@ impl<'a, C> BracketOrderBuilder<'a, C> { aux_price: Some(stop_loss.value()), parent_id: parent.order_id, transmit: true, + outside_rth: parent.outside_rth, ..Default::default() }; diff --git a/src/orders/builder/order_builder/tests.rs b/src/orders/builder/order_builder/tests.rs index 45c53218..759cb73c 100644 --- a/src/orders/builder/order_builder/tests.rs +++ b/src/orders/builder/order_builder/tests.rs @@ -746,6 +746,43 @@ fn test_bracket_order_types() { assert_eq!(orders[2].order_type, "STP"); // Stop loss is stop } +#[test] +fn test_bracket_order_inherits_outside_rth() { + let client = MockClient; + let contract = create_test_contract(); + + // Test with outside_rth enabled + let bracket = OrderBuilder::new(&client, &contract) + .buy(100) + .outside_rth() + .bracket() + .entry_limit(50.0) + .take_profit(55.0) + .stop_loss(45.0); + + let orders = bracket.build().unwrap(); + + // All orders should inherit outside_rth from parent + assert!(orders[0].outside_rth, "Parent should have outside_rth"); + assert!(orders[1].outside_rth, "Take profit should inherit outside_rth"); + assert!(orders[2].outside_rth, "Stop loss should inherit outside_rth"); + + // Test without outside_rth (default) + let bracket = OrderBuilder::new(&client, &contract) + .buy(100) + .bracket() + .entry_limit(50.0) + .take_profit(55.0) + .stop_loss(45.0); + + let orders = bracket.build().unwrap(); + + // All orders should have outside_rth = false + assert!(!orders[0].outside_rth); + assert!(!orders[1].outside_rth); + assert!(!orders[2].outside_rth); +} + #[test] fn test_bracket_order_with_missing_action() { let client = MockClient; From df9e0adbad10870a32b4e22397acc5415aca8000 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 23 Jan 2026 12:30:12 -0800 Subject: [PATCH 09/18] feat: add entry_market() for bracket orders (#372) (#377) Allow market entry for bracket orders for immediate execution in scalping scenarios. --- Cargo.toml | 2 +- src/orders/builder/order_builder.rs | 58 +++++++--- src/orders/builder/order_builder/tests.rs | 133 +++++++++++++++++++++- 3 files changed, 173 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 73033333..8ea17cb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ibapi" -version = "2.6.1" +version = "2.6.2" edition = "2021" authors = ["Wil Boayue "] description = "A Rust implementation of the Interactive Brokers TWS API, providing a reliable and user friendly interface for TWS and IB Gateway. Designed with a focus on simplicity and performance." diff --git a/src/orders/builder/order_builder.rs b/src/orders/builder/order_builder.rs index 00d4f3c6..072262fc 100644 --- a/src/orders/builder/order_builder.rs +++ b/src/orders/builder/order_builder.rs @@ -928,10 +928,19 @@ fn set_conjunction(condition: &mut OrderCondition, is_conjunction: bool) { } } +/// Entry order type for bracket orders +#[derive(Default)] +enum BracketEntryType { + #[default] + None, + Limit(f64), + Market, +} + /// Builder for bracket orders pub struct BracketOrderBuilder<'a, C> { pub(crate) parent_builder: OrderBuilder<'a, C>, - entry_price: Option, + entry_type: BracketEntryType, take_profit_price: Option, stop_loss_price: Option, } @@ -940,15 +949,21 @@ impl<'a, C> BracketOrderBuilder<'a, C> { fn new(parent_builder: OrderBuilder<'a, C>) -> Self { Self { parent_builder, - entry_price: None, + entry_type: BracketEntryType::None, take_profit_price: None, stop_loss_price: None, } } + /// Set entry as market order (immediate execution) + pub fn entry_market(mut self) -> Self { + self.entry_type = BracketEntryType::Market; + self + } + /// Set entry limit price pub fn entry_limit(mut self, price: impl Into) -> Self { - self.entry_price = Some(price.into()); + self.entry_type = BracketEntryType::Limit(price.into()); self } @@ -966,26 +981,35 @@ impl<'a, C> BracketOrderBuilder<'a, C> { /// Build bracket orders with full validation pub fn build(mut self) -> Result, ValidationError> { - // Validate and convert prices - let entry_price_raw = self.entry_price.ok_or(ValidationError::MissingRequiredField("entry_price"))?; + // Validate and convert take profit and stop loss prices let take_profit_raw = self.take_profit_price.ok_or(ValidationError::MissingRequiredField("take_profit"))?; let stop_loss_raw = self.stop_loss_price.ok_or(ValidationError::MissingRequiredField("stop_loss"))?; - let entry_price = Price::new(entry_price_raw)?; let take_profit = Price::new(take_profit_raw)?; let stop_loss = Price::new(stop_loss_raw)?; - // Validate bracket order prices - validation::validate_bracket_prices( - self.parent_builder.action.as_ref(), - entry_price.value(), - take_profit.value(), - stop_loss.value(), - )?; - - // Set the entry limit price on parent builder - self.parent_builder.order_type = Some(OrderType::Limit); - self.parent_builder.limit_price = Some(entry_price.value()); + // Set order type based on entry type + match self.entry_type { + BracketEntryType::None => { + return Err(ValidationError::MissingRequiredField("entry (use entry_limit() or entry_market())")); + } + BracketEntryType::Limit(price) => { + let entry_price = Price::new(price)?; + // Validate bracket order prices + validation::validate_bracket_prices( + self.parent_builder.action.as_ref(), + entry_price.value(), + take_profit.value(), + stop_loss.value(), + )?; + self.parent_builder.order_type = Some(OrderType::Limit); + self.parent_builder.limit_price = Some(entry_price.value()); + } + BracketEntryType::Market => { + // Skip price relationship validation for market orders + self.parent_builder.order_type = Some(OrderType::Market); + } + } // Build parent order let mut parent = self.parent_builder.build()?; diff --git a/src/orders/builder/order_builder/tests.rs b/src/orders/builder/order_builder/tests.rs index 759cb73c..02ec1674 100644 --- a/src/orders/builder/order_builder/tests.rs +++ b/src/orders/builder/order_builder/tests.rs @@ -542,7 +542,7 @@ fn test_bracket_order_validation_buy() { } #[test] -fn test_bracket_order_missing_entry_price() { +fn test_bracket_order_missing_entry() { let client = MockClient; let contract = create_test_contract(); @@ -550,7 +550,7 @@ fn test_bracket_order_missing_entry_price() { let result = bracket.build(); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("entry_price")); + assert!(result.unwrap_err().to_string().contains("entry")); } #[test] @@ -799,6 +799,135 @@ fn test_bracket_order_with_missing_action() { assert!(result.is_err()); } +// ===== Market Entry Bracket Order Tests ===== + +#[test] +fn test_bracket_order_market_entry_buy() { + let client = MockClient; + let contract = create_test_contract(); + + let bracket = OrderBuilder::new(&client, &contract) + .buy(100) + .bracket() + .entry_market() + .take_profit(55.0) + .stop_loss(45.0); + + let orders = bracket.build().unwrap(); + assert_eq!(orders.len(), 3); + + // Verify parent order is market order + let parent = &orders[0]; + assert_eq!(parent.action, Action::Buy); + assert_eq!(parent.order_type, "MKT"); + assert_eq!(parent.limit_price, None); + assert!(!parent.transmit); + + // Verify take profit details + let tp = &orders[1]; + assert_eq!(tp.action, Action::Sell); + assert_eq!(tp.order_type, "LMT"); + assert_eq!(tp.limit_price, Some(55.0)); + assert_eq!(tp.parent_id, parent.order_id); + assert!(!tp.transmit); + + // Verify stop loss details + let sl = &orders[2]; + assert_eq!(sl.action, Action::Sell); + assert_eq!(sl.order_type, "STP"); + assert_eq!(sl.aux_price, Some(45.0)); + assert_eq!(sl.parent_id, parent.order_id); + assert!(sl.transmit); +} + +#[test] +fn test_bracket_order_market_entry_sell() { + let client = MockClient; + let contract = create_test_contract(); + + let bracket = OrderBuilder::new(&client, &contract) + .sell(100) + .bracket() + .entry_market() + .take_profit(45.0) + .stop_loss(55.0); + + let orders = bracket.build().unwrap(); + assert_eq!(orders.len(), 3); + + // Verify parent order is market order with Sell action + let parent = &orders[0]; + assert_eq!(parent.action, Action::Sell); + assert_eq!(parent.order_type, "MKT"); + + // Verify child orders have reversed action + let tp = &orders[1]; + assert_eq!(tp.action, Action::Buy); + + let sl = &orders[2]; + assert_eq!(sl.action, Action::Buy); +} + +#[test] +fn test_bracket_order_market_entry_inherits_outside_rth() { + let client = MockClient; + let contract = create_test_contract(); + + let bracket = OrderBuilder::new(&client, &contract) + .buy(100) + .outside_rth() + .bracket() + .entry_market() + .take_profit(55.0) + .stop_loss(45.0); + + let orders = bracket.build().unwrap(); + + // All orders should inherit outside_rth from parent + assert!(orders[0].outside_rth, "Parent should have outside_rth"); + assert!(orders[1].outside_rth, "Take profit should inherit outside_rth"); + assert!(orders[2].outside_rth, "Stop loss should inherit outside_rth"); +} + +#[test] +fn test_bracket_order_market_entry_quantity_propagation() { + let client = MockClient; + let contract = create_test_contract(); + + let bracket = OrderBuilder::new(&client, &contract) + .buy(500) + .bracket() + .entry_market() + .take_profit(55.0) + .stop_loss(45.0); + + let orders = bracket.build().unwrap(); + + // All orders should have the same quantity + assert_eq!(orders[0].total_quantity, 500.0); + assert_eq!(orders[1].total_quantity, 500.0); + assert_eq!(orders[2].total_quantity, 500.0); +} + +#[test] +fn test_bracket_order_market_entry_parent_id_propagation() { + let client = MockClient; + let contract = create_test_contract(); + + let bracket = OrderBuilder::new(&client, &contract) + .buy(100) + .bracket() + .entry_market() + .take_profit(55.0) + .stop_loss(45.0); + + let orders = bracket.build().unwrap(); + let parent_id = orders[0].order_id; + + assert_eq!(orders[1].parent_id, parent_id); + assert_eq!(orders[2].parent_id, parent_id); +} + #[test] fn test_market_on_close() { let client = MockClient; From 7cc5f9ea46452989ec494b172723f8e14899cf20 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 23 Jan 2026 21:53:56 -0800 Subject: [PATCH 10/18] docs: add DRY, SRP, and composition guidelines (#378) --- CLAUDE.md | 5 ++++ docs/api-patterns.md | 61 +++++++++++++++++++++++++++++++++++++++++-- docs/code-style.md | 39 +++++++++++++++++++++++++++ docs/extending-api.md | 59 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 2 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 39f05497..c4c638af 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -53,6 +53,11 @@ cargo test --all-features 4. **Minimal comments**: Keep comments concise, avoid stating the obvious 5. **Run quality checks**: Before committing, run `cargo fmt`, `cargo clippy --features sync`, and `cargo clippy --features async` 6. **Fluent conditional orders**: Use helper functions (`price()`, `time()`, `margin()`, etc.) and method chaining (`.condition()`, `.and_condition()`, `.or_condition()`) for building conditional orders. See [docs/order-types.md](docs/order-types.md#conditional-orders-with-conditions) and [docs/api-patterns.md](docs/api-patterns.md#conditional-order-builder-pattern) for details +7. **Don't repeat code**: Extract repeated logic to `common/`; use shared helpers like `request_helpers` +8. **Single responsibility**: One responsibility per function/module; split orchestration from business logic +9. **Composition**: Single responsibility per struct; use builders for complex construction; max 3 params per function (use builder if 4+) + +See [docs/code-style.md](docs/code-style.md#design-principles) for detailed design guidelines. ## Connection Settings diff --git a/docs/api-patterns.md b/docs/api-patterns.md index 3ad43c8c..aeafc425 100644 --- a/docs/api-patterns.md +++ b/docs/api-patterns.md @@ -523,11 +523,68 @@ loop { thread::sleep(Duration::from_secs(5)); continue; } - + match perform_operation(&client) { Ok(result) => return Ok(result), Err(Error::NotConnected) => continue, Err(e) => return Err(e), } } -``` \ No newline at end of file +``` + +## Trait Composition Patterns + +### Domain Traits for Shared Behavior +```rust +// Use domain traits when: +// - 2+ types need the same operation (encode, decode, validate) +// - You want to write generic functions over those types +pub trait Encodable { + fn encode(&self, message: &mut RequestMessage) -> Result<(), Error>; +} + +pub trait Decodable: Sized { + fn decode(fields: &mut FieldIter) -> Result; +} + +// Implement for types that need this behavior +impl Encodable for Order { /* ... */ } +impl Encodable for Contract { /* ... */ } +``` + +### Extension via Composition +```rust +// Use composition when: +// - A type needs capabilities from multiple sources +// - Behavior should be added without modifying the original type +pub struct Subscription { + receiver: Receiver, + cancel_fn: Box, +} + +// Add behavior via trait impls +impl Iterator for Subscription { /* ... */ } +impl Drop for Subscription { /* ... */ } +``` + +### Newtype Wrappers for Domain Constraints +```rust +// Bad: raw i32 allows invalid IDs and type confusion +fn lookup(contract_id: i32, order_id: i32) -> Contract { /* ... */ } // easy to swap args + +// Good: newtype wrappers prevent mistakes +// Use newtype wrappers when: +// - A primitive has domain constraints (non-zero, positive, etc.) +// - Type confusion is possible (ContractId vs OrderId) +pub struct ContractId(i32); + +impl ContractId { + pub fn new(id: i32) -> Result { + if id <= 0 { return Err(Error::InvalidContractId); } + Ok(Self(id)) + } +} + +// Type system prevents invalid states +fn lookup(id: ContractId) -> Contract { /* ... */ } // Can't pass raw i32 +``` diff --git a/docs/code-style.md b/docs/code-style.md index f805d849..48a4757f 100644 --- a/docs/code-style.md +++ b/docs/code-style.md @@ -1,5 +1,44 @@ # Code Style Guidelines +## Design Principles + +### DRY (Don't Repeat Yourself) +- Extract shared logic to `common/` modules +- Use `request_helpers` for common request patterns +- Prefer traits over code duplication across types +- **When to extract**: 2+ occurrences differing only in parameter values (same logic, types, and control flow) +- **When NOT to extract**: One-time code, or when extraction obscures intent + +### SRP (Single Responsibility Principle) +- One encoder/decoder per message type +- Modules own one domain (accounts, orders, market_data) +- Functions do one thing: encode, decode, validate, or orchestrate +- Max 50 lines per function; extract if larger +- Functions with 4+ parameters should use a builder pattern + +### Composition +- Combine small, focused components to build complex behavior +- Use traits to define shared behavior across types +- Compose complex types from smaller building blocks: + ```rust + // Good: use builder when 4+ params, optional params, or complex construction + let order = order_builder::limit_order(Action::Buy, 100.0, 150.0) + .condition(price_condition) + .build(); + + // Bad: monolithic constructor with 4+ params + let order = Order::new(Action::Buy, 100.0, 150.0, Some(cond), None, None); + ``` +- Prefer `impl Trait` for flexible return types +- Use newtype wrappers for domain constraints + +### When Principles Conflict +- Clarity > DRY (if a reader must jump to another file to understand the flow, don't extract) +- SRP > brevity (split even if it adds lines) +- Consistency with existing code > ideal patterns + +See [extending-api.md#anti-patterns-to-avoid](extending-api.md#anti-patterns-to-avoid) for code examples of these violations. + ## Comments - **Keep comments concise and avoid redundancy**. Don't state the obvious. diff --git a/docs/extending-api.md b/docs/extending-api.md index 79e6b032..61f10ce3 100644 --- a/docs/extending-api.md +++ b/docs/extending-api.md @@ -2,6 +2,65 @@ This guide covers advanced topics for extending the rust-ibapi functionality. +## Anti-Patterns to Avoid + +These examples demonstrate violations of principles in [code-style.md](code-style.md#design-principles). + +### Duplicated Logic +```rust +// Bad: duplicated validation in sync and async +pub fn my_func(client: &Client, param: &str) -> Result { + if param.is_empty() { return Err(Error::InvalidParam); } + // ... +} +pub async fn my_func(client: &Client, param: &str) -> Result { + if param.is_empty() { return Err(Error::InvalidParam); } // duplicate! + // ... +} +``` + +```rust +// Good: shared validation in common/ +pub(crate) fn validate_param(param: &str) -> Result<(), Error> { + if param.is_empty() { return Err(Error::InvalidParam); } + Ok(()) +} + +// Usage in sync.rs and async.rs +validate_param(param)?; +``` + +### Monolithic Functions +```rust +// Bad: function does encoding, validation, and error handling +pub fn place_order(client: &Client, order: &Order) -> Result<(), Error> { + // 100+ lines of mixed concerns +} +``` + +```rust +// Good: split by responsibility +pub fn place_order(client: &Client, order: &Order) -> Result<(), Error> { + validate_order(order)?; + let request = encode_order(order)?; + send_and_handle_response(client, request) +} +``` + +### Large Parameter Lists +```rust +// Bad: 4+ params signal need for builder +fn create_order(action: Action, qty: f64, price: f64, tif: TimeInForce, + oca: Option, cond: Option) { } + +// Good: use builder pattern +order_builder::limit_order(action, qty, price) + .time_in_force(tif) + .oca_group(oca) + .condition(cond) + .build() +``` + ## Module Organization Each API module follows a consistent structure to support both sync and async modes: From b56466b32b3fa6b47deb36d14c9f96da689c5329 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 24 Jan 2026 19:23:51 -0800 Subject: [PATCH 11/18] Add Sec10 bar size variant (#380) Fixes #379 - 10-second bar granularity was missing from BarSize enum. --- src/market_data/historical/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/market_data/historical/mod.rs b/src/market_data/historical/mod.rs index 9a5d7b5e..79279ecf 100644 --- a/src/market_data/historical/mod.rs +++ b/src/market_data/historical/mod.rs @@ -65,6 +65,8 @@ pub enum BarSize { Sec, /// Five-second bars. Sec5, + /// Ten-second bars. + Sec10, /// Fifteen-second bars. Sec15, /// Thirty-second bars. @@ -106,6 +108,7 @@ impl Display for BarSize { match self { Self::Sec => write!(f, "1 secs"), Self::Sec5 => write!(f, "5 secs"), + Self::Sec10 => write!(f, "10 secs"), Self::Sec15 => write!(f, "15 secs"), Self::Sec30 => write!(f, "30 secs"), Self::Min => write!(f, "1 min"), From 8b0fb01ea99a5fe3e5653d338c5ecece50651cca Mon Sep 17 00:00:00 2001 From: Vitaly Kravchenko Date: Mon, 26 Jan 2026 22:05:28 +0000 Subject: [PATCH 12/18] Update BarSize::from_str and tests. (#381) --- src/market_data/historical/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/market_data/historical/mod.rs b/src/market_data/historical/mod.rs index 79279ecf..f61953e5 100644 --- a/src/market_data/historical/mod.rs +++ b/src/market_data/historical/mod.rs @@ -137,6 +137,7 @@ impl FromStr for BarSize { match s.to_uppercase().as_str() { "SEC" => Ok(Self::Sec), "SEC5" => Ok(Self::Sec5), + "SEC10" => Ok(Self::Sec10), "SEC15" => Ok(Self::Sec15), "SEC30" => Ok(Self::Sec30), "MIN" => Ok(Self::Min), @@ -572,6 +573,7 @@ mod tests { fn test_bar_size_to_string() { assert_eq!("1 secs", BarSize::Sec.to_string()); assert_eq!("5 secs", BarSize::Sec5.to_string()); + assert_eq!("10 secs", BarSize::Sec10.to_string()); assert_eq!("15 secs", BarSize::Sec15.to_string()); assert_eq!("30 secs", BarSize::Sec30.to_string()); assert_eq!("1 min", BarSize::Min.to_string()); @@ -595,6 +597,7 @@ mod tests { fn test_bar_size_from_string() { assert_eq!(BarSize::Sec, BarSize::from("SEC")); assert_eq!(BarSize::Sec5, BarSize::from("SEC5")); + assert_eq!(BarSize::Sec10, BarSize::from("SEC10")); assert_eq!(BarSize::Sec15, BarSize::from("SEC15")); assert_eq!(BarSize::Sec30, BarSize::from("SEC30")); assert_eq!(BarSize::Min, BarSize::from("MIN")); From 5e1a6bda91dc36ea51750674373ec5b9734a8a34 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Mon, 26 Jan 2026 14:36:47 -0800 Subject: [PATCH 13/18] fix: ForexBuilder sets symbol to base currency, currency to quote (#383) Fixes #382. Symbol was incorrectly set to "EUR.USD" format instead of just "EUR". Currency was hardcoded to "USD" instead of using the quote parameter. Also removes unused amount field. --- docs/contract-builder.md | 5 +---- examples/v2_contract_builder.rs | 2 +- src/contracts/builders.rs | 20 ++++++-------------- src/contracts/builders/tests.rs | 4 ++-- src/contracts/mod.rs | 4 +--- 5 files changed, 11 insertions(+), 24 deletions(-) diff --git a/docs/contract-builder.md b/docs/contract-builder.md index 138008f5..d8defa7f 100644 --- a/docs/contract-builder.md +++ b/docs/contract-builder.md @@ -173,13 +173,10 @@ Foreign exchange pairs with automatic pair formatting. use ibapi::contracts::Contract; // EUR/USD pair -let eur_usd = Contract::forex("EUR", "USD") - .amount(100_000) - .build(); +let eur_usd = Contract::forex("EUR", "USD").build(); // GBP/JPY with custom exchange let gbp_jpy = Contract::forex("GBP", "JPY") - .amount(50_000) .on_exchange("IDEALPRO") .build(); ``` diff --git a/examples/v2_contract_builder.rs b/examples/v2_contract_builder.rs index f47ccd40..42f6ad84 100644 --- a/examples/v2_contract_builder.rs +++ b/examples/v2_contract_builder.rs @@ -41,7 +41,7 @@ fn main() { ); // Forex pair - let eur_usd = Contract::forex("EUR", "USD").amount(100_000).build(); + let eur_usd = Contract::forex("EUR", "USD").build(); println!("Forex: {}", eur_usd.symbol); // Cryptocurrency diff --git a/src/contracts/builders.rs b/src/contracts/builders.rs index a6e75ed2..86c57326 100644 --- a/src/contracts/builders.rs +++ b/src/contracts/builders.rs @@ -300,29 +300,21 @@ impl FuturesBuilder { /// Forex pair builder #[derive(Debug, Clone)] pub struct ForexBuilder { - pair: String, + base: Currency, + quote: Currency, exchange: Exchange, - amount: u32, } impl ForexBuilder { /// Create a forex contract using the given base and quote currencies. pub fn new(base: impl Into, quote: impl Into) -> Self { - let base = base.into(); - let quote = quote.into(); ForexBuilder { - pair: format!("{}.{}", base, quote), + base: base.into(), + quote: quote.into(), exchange: "IDEALPRO".into(), - amount: 20_000, } } - /// Adjust the standard order amount. - pub fn amount(mut self, amount: u32) -> Self { - self.amount = amount; - self - } - /// Route the trade to a different forex venue. pub fn on_exchange(mut self, exchange: impl Into) -> Self { self.exchange = exchange.into(); @@ -332,10 +324,10 @@ impl ForexBuilder { /// Complete the forex contract definition. pub fn build(self) -> Contract { Contract { - symbol: Symbol::new(self.pair), + symbol: Symbol::new(self.base.0), security_type: SecurityType::ForexPair, exchange: self.exchange, - currency: "USD".into(), // Quote currency + currency: self.quote, ..Default::default() } } diff --git a/src/contracts/builders/tests.rs b/src/contracts/builders/tests.rs index 8530cdd5..18ece3bb 100644 --- a/src/contracts/builders/tests.rs +++ b/src/contracts/builders/tests.rs @@ -111,9 +111,9 @@ fn test_futures_multiplier() { #[test] fn test_forex_builder() { - let forex = Contract::forex("EUR", "USD").amount(100_000).on_exchange("IDEALPRO").build(); + let forex = Contract::forex("EUR", "USD").on_exchange("IDEALPRO").build(); - assert_eq!(forex.symbol, Symbol::from("EUR.USD")); + assert_eq!(forex.symbol, Symbol::from("EUR")); assert_eq!(forex.security_type, SecurityType::ForexPair); assert_eq!(forex.exchange, Exchange::from("IDEALPRO")); assert_eq!(forex.currency, Currency::from("USD")); diff --git a/src/contracts/mod.rs b/src/contracts/mod.rs index b4e20709..2202b601 100644 --- a/src/contracts/mod.rs +++ b/src/contracts/mod.rs @@ -288,9 +288,7 @@ impl Contract { /// ``` /// use ibapi::contracts::{Contract, Currency}; /// - /// let eur_usd = Contract::forex("EUR", "USD") - /// .amount(100_000) - /// .build(); + /// let eur_usd = Contract::forex("EUR", "USD").build(); /// ``` pub fn forex(base: impl Into, quote: impl Into) -> ForexBuilder { ForexBuilder::new(base, quote) From 7770ede9a90061b20f1b1c330a434003f87b58ce Mon Sep 17 00:00:00 2001 From: mysyzygy Date: Fri, 30 Jan 2026 19:38:53 -0800 Subject: [PATCH 14/18] Add historical_data_streaming() with keepUpToDate=true support (#384) * feat: add historical_data_streaming() with keepUpToDate=true support Add streaming historical data support for IBKR's reqHistoricalData API with keepUpToDate parameter enabled. This allows receiving continuous bar updates as they build, enabling live trading strategies to receive bars at any IBKR-supported resolution. Changes: - Add HistoricalBarUpdate enum to represent streaming updates - Add decode_historical_data_update() for message type 90 - Add historical_data_streaming() async function that sets keepUpToDate=true - Add HistoricalDataStreamingSubscription for handling streaming responses IBKR behavior notes: - Same timestamp bars are sent ~4-6 seconds apart as they build - When a NEW timestamp appears, the previous bar is complete - Supported what_to_show: Trades, Midpoint, Bid, Ask only - end_date must be None when keepUpToDate=true * add missing historical_data_streaming function * Change channel-closed logs from warn to info * fixed historical_data_streaming to properly received historical streaming data with keep_up_to_date=true * fixed formatting and doc tests * Revert "Change channel-closed logs from warn to" This reverts commit 065a65e8c3de5d0ac3da2a9d5ff2d5512f66eaae. * fix: use proper contract builders in example - Use Contract::forex() builder instead of manual field assignment - Use struct initialization for futures query to satisfy clippy --------- Co-authored-by: Wil Boayue --- Cargo.toml | 2 +- examples/async/historical_data.rs | 358 +++++++++++------ src/client/async.rs | 55 +++ src/lib.rs | 2 +- src/market_data/historical/async.rs | 365 +++++++++++++++++- src/market_data/historical/common/decoders.rs | 85 ++++ src/market_data/historical/mod.rs | 19 +- src/messages.rs | 1 + src/orders/builder/order_builder.rs | 4 +- 9 files changed, 768 insertions(+), 123 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8ea17cb5..0e6aea58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ async-trait = { version = "0.1.83", optional = true } [dev-dependencies] anyhow = "1.0.92" -clap = "4.5.20" +clap = { version = "4.5.20", features = ["derive"] } env_logger = "0.11.5" pretty_assertions = "1.4.1" tempfile = "3.13" diff --git a/examples/async/historical_data.rs b/examples/async/historical_data.rs index cc9889bf..a9db50f4 100644 --- a/examples/async/historical_data.rs +++ b/examples/async/historical_data.rs @@ -9,76 +9,129 @@ //! Make sure IB Gateway or TWS is running with API connections enabled, then run: //! //! ```bash -//! cargo run --features async --example async_historical_data +//! cargo run --example async_historical_data +//! cargo run --example async_historical_data -- --asset forex +//! cargo run --example async_historical_data -- --asset futures //! ``` //! //! # Configuration //! //! - Adjust the connection address if needed (default: 127.0.0.1:4002) -//! - Change the stock symbol if desired (default: AAPL) +//! - Use --asset to select: stock (AAPL), forex (EUR.USD), or futures (ES) //! - Modify duration and bar size to get different data periods use std::sync::Arc; +use clap::{Parser, ValueEnum}; +use ibapi::contracts::SecurityType; +use ibapi::market_data::historical::HistoricalBarUpdate; use ibapi::prelude::*; use time::OffsetDateTime; +#[derive(Parser)] +#[command(name = "historical_data")] +#[command(about = "Fetch historical bar data from IB")] +struct Args { + /// Asset type to use + #[arg(long, value_enum, default_value = "stock")] + asset: AssetType, + + /// Skip to streaming test only + #[arg(long, short = 's')] + streaming_only: bool, +} + +#[derive(Clone, Debug, ValueEnum)] +enum AssetType { + Stock, + Forex, + Futures, +} + #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init(); + let args = Args::parse(); // Connect to IB Gateway let client = Arc::new(Client::connect("127.0.0.1:4002", 100).await?); println!("Connected to IB Gateway"); - // Create a stock contract - let contract = Contract::stock("AAPL").build(); - println!("Requesting historical data for {}\n", contract.symbol); + // Create contract based on asset type + let contract = match args.asset { + AssetType::Stock => Contract::stock("AAPL").build(), + AssetType::Forex => Contract::forex("EUR", "USD").build(), + AssetType::Futures => { + // For futures, we need to resolve the front-month contract via contract_details + println!("Resolving front-month contract for ES..."); + let query = Contract { + symbol: "ES".into(), + security_type: SecurityType::Future, + exchange: "CME".into(), + currency: "USD".into(), + ..Default::default() + }; - // Example 1: Get the earliest available data timestamp - println!("=== Head Timestamp ==="); - let head_timestamp = client - .head_timestamp(&contract, HistoricalWhatToShow::Trades, TradingHours::Regular) - .await?; - println!("Earliest available historical data: {head_timestamp:?}"); + let details = client.contract_details(&query).await?; + if details.is_empty() { + return Err("No futures contracts found for ES".into()); + } - // Example 2: Get recent intraday data (5-minute bars for last day) - println!("\n=== Recent Intraday Data (5-min bars) ==="); - let end_date = OffsetDateTime::now_utc(); - let historical_data = client - .historical_data( - &contract, - Some(end_date), - 1.days(), // Duration: 1 day - HistoricalBarSize::Min5, // 5-minute bars - Some(HistoricalWhatToShow::Trades), // Trade data - TradingHours::Regular, // Use regular trading hours - ) - .await?; + // Sort by contract month and take front-month + let mut sorted: Vec<_> = details + .into_iter() + .filter(|d| !d.contract.last_trade_date_or_contract_month.is_empty()) + .collect(); + sorted.sort_by(|a, b| { + a.contract + .last_trade_date_or_contract_month + .cmp(&b.contract.last_trade_date_or_contract_month) + }); - println!("Period: {} to {}", historical_data.start, historical_data.end); - println!("Total bars: {}", historical_data.bars.len()); - - // Show first 5 and last 5 bars - for (i, bar) in historical_data.bars.iter().take(5).enumerate() { - println!( - "Bar {}: {} - O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}, V: {:.0}", - i + 1, - format!("{:02}:{:02}", bar.date.hour(), bar.date.minute()), - bar.open, - bar.high, - bar.low, - bar.close, - bar.volume - ); - } - if historical_data.bars.len() > 10 { - println!("..."); - let start_idx = historical_data.bars.len() - 5; - for (i, bar) in historical_data.bars.iter().skip(start_idx).enumerate() { + let front = sorted.into_iter().next().expect("No valid contracts"); + println!( + " Found front-month: local_symbol='{}', contract_month='{}'", + front.contract.local_symbol, front.contract.last_trade_date_or_contract_month + ); + front.contract + } + }; + println!("Requesting historical data for {} ({:?})", contract.symbol, args.asset); + println!( + " local_symbol: '{}', exchange: {}, contract_month: '{}'\n", + contract.local_symbol, contract.exchange, contract.last_trade_date_or_contract_month + ); + + if !args.streaming_only { + // Example 1: Get the earliest available data timestamp + println!("=== Head Timestamp ==="); + let head_timestamp = client + .head_timestamp(&contract, HistoricalWhatToShow::Trades, TradingHours::Regular) + .await?; + println!("Earliest available historical data: {head_timestamp:?}"); + + // Example 2: Get recent intraday data (5-minute bars for last day) + println!("\n=== Recent Intraday Data (5-min bars) ==="); + let end_date = OffsetDateTime::now_utc(); + let historical_data = client + .historical_data( + &contract, + Some(end_date), + 1.days(), // Duration: 1 day + HistoricalBarSize::Min5, // 5-minute bars + Some(HistoricalWhatToShow::Trades), // Trade data + TradingHours::Regular, // Use regular trading hours + ) + .await?; + + println!("Period: {} to {}", historical_data.start, historical_data.end); + println!("Total bars: {}", historical_data.bars.len()); + + // Show first 5 and last 5 bars + for (i, bar) in historical_data.bars.iter().take(5).enumerate() { println!( "Bar {}: {} - O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}, V: {:.0}", - start_idx + i + 1, + i + 1, format!("{:02}:{:02}", bar.date.hour(), bar.date.minute()), bar.open, bar.high, @@ -87,90 +140,161 @@ async fn main() -> Result<(), Box> { bar.volume ); } - } + if historical_data.bars.len() > 10 { + println!("..."); + let start_idx = historical_data.bars.len() - 5; + for (i, bar) in historical_data.bars.iter().skip(start_idx).enumerate() { + println!( + "Bar {}: {} - O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}, V: {:.0}", + start_idx + i + 1, + format!("{:02}:{:02}", bar.date.hour(), bar.date.minute()), + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume + ); + } + } - // Example 3: Get daily data for past month - println!("\n=== Daily Data (past month) ==="); - let daily_data = client - .historical_data( - &contract, - Some(end_date), - 1.months(), // Duration: 1 month - HistoricalBarSize::Day, // Daily bars - Some(HistoricalWhatToShow::Trades), // Trade data - TradingHours::Regular, // Use regular trading hours - ) - .await?; + // Example 3: Get daily data for past month + println!("\n=== Daily Data (past month) ==="); + let daily_data = client + .historical_data( + &contract, + Some(end_date), + 1.months(), // Duration: 1 month + HistoricalBarSize::Day, // Daily bars + Some(HistoricalWhatToShow::Trades), // Trade data + TradingHours::Regular, // Use regular trading hours + ) + .await?; - println!("Daily bars received: {}", daily_data.bars.len()); - for bar in daily_data.bars.iter().take(5) { - println!( - "{}: O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}, V: {:.0}K", - format!("{:04}-{:02}-{:02}", bar.date.year(), bar.date.month() as u8, bar.date.day()), - bar.open, - bar.high, - bar.low, - bar.close, - bar.volume / 1000.0 - ); - } + println!("Daily bars received: {}", daily_data.bars.len()); + for bar in daily_data.bars.iter().take(5) { + println!( + "{}: O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}, V: {:.0}K", + format!("{:04}-{:02}-{:02}", bar.date.year(), bar.date.month() as u8, bar.date.day()), + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume / 1000.0 + ); + } - // Example 4: Get different data types - println!("\n=== Different Data Types ==="); + // Example 4: Get different data types + println!("\n=== Different Data Types ==="); - // Bid data (last 1 day) - let bid_data = client - .historical_data( - &contract, - Some(end_date), - 1.days(), - HistoricalBarSize::Min, - Some(HistoricalWhatToShow::Bid), - TradingHours::Regular, - ) - .await?; - println!("Bid bars (1-min): {} bars", bid_data.bars.len()); - if let Some(bar) = bid_data.bars.first() { - println!( - " First bar: {:02}:{:02}:{:02} - Bid: ${:.2}", - bar.date.hour(), - bar.date.minute(), - bar.date.second(), - bar.close - ); + // Bid data (last 1 day) + let bid_data = client + .historical_data( + &contract, + Some(end_date), + 1.days(), + HistoricalBarSize::Min, + Some(HistoricalWhatToShow::Bid), + TradingHours::Regular, + ) + .await?; + println!("Bid bars (1-min): {} bars", bid_data.bars.len()); + if let Some(bar) = bid_data.bars.first() { + println!( + " First bar: {:02}:{:02}:{:02} - Bid: ${:.2}", + bar.date.hour(), + bar.date.minute(), + bar.date.second(), + bar.close + ); + } + + // Ask data (last 1 day) + let ask_data = client + .historical_data( + &contract, + Some(end_date), + 1.days(), + HistoricalBarSize::Min, + Some(HistoricalWhatToShow::Ask), + TradingHours::Regular, + ) + .await?; + println!("Ask bars (1-min): {} bars", ask_data.bars.len()); + if let Some(bar) = ask_data.bars.first() { + println!( + " First bar: {:02}:{:02}:{:02} - Ask: ${:.2}", + bar.date.hour(), + bar.date.minute(), + bar.date.second(), + bar.close + ); + } + + // Example 5: Get histogram data + println!("\n=== Histogram Data ==="); + let histogram = client.histogram_data(&contract, TradingHours::Regular, HistoricalBarSize::Day).await?; + + println!("Histogram entries: {}", histogram.len()); + for entry in histogram.iter().take(5) { + println!(" Price: ${:.2}, Size: {}", entry.price, entry.size); + } } - // Ask data (last 1 day) - let ask_data = client - .historical_data( + // Example 6: Streaming historical data with keepUpToDate=true + println!("\n=== Streaming Historical Data (keepUpToDate=true) ==="); + println!("Press Ctrl+C to stop streaming...\n"); + + // Use appropriate data type per asset + let what_to_show = match args.asset { + AssetType::Forex => HistoricalWhatToShow::MidPoint, + _ => HistoricalWhatToShow::Trades, + }; + + let mut subscription = client + .historical_data_streaming( &contract, - Some(end_date), - 1.days(), - HistoricalBarSize::Min, - Some(HistoricalWhatToShow::Ask), - TradingHours::Regular, + 1.days(), // Duration: 1 day of history + HistoricalBarSize::Min, // 1-minute bars + Some(what_to_show), + TradingHours::Extended, + true, // keep_up_to_date: stream live updates ) .await?; - println!("Ask bars (1-min): {} bars", ask_data.bars.len()); - if let Some(bar) = ask_data.bars.first() { - println!( - " First bar: {:02}:{:02}:{:02} - Ask: ${:.2}", - bar.date.hour(), - bar.date.minute(), - bar.date.second(), - bar.close - ); - } - // Example 5: Get histogram data - println!("\n=== Histogram Data ==="); - let histogram = client.histogram_data(&contract, TradingHours::Regular, HistoricalBarSize::Day).await?; - - println!("Histogram entries: {}", histogram.len()); - for entry in histogram.iter().take(5) { - println!(" Price: ${:.2}, Size: {}", entry.price, entry.size); + while let Some(update) = subscription.next().await { + match update { + HistoricalBarUpdate::Historical(data) => { + println!("Received {} initial historical bars", data.bars.len()); + if let Some(bar) = data.bars.last() { + println!( + " Latest: {} - O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}", + format!("{:02}:{:02}", bar.date.hour(), bar.date.minute()), + bar.open, + bar.high, + bar.low, + bar.close + ); + } + } + HistoricalBarUpdate::HistoricalEnd => { + println!("Initial historical data complete. Now streaming updates..."); + } + HistoricalBarUpdate::Update(bar) => { + println!( + "UPDATE: {} - O: ${:.2}, H: ${:.2}, L: ${:.2}, C: ${:.2}, V: {:.0}", + format!("{:02}:{:02}:{:02}", bar.date.hour(), bar.date.minute(), bar.date.second()), + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume + ); + } + } } + println!("Stream ended"); + println!("\nHistorical data example completed!"); Ok(()) } diff --git a/src/client/async.rs b/src/client/async.rs index a634283e..13c081ba 100644 --- a/src/client/async.rs +++ b/src/client/async.rs @@ -1039,6 +1039,61 @@ impl Client { crate::market_data::historical::historical_data(self, contract, end_date, duration, bar_size, what_to_show, trading_hours).await } + /// Requests historical data with optional streaming updates. + /// + /// This method returns a subscription that first yields the initial historical bars. + /// When `keep_up_to_date` is `true`, it continues to yield streaming updates for + /// the current bar as it builds. IBKR sends updated bars every ~4-6 seconds until + /// the bar completes. + /// + /// # Arguments + /// * `contract` - Contract object that is subject of query + /// * `duration` - The amount of time for which the data needs to be retrieved + /// * `bar_size` - The bar size (resolution) + /// * `what_to_show` - The type of data to retrieve (Trades, MidPoint, etc.) + /// * `trading_hours` - Regular trading hours only, or include extended hours + /// * `keep_up_to_date` - If true, continue receiving streaming updates after initial data + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::contracts::Contract; + /// use ibapi::Client; + /// use ibapi::market_data::historical::{ToDuration, BarSize, WhatToShow, HistoricalBarUpdate}; + /// use ibapi::market_data::TradingHours; + /// + /// #[tokio::main] + /// async fn main() { + /// let client = Client::connect("127.0.0.1:4002", 100).await.expect("connection failed"); + /// let contract = Contract::stock("SPY").build(); + /// + /// let mut subscription = client + /// .historical_data_streaming(&contract, 3.days(), BarSize::Min15, Some(WhatToShow::Trades), TradingHours::Extended, true) + /// .await + /// .expect("streaming request failed"); + /// + /// while let Some(update) = subscription.next().await { + /// match update { + /// HistoricalBarUpdate::Historical(data) => println!("Initial bars: {}", data.bars.len()), + /// HistoricalBarUpdate::Update(bar) => println!("Streaming update: {:?}", bar), + /// HistoricalBarUpdate::HistoricalEnd => println!("Initial data complete"), + /// } + /// } + /// } + /// ``` + pub async fn historical_data_streaming( + &self, + contract: &crate::contracts::Contract, + duration: crate::market_data::historical::Duration, + bar_size: crate::market_data::historical::BarSize, + what_to_show: Option, + trading_hours: TradingHours, + keep_up_to_date: bool, + ) -> Result { + crate::market_data::historical::historical_data_streaming(self, contract, duration, bar_size, what_to_show, trading_hours, keep_up_to_date) + .await + } + /// Requests historical schedule. /// /// # Arguments diff --git a/src/lib.rs b/src/lib.rs index de135a52..66c324e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,7 +60,7 @@ pub(crate) mod connection; /// /// # Example /// -/// ```no_run +/// ```ignore /// use ibapi::{Client, StartupMessageCallback}; /// use ibapi::messages::IncomingMessages; /// use std::sync::{Arc, Mutex}; diff --git a/src/market_data/historical/async.rs b/src/market_data/historical/async.rs index b4977321..3e4e8e02 100644 --- a/src/market_data/historical/async.rs +++ b/src/market_data/historical/async.rs @@ -1,6 +1,7 @@ use log::{debug, warn}; use std::collections::VecDeque; use time::OffsetDateTime; +use time_tz::Tz; use crate::client::ClientRequestBuilders; use crate::contracts::Contract; @@ -10,7 +11,9 @@ use crate::transport::AsyncInternalSubscription; use crate::{Client, Error, MAX_RETRIES}; use super::common::{decoders, encoders}; -use super::{BarSize, Duration, HistogramEntry, HistoricalData, Schedule, TickBidAsk, TickDecoder, TickLast, TickMidpoint, WhatToShow}; +use super::{ + BarSize, Duration, HistogramEntry, HistoricalBarUpdate, HistoricalData, Schedule, TickBidAsk, TickDecoder, TickLast, TickMidpoint, WhatToShow, +}; use crate::market_data::TradingHours; // === Public API Functions === @@ -323,6 +326,200 @@ impl + Send> TickSubscription { } } +// === Historical Data Streaming with keepUpToDate === + +/// Requests historical data for a contract with optional streaming updates. +/// +/// When `keep_up_to_date` is `true`, this function requests historical bars and then +/// continues to receive streaming updates for the current (incomplete) bar. IBKR sends +/// updates approximately every 4-6 seconds until the bar completes, at which point a +/// new bar begins. +/// +/// When `keep_up_to_date` is `false`, only the initial historical data is returned +/// and the subscription ends after delivering the data. +/// +/// **Important IBKR behavior (when keepUpToDate=true):** +/// - The same timestamp bar is sent multiple times as it builds (with updated OHLCV) +/// - When a NEW timestamp appears, the previous bar is considered complete +/// - Supported `what_to_show` values: Trades, Midpoint, Bid, Ask only +/// +/// # Arguments +/// * `client` - The IBKR client connection +/// * `contract` - The contract to request data for +/// * `duration` - How far back to request initial historical data +/// * `bar_size` - The bar size (e.g., Min15 for 15-minute bars) +/// * `what_to_show` - The data type (Trades, Midpoint, Bid, or Ask) +/// * `trading_hours` - Whether to use regular trading hours only +/// * `keep_up_to_date` - If true, continue receiving streaming updates after initial data +/// +/// # Returns +/// A `HistoricalDataStreamingSubscription` that yields `HistoricalBarUpdate` values +/// +/// # Example +/// ```ignore +/// let mut subscription = historical_data_streaming( +/// &client, +/// &contract, +/// Duration::days(1), +/// BarSize::Min15, +/// Some(WhatToShow::Trades), +/// TradingHours::Regular, +/// true, // keep_up_to_date +/// ).await?; +/// +/// while let Some(update) = subscription.next().await { +/// match update { +/// HistoricalBarUpdate::Historical(data) => { +/// println!("Received {} historical bars", data.bars.len()); +/// } +/// HistoricalBarUpdate::Update(bar) => { +/// println!("Bar update: {} close={}", bar.date, bar.close); +/// } +/// HistoricalBarUpdate::HistoricalEnd => { +/// println!("Initial historical data complete, now streaming"); +/// } +/// } +/// } +/// ``` +/// +/// # See Also +/// * [IBKR Campus - keepUpToDate](https://ibkrcampus.com/campus/ibkr-api-page/twsapi-doc/#hist-keepUp-date) +pub async fn historical_data_streaming( + client: &Client, + contract: &Contract, + duration: Duration, + bar_size: BarSize, + what_to_show: Option, + trading_hours: TradingHours, + keep_up_to_date: bool, +) -> Result { + if !contract.trading_class.is_empty() || contract.contract_id > 0 { + check_version(client.server_version(), Features::TRADING_CLASS)?; + } + + // Note: end_date must be None when keepUpToDate=true (IBKR requirement) + let builder = client.request(); + let request = encoders::encode_request_historical_data( + client.server_version(), + builder.request_id(), + contract, + None, // end_date must be None for keepUpToDate + duration, + bar_size, + what_to_show, + trading_hours.use_rth(), + keep_up_to_date, + Vec::::default(), + )?; + + let subscription = builder.send_raw(request).await?; + + // Get the timezone directly to avoid lifetime issues + // time_zone(client) returns a reference tied to client's lifetime, + // but we need a 'static reference for the subscription struct + let tz: &'static Tz = client.time_zone.unwrap_or_else(|| { + warn!("server timezone unknown. assuming UTC, but that may be incorrect!"); + time_tz::timezones::db::UTC + }); + + Ok(HistoricalDataStreamingSubscription::new(subscription, client.server_version(), tz)) +} + +/// Async subscription for streaming historical data with keepUpToDate=true. +/// +/// This subscription first yields the initial historical bars, then continues +/// to yield streaming updates for the current bar as it builds. +pub struct HistoricalDataStreamingSubscription { + messages: AsyncInternalSubscription, + server_version: i32, + time_zone: &'static Tz, + pending_end: bool, + error: Option, +} + +impl HistoricalDataStreamingSubscription { + fn new(messages: AsyncInternalSubscription, server_version: i32, time_zone: &'static Tz) -> Self { + Self { + messages, + server_version, + time_zone, + pending_end: false, + error: None, + } + } + + /// Get the next update from the streaming subscription. + /// + /// Returns: + /// - `Some(HistoricalBarUpdate::Historical(data))` - Initial batch of historical bars + /// - `Some(HistoricalBarUpdate::HistoricalEnd)` - End of initial historical data + /// - `Some(HistoricalBarUpdate::Update(bar))` - Streaming bar update + /// - `None` - Subscription ended (connection closed or error) + pub async fn next(&mut self) -> Option { + // Emit HistoricalEnd after Historical data was returned + if self.pending_end { + self.pending_end = false; + return Some(HistoricalBarUpdate::HistoricalEnd); + } + + loop { + match self.messages.next().await { + Some(Ok(mut message)) => { + match message.message_type() { + IncomingMessages::HistoricalData => { + // Initial historical data batch + match decoders::decode_historical_data(self.server_version, self.time_zone, &mut message) { + Ok(data) => { + self.pending_end = true; + return Some(HistoricalBarUpdate::Historical(data)); + } + Err(e) => { + self.error = Some(e); + return None; + } + } + } + IncomingMessages::HistoricalDataUpdate => { + // Streaming bar update + match decoders::decode_historical_data_update(self.time_zone, &mut message) { + Ok(bar) => { + return Some(HistoricalBarUpdate::Update(bar)); + } + Err(e) => { + self.error = Some(e); + return None; + } + } + } + IncomingMessages::Error => { + self.error = Some(Error::from(message)); + return None; + } + _ => { + // Skip unexpected messages + debug!("unexpected message in streaming subscription: {:?}", message.message_type()); + continue; + } + } + } + Some(Err(e)) => { + self.error = Some(e); + return None; + } + None => { + // Channel closed + return None; + } + } + } + } + + /// Returns the last error that occurred, if any. + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } +} + #[cfg(test)] mod tests { use super::*; @@ -949,4 +1146,170 @@ mod tests { let tz = time_zone(&client); assert_eq!(tz, time_tz::timezones::db::UTC, "Should fallback to UTC when timezone not set"); } + + #[tokio::test] + async fn test_historical_data_streaming_with_updates() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + // Initial historical data (message type 17) + "17|9000|20230315 09:30:00|20230315 10:30:00|1|1678886400|185.50|186.00|185.25|185.75|1000|185.70|100|".to_owned(), + // Streaming update (message type 90) + "90|9000|-1|1678890000|185.80|186.10|185.60|185.90|500|185.85|50|".to_owned(), + ], + }); + + let mut client = Client::stubbed(message_bus.clone(), server_versions::SIZE_RULES); + client.time_zone = Some(time_tz::timezones::db::UTC); + + let contract = Contract::stock("SPY").build(); + + let mut subscription = historical_data_streaming( + &client, + &contract, + Duration::days(1), + BarSize::Hour, + Some(WhatToShow::Trades), + TradingHours::Regular, + true, + ) + .await + .expect("streaming request should succeed"); + + // First: receive initial historical data + let update1 = subscription.next().await; + assert!(update1.is_some(), "Should receive initial historical data"); + match update1.unwrap() { + HistoricalBarUpdate::Historical(data) => { + assert_eq!(data.bars.len(), 1, "Should have 1 initial bar"); + assert_eq!(data.bars[0].open, 185.50, "Wrong open price"); + } + _ => panic!("Expected Historical variant"), + } + + // Second: receive HistoricalEnd marker + let update2 = subscription.next().await; + assert!(update2.is_some(), "Should receive HistoricalEnd"); + match update2.unwrap() { + HistoricalBarUpdate::HistoricalEnd => {} + _ => panic!("Expected HistoricalEnd variant"), + } + + // Third: receive streaming update + let update3 = subscription.next().await; + assert!(update3.is_some(), "Should receive streaming update"); + match update3.unwrap() { + HistoricalBarUpdate::Update(bar) => { + assert_eq!(bar.open, 185.80, "Wrong open price in update"); + assert_eq!(bar.high, 186.10, "Wrong high price in update"); + assert_eq!(bar.close, 185.90, "Wrong close price in update"); + } + _ => panic!("Expected Update variant"), + } + + // Verify request message includes keepUpToDate=true + let request_messages = message_bus.request_messages.read().unwrap(); + assert_eq!(request_messages.len(), 1, "Should send one request"); + // The keepUpToDate field should be "1" (true) + assert!( + request_messages[0].fields.contains(&"1".to_string()), + "Request should have keepUpToDate=true" + ); + } + + #[tokio::test] + async fn test_historical_data_streaming_keep_up_to_date_false() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + // Initial historical data only + "17|9000|20230315 09:30:00|20230315 10:30:00|1|1678886400|185.50|186.00|185.25|185.75|1000|185.70|100|".to_owned(), + ], + }); + + let mut client = Client::stubbed(message_bus.clone(), server_versions::SIZE_RULES); + client.time_zone = Some(time_tz::timezones::db::UTC); + + let contract = Contract::stock("SPY").build(); + + let mut subscription = historical_data_streaming( + &client, + &contract, + Duration::days(1), + BarSize::Hour, + Some(WhatToShow::Trades), + TradingHours::Regular, + false, // keep_up_to_date = false + ) + .await + .expect("streaming request should succeed"); + + // Receive initial historical data + let update1 = subscription.next().await; + assert!(update1.is_some(), "Should receive initial historical data"); + match update1.unwrap() { + HistoricalBarUpdate::Historical(data) => { + assert_eq!(data.bars.len(), 1, "Should have 1 initial bar"); + } + _ => panic!("Expected Historical variant"), + } + + // Receive HistoricalEnd marker + let update2 = subscription.next().await; + assert!(update2.is_some(), "Should receive HistoricalEnd"); + match update2.unwrap() { + HistoricalBarUpdate::HistoricalEnd => {} + _ => panic!("Expected HistoricalEnd variant"), + } + + // Verify request message includes keepUpToDate=false + let request_messages = message_bus.request_messages.read().unwrap(); + assert_eq!(request_messages.len(), 1, "Should send one request"); + // Find the keepUpToDate field - it should be "0" (false) + // The field order in historical data request puts keepUpToDate near the end + let request = &request_messages[0]; + // Check the last few fields for the "0" value + let fields_str = request.fields.join("|"); + assert!(fields_str.contains("|0|"), "Request should have keepUpToDate=false (0)"); + } + + #[tokio::test] + async fn test_historical_data_streaming_error_response() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + // Error response + "4|2|9000|162|Historical Market Data Service error message:No market data permissions.|".to_owned(), + ], + }); + + let mut client = Client::stubbed(message_bus, server_versions::SIZE_RULES); + client.time_zone = Some(time_tz::timezones::db::UTC); + + let contract = Contract::stock("SPY").build(); + + let mut subscription = historical_data_streaming( + &client, + &contract, + Duration::days(1), + BarSize::Hour, + Some(WhatToShow::Trades), + TradingHours::Regular, + true, + ) + .await + .expect("streaming request should succeed"); + + // Should return None due to error + let update = subscription.next().await; + assert!(update.is_none(), "Should return None on error"); + + // Error should be accessible + let error = subscription.error(); + assert!(error.is_some(), "Error should be stored"); + assert!( + error.unwrap().to_string().contains("No market data permissions"), + "Error should contain the message" + ); + } } diff --git a/src/market_data/historical/common/decoders.rs b/src/market_data/historical/common/decoders.rs index 9c319ff1..c4397c56 100644 --- a/src/market_data/historical/common/decoders.rs +++ b/src/market_data/historical/common/decoders.rs @@ -221,6 +221,51 @@ pub(crate) fn decode_histogram_data(message: &mut ResponseMessage) -> Result Result { + message.skip(); // message type + message.skip(); // request_id + message.skip(); // bar_count (always -1 for updates) + + let date = message.next_string()?; + let open = message.next_double()?; + let high = message.next_double()?; + let low = message.next_double()?; + let close = message.next_double()?; + let volume = message.next_double()?; + let wap = message.next_double()?; + // count field is optional in streaming updates - may not be present + let count = message.next_int().unwrap_or(0); + + Ok(Bar { + date: parse_bar_date(&date, time_zone)?, + open, + high, + low, + close, + volume, + wap, + count, + }) +} + fn parse_time_zone(name: &str) -> &Tz { let zones = find_timezone(name); if zones.is_empty() { @@ -436,4 +481,44 @@ mod tests { assert_eq!(ticks[23].price, 91.31, "ticks[0].price"); assert_eq!(ticks[23].size, 0, "ticks[0].size"); } + + #[cfg(feature = "async")] + #[test] + fn test_decode_historical_data_update() { + let time_zone: &Tz = time_tz::timezones::db::america::NEW_YORK; + + // Message format: message_type|request_id|bar_count|timestamp|open|high|low|close|volume|wap|count + let mut message = ResponseMessage::from("90\09000\0-1\01681133400\0185.50\0186.00\0185.00\0185.75\01000.5\0185.625\0150\0"); + + let bar = decode_historical_data_update(time_zone, &mut message).expect("error decoding historical data update"); + + assert_eq!(bar.date, datetime!(2023-04-10 13:30:00 UTC), "bar.date"); + assert_eq!(bar.open, 185.50, "bar.open"); + assert_eq!(bar.high, 186.00, "bar.high"); + assert_eq!(bar.low, 185.00, "bar.low"); + assert_eq!(bar.close, 185.75, "bar.close"); + assert_eq!(bar.volume, 1000.5, "bar.volume"); + assert_eq!(bar.wap, 185.625, "bar.wap"); + assert_eq!(bar.count, 150, "bar.count"); + } + + #[cfg(feature = "async")] + #[test] + fn test_decode_historical_data_update_without_count() { + let time_zone: &Tz = time_tz::timezones::db::america::NEW_YORK; + + // Message without count field (optional in streaming updates) + let mut message = ResponseMessage::from("90\09000\0-1\01681133400\0185.50\0186.00\0185.00\0185.75\01000.5\0185.625\0"); + + let bar = decode_historical_data_update(time_zone, &mut message).expect("error decoding historical data update"); + + assert_eq!(bar.date, datetime!(2023-04-10 13:30:00 UTC), "bar.date"); + assert_eq!(bar.open, 185.50, "bar.open"); + assert_eq!(bar.high, 186.00, "bar.high"); + assert_eq!(bar.low, 185.00, "bar.low"); + assert_eq!(bar.close, 185.75, "bar.close"); + assert_eq!(bar.volume, 1000.5, "bar.volume"); + assert_eq!(bar.wap, 185.625, "bar.wap"); + assert_eq!(bar.count, 0, "bar.count should default to 0 when missing"); + } } diff --git a/src/market_data/historical/mod.rs b/src/market_data/historical/mod.rs index f61953e5..7764ce06 100644 --- a/src/market_data/historical/mod.rs +++ b/src/market_data/historical/mod.rs @@ -326,6 +326,23 @@ pub struct HistoricalData { pub bars: Vec, } +/// Update from historical data streaming with keepUpToDate=true. +/// +/// When requesting historical data with `keepUpToDate=true`, IBKR first sends +/// the historical bars, then continues streaming updates for the current bar. +/// The current bar is updated approximately every 4-6 seconds until a new +/// bar begins. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum HistoricalBarUpdate { + /// Initial batch of historical bars. + Historical(HistoricalData), + /// Real-time update of the current (incomplete) bar. + /// Note: Multiple updates with the same timestamp will be sent as the bar builds. + Update(Bar), + /// Signals the end of the initial historical data batch. + HistoricalEnd, +} + /// Trading schedule describing sessions for a contract. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct Schedule { @@ -563,7 +580,7 @@ impl TickDecoder for TickMidpoint { pub use sync::{TickSubscription, TickSubscriptionIter, TickSubscriptionOwnedIter, TickSubscriptionTimeoutIter, TickSubscriptionTryIter}; #[cfg(feature = "async")] -pub use r#async::TickSubscription; +pub use r#async::{historical_data_streaming, HistoricalDataStreamingSubscription, TickSubscription}; #[cfg(test)] mod tests { diff --git a/src/messages.rs b/src/messages.rs index 0d460407..132ab981 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -410,6 +410,7 @@ pub fn request_id_index(kind: IncomingMessages) -> Option { IncomingMessages::HeadTimestamp => Some(1), IncomingMessages::HistogramData => Some(1), IncomingMessages::HistoricalData => Some(1), + IncomingMessages::HistoricalDataUpdate => Some(1), IncomingMessages::HistoricalNews => Some(1), IncomingMessages::HistoricalNewsEnd => Some(1), IncomingMessages::HistoricalSchedule => Some(1), diff --git a/src/orders/builder/order_builder.rs b/src/orders/builder/order_builder.rs index 072262fc..c57804bf 100644 --- a/src/orders/builder/order_builder.rs +++ b/src/orders/builder/order_builder.rs @@ -529,7 +529,7 @@ impl<'a, C> OrderBuilder<'a, C> { /// /// # Example with builder /// - /// ```no_run + /// ```ignore /// # async fn example() -> Result<(), Box> { /// # use ibapi::client::Client; /// # use ibapi::contracts::Contract; @@ -552,7 +552,7 @@ impl<'a, C> OrderBuilder<'a, C> { /// /// # Example with string (for custom strategies) /// - /// ```no_run + /// ```ignore /// # async fn example() -> Result<(), Box> { /// # use ibapi::client::Client; /// # use ibapi::contracts::Contract; From d727fa38157553247e7a3cd638c4358c58762486 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Fri, 30 Jan 2026 21:54:57 -0800 Subject: [PATCH 15/18] Add sync historical_data_streaming(), address PR #384 feedback (#385) * Add sync historical_data_streaming(), address PR #384 feedback - Add blocking HistoricalDataStreamingSubscription with next/try_next/next_timeout - Add Client::historical_data_streaming() for sync API - Remove HistoricalEnd variant from HistoricalBarUpdate enum - Fix doc example imports to use prelude types - Replace fragile contains() test assertions with field index checks - Remove unnecessary #[cfg(feature = "async")] from decoder * Export historical_data_streaming doctest (ignore -> no_run) * Bump version to 2.7.0 --- Cargo.toml | 2 +- examples/async/historical_data.rs | 4 +- src/client/async.rs | 10 +- src/client/sync.rs | 52 +++ src/market_data/historical/async.rs | 69 ++-- src/market_data/historical/common/decoders.rs | 3 - src/market_data/historical/mod.rs | 14 +- src/market_data/historical/sync.rs | 312 +++++++++++++++++- 8 files changed, 402 insertions(+), 64 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0e6aea58..0172ac59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ibapi" -version = "2.6.2" +version = "2.7.0" edition = "2021" authors = ["Wil Boayue "] description = "A Rust implementation of the Interactive Brokers TWS API, providing a reliable and user friendly interface for TWS and IB Gateway. Designed with a focus on simplicity and performance." diff --git a/examples/async/historical_data.rs b/examples/async/historical_data.rs index a9db50f4..2e3fb6db 100644 --- a/examples/async/historical_data.rs +++ b/examples/async/historical_data.rs @@ -275,9 +275,7 @@ async fn main() -> Result<(), Box> { bar.close ); } - } - HistoricalBarUpdate::HistoricalEnd => { - println!("Initial historical data complete. Now streaming updates..."); + println!("Now streaming updates..."); } HistoricalBarUpdate::Update(bar) => { println!( diff --git a/src/client/async.rs b/src/client/async.rs index 13c081ba..0189dcbb 100644 --- a/src/client/async.rs +++ b/src/client/async.rs @@ -1059,8 +1059,8 @@ impl Client { /// ```no_run /// use ibapi::contracts::Contract; /// use ibapi::Client; - /// use ibapi::market_data::historical::{ToDuration, BarSize, WhatToShow, HistoricalBarUpdate}; - /// use ibapi::market_data::TradingHours; + /// use ibapi::market_data::historical::{ToDuration, HistoricalBarUpdate}; + /// use ibapi::prelude::{HistoricalBarSize, HistoricalWhatToShow, TradingHours}; /// /// #[tokio::main] /// async fn main() { @@ -1068,7 +1068,10 @@ impl Client { /// let contract = Contract::stock("SPY").build(); /// /// let mut subscription = client - /// .historical_data_streaming(&contract, 3.days(), BarSize::Min15, Some(WhatToShow::Trades), TradingHours::Extended, true) + /// .historical_data_streaming( + /// &contract, 3.days(), HistoricalBarSize::Min15, + /// Some(HistoricalWhatToShow::Trades), TradingHours::Extended, true + /// ) /// .await /// .expect("streaming request failed"); /// @@ -1076,7 +1079,6 @@ impl Client { /// match update { /// HistoricalBarUpdate::Historical(data) => println!("Initial bars: {}", data.bars.len()), /// HistoricalBarUpdate::Update(bar) => println!("Streaming update: {:?}", bar), - /// HistoricalBarUpdate::HistoricalEnd => println!("Initial data complete"), /// } /// } /// } diff --git a/src/client/sync.rs b/src/client/sync.rs index 728690f8..bc97b7db 100644 --- a/src/client/sync.rs +++ b/src/client/sync.rs @@ -1114,6 +1114,58 @@ impl Client { historical::blocking::historical_data(self, contract, interval_end, duration, bar_size, Some(what_to_show), trading_hours) } + /// Requests historical data with optional streaming updates. + /// + /// This method returns a subscription that first yields the initial historical bars. + /// When `keep_up_to_date` is `true`, it continues to yield streaming updates for + /// the current bar as it builds. IBKR sends updated bars every ~4-6 seconds until + /// the bar completes. + /// + /// # Arguments + /// * `contract` - Contract object that is subject of query + /// * `duration` - The amount of time for which the data needs to be retrieved + /// * `bar_size` - The bar size (resolution) + /// * `what_to_show` - The type of data to retrieve (Trades, MidPoint, etc.) + /// * `trading_hours` - Regular trading hours only, or include extended hours + /// * `keep_up_to_date` - If true, continue receiving streaming updates after initial data + /// + /// # Examples + /// + /// ```no_run + /// use ibapi::contracts::Contract; + /// use ibapi::client::blocking::Client; + /// use ibapi::market_data::historical::{ToDuration, HistoricalBarUpdate}; + /// use ibapi::prelude::{HistoricalBarSize, HistoricalWhatToShow, TradingHours}; + /// + /// let client = Client::connect("127.0.0.1:4002", 100).expect("connection failed"); + /// let contract = Contract::stock("SPY").build(); + /// + /// let subscription = client + /// .historical_data_streaming( + /// &contract, 3.days(), HistoricalBarSize::Min15, + /// Some(HistoricalWhatToShow::Trades), TradingHours::Extended, true + /// ) + /// .expect("streaming request failed"); + /// + /// while let Some(update) = subscription.next() { + /// match update { + /// HistoricalBarUpdate::Historical(data) => println!("Initial bars: {}", data.bars.len()), + /// HistoricalBarUpdate::Update(bar) => println!("Streaming update: {:?}", bar), + /// } + /// } + /// ``` + pub fn historical_data_streaming( + &self, + contract: &Contract, + duration: historical::Duration, + bar_size: historical::BarSize, + what_to_show: Option, + trading_hours: TradingHours, + keep_up_to_date: bool, + ) -> Result { + historical::blocking::historical_data_streaming(self, contract, duration, bar_size, what_to_show, trading_hours, keep_up_to_date) + } + /// Requests [Schedule](historical::Schedule) for an interval of given duration /// ending at specified date. /// diff --git a/src/market_data/historical/async.rs b/src/market_data/historical/async.rs index 3e4e8e02..d082481f 100644 --- a/src/market_data/historical/async.rs +++ b/src/market_data/historical/async.rs @@ -356,7 +356,18 @@ impl + Send> TickSubscription { /// A `HistoricalDataStreamingSubscription` that yields `HistoricalBarUpdate` values /// /// # Example -/// ```ignore +/// ```no_run +/// use ibapi::Client; +/// use ibapi::contracts::Contract; +/// use ibapi::market_data::historical::{ +/// BarSize, Duration, HistoricalBarUpdate, WhatToShow, historical_data_streaming +/// }; +/// use ibapi::market_data::TradingHours; +/// +/// # async fn example() -> Result<(), ibapi::Error> { +/// let client = Client::connect("127.0.0.1:4002", 100).await?; +/// let contract = Contract::stock("SPY").build(); +/// /// let mut subscription = historical_data_streaming( /// &client, /// &contract, @@ -375,11 +386,10 @@ impl + Send> TickSubscription { /// HistoricalBarUpdate::Update(bar) => { /// println!("Bar update: {} close={}", bar.date, bar.close); /// } -/// HistoricalBarUpdate::HistoricalEnd => { -/// println!("Initial historical data complete, now streaming"); -/// } /// } /// } +/// # Ok(()) +/// # } /// ``` /// /// # See Also @@ -427,13 +437,12 @@ pub async fn historical_data_streaming( /// Async subscription for streaming historical data with keepUpToDate=true. /// -/// This subscription first yields the initial historical bars, then continues -/// to yield streaming updates for the current bar as it builds. +/// This subscription first yields the initial historical bars as a `Historical` variant, +/// then continues to yield streaming updates for the current bar as `Update` variants. pub struct HistoricalDataStreamingSubscription { messages: AsyncInternalSubscription, server_version: i32, time_zone: &'static Tz, - pending_end: bool, error: Option, } @@ -443,7 +452,6 @@ impl HistoricalDataStreamingSubscription { messages, server_version, time_zone, - pending_end: false, error: None, } } @@ -451,17 +459,10 @@ impl HistoricalDataStreamingSubscription { /// Get the next update from the streaming subscription. /// /// Returns: - /// - `Some(HistoricalBarUpdate::Historical(data))` - Initial batch of historical bars - /// - `Some(HistoricalBarUpdate::HistoricalEnd)` - End of initial historical data + /// - `Some(HistoricalBarUpdate::Historical(data))` - Initial batch of historical bars (always first) /// - `Some(HistoricalBarUpdate::Update(bar))` - Streaming bar update /// - `None` - Subscription ended (connection closed or error) pub async fn next(&mut self) -> Option { - // Emit HistoricalEnd after Historical data was returned - if self.pending_end { - self.pending_end = false; - return Some(HistoricalBarUpdate::HistoricalEnd); - } - loop { match self.messages.next().await { Some(Ok(mut message)) => { @@ -470,7 +471,6 @@ impl HistoricalDataStreamingSubscription { // Initial historical data batch match decoders::decode_historical_data(self.server_version, self.time_zone, &mut message) { Ok(data) => { - self.pending_end = true; return Some(HistoricalBarUpdate::Historical(data)); } Err(e) => { @@ -1187,18 +1187,10 @@ mod tests { _ => panic!("Expected Historical variant"), } - // Second: receive HistoricalEnd marker + // Second: receive streaming update let update2 = subscription.next().await; - assert!(update2.is_some(), "Should receive HistoricalEnd"); + assert!(update2.is_some(), "Should receive streaming update"); match update2.unwrap() { - HistoricalBarUpdate::HistoricalEnd => {} - _ => panic!("Expected HistoricalEnd variant"), - } - - // Third: receive streaming update - let update3 = subscription.next().await; - assert!(update3.is_some(), "Should receive streaming update"); - match update3.unwrap() { HistoricalBarUpdate::Update(bar) => { assert_eq!(bar.open, 185.80, "Wrong open price in update"); assert_eq!(bar.high, 186.10, "Wrong high price in update"); @@ -1210,11 +1202,8 @@ mod tests { // Verify request message includes keepUpToDate=true let request_messages = message_bus.request_messages.read().unwrap(); assert_eq!(request_messages.len(), 1, "Should send one request"); - // The keepUpToDate field should be "1" (true) - assert!( - request_messages[0].fields.contains(&"1".to_string()), - "Request should have keepUpToDate=true" - ); + // keepUpToDate is at field index 21 (for non-bag contracts) + assert_eq!(request_messages[0].fields[21], "1", "Request should have keepUpToDate=true at field[21]"); } #[tokio::test] @@ -1254,23 +1243,11 @@ mod tests { _ => panic!("Expected Historical variant"), } - // Receive HistoricalEnd marker - let update2 = subscription.next().await; - assert!(update2.is_some(), "Should receive HistoricalEnd"); - match update2.unwrap() { - HistoricalBarUpdate::HistoricalEnd => {} - _ => panic!("Expected HistoricalEnd variant"), - } - // Verify request message includes keepUpToDate=false let request_messages = message_bus.request_messages.read().unwrap(); assert_eq!(request_messages.len(), 1, "Should send one request"); - // Find the keepUpToDate field - it should be "0" (false) - // The field order in historical data request puts keepUpToDate near the end - let request = &request_messages[0]; - // Check the last few fields for the "0" value - let fields_str = request.fields.join("|"); - assert!(fields_str.contains("|0|"), "Request should have keepUpToDate=false (0)"); + // keepUpToDate is at field index 21 (for non-bag contracts) + assert_eq!(request_messages[0].fields[21], "0", "Request should have keepUpToDate=false at field[21]"); } #[tokio::test] diff --git a/src/market_data/historical/common/decoders.rs b/src/market_data/historical/common/decoders.rs index c4397c56..db36985e 100644 --- a/src/market_data/historical/common/decoders.rs +++ b/src/market_data/historical/common/decoders.rs @@ -238,7 +238,6 @@ pub(crate) fn decode_histogram_data(message: &mut ResponseMessage) -> Result Result { message.skip(); // message type message.skip(); // request_id @@ -482,7 +481,6 @@ mod tests { assert_eq!(ticks[23].size, 0, "ticks[0].size"); } - #[cfg(feature = "async")] #[test] fn test_decode_historical_data_update() { let time_zone: &Tz = time_tz::timezones::db::america::NEW_YORK; @@ -502,7 +500,6 @@ mod tests { assert_eq!(bar.count, 150, "bar.count"); } - #[cfg(feature = "async")] #[test] fn test_decode_historical_data_update_without_count() { let time_zone: &Tz = time_tz::timezones::db::america::NEW_YORK; diff --git a/src/market_data/historical/mod.rs b/src/market_data/historical/mod.rs index 7764ce06..2bcfb769 100644 --- a/src/market_data/historical/mod.rs +++ b/src/market_data/historical/mod.rs @@ -329,18 +329,17 @@ pub struct HistoricalData { /// Update from historical data streaming with keepUpToDate=true. /// /// When requesting historical data with `keepUpToDate=true`, IBKR first sends -/// the historical bars, then continues streaming updates for the current bar. +/// the initial historical bars as a `Historical` variant, then continues +/// streaming real-time updates for the current bar as `Update` variants. /// The current bar is updated approximately every 4-6 seconds until a new /// bar begins. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum HistoricalBarUpdate { - /// Initial batch of historical bars. + /// Initial batch of historical bars. Always received first. Historical(HistoricalData), /// Real-time update of the current (incomplete) bar. - /// Note: Multiple updates with the same timestamp will be sent as the bar builds. + /// Multiple updates with the same timestamp will be sent as the bar builds. Update(Bar), - /// Signals the end of the initial historical data batch. - HistoricalEnd, } /// Trading schedule describing sessions for a contract. @@ -577,7 +576,10 @@ impl TickDecoder for TickMidpoint { // Re-export TickSubscription and iterator types based on active feature #[cfg(all(feature = "sync", not(feature = "async")))] -pub use sync::{TickSubscription, TickSubscriptionIter, TickSubscriptionOwnedIter, TickSubscriptionTimeoutIter, TickSubscriptionTryIter}; +pub use sync::{ + HistoricalDataStreamingSubscription, TickSubscription, TickSubscriptionIter, TickSubscriptionOwnedIter, TickSubscriptionTimeoutIter, + TickSubscriptionTryIter, +}; #[cfg(feature = "async")] pub use r#async::{historical_data_streaming, HistoricalDataStreamingSubscription, TickSubscription}; diff --git a/src/market_data/historical/sync.rs b/src/market_data/historical/sync.rs index 3363c774..d1f0ffc0 100644 --- a/src/market_data/historical/sync.rs +++ b/src/market_data/historical/sync.rs @@ -12,8 +12,12 @@ use crate::protocol::{check_version, Features}; use crate::transport::{InternalSubscription, Response}; use crate::{client::sync::Client, Error, MAX_RETRIES}; +use time_tz::Tz; + use super::common::{decoders, encoders}; -use super::{BarSize, Duration, HistogramEntry, HistoricalData, Schedule, TickBidAsk, TickDecoder, TickLast, TickMidpoint, WhatToShow}; +use super::{ + BarSize, Duration, HistogramEntry, HistoricalBarUpdate, HistoricalData, Schedule, TickBidAsk, TickDecoder, TickLast, TickMidpoint, WhatToShow, +}; use crate::market_data::TradingHours; // Returns the timestamp of earliest available historical data for a contract and data type. @@ -249,6 +253,172 @@ pub(crate) fn histogram_data( } } +// === Historical Data Streaming with keepUpToDate === + +/// Requests historical data for a contract with optional streaming updates. +/// +/// When `keep_up_to_date` is `true`, this function requests historical bars and then +/// continues to receive streaming updates for the current (incomplete) bar. IBKR sends +/// updates approximately every 4-6 seconds until the bar completes, at which point a +/// new bar begins. +/// +/// When `keep_up_to_date` is `false`, only the initial historical data is returned +/// and the subscription ends after delivering the data. +pub(crate) fn historical_data_streaming( + client: &Client, + contract: &Contract, + duration: Duration, + bar_size: BarSize, + what_to_show: Option, + trading_hours: TradingHours, + keep_up_to_date: bool, +) -> Result { + if !contract.trading_class.is_empty() || contract.contract_id > 0 { + check_version(client.server_version(), Features::TRADING_CLASS)?; + } + + // Note: end_date must be None when keepUpToDate=true (IBKR requirement) + let builder = client.request(); + let request = encoders::encode_request_historical_data( + client.server_version(), + builder.request_id(), + contract, + None, // end_date must be None for keepUpToDate + duration, + bar_size, + what_to_show, + trading_hours.use_rth(), + keep_up_to_date, + Vec::::default(), + )?; + + let subscription = builder.send_raw(request)?; + + // Get the timezone directly + let tz: &'static Tz = client.time_zone.unwrap_or_else(|| { + warn!("server timezone unknown. assuming UTC, but that may be incorrect!"); + time_tz::timezones::db::UTC + }); + + Ok(HistoricalDataStreamingSubscription::new(subscription, client.server_version, tz)) +} + +/// Blocking subscription for streaming historical data with keepUpToDate=true. +/// +/// This subscription first yields the initial historical bars as a `Historical` variant, +/// then continues to yield streaming updates for the current bar as `Update` variants. +pub struct HistoricalDataStreamingSubscription { + messages: InternalSubscription, + server_version: i32, + time_zone: &'static Tz, + error: Mutex>, +} + +impl HistoricalDataStreamingSubscription { + fn new(messages: InternalSubscription, server_version: i32, time_zone: &'static Tz) -> Self { + Self { + messages, + server_version, + time_zone, + error: Mutex::new(None), + } + } + + /// Block until the next update is available. + /// + /// Returns: + /// - `Some(HistoricalBarUpdate::Historical(data))` - Initial batch of historical bars (always first) + /// - `Some(HistoricalBarUpdate::Update(bar))` - Streaming bar update + /// - `None` - Subscription ended (connection closed or error) + pub fn next(&self) -> Option { + self.next_helper(|| self.messages.next()) + } + + /// Attempt to fetch the next update without blocking. + pub fn try_next(&self) -> Option { + self.next_helper(|| self.messages.try_next()) + } + + /// Wait up to `duration` for the next update to arrive. + pub fn next_timeout(&self, duration: std::time::Duration) -> Option { + self.next_helper(|| self.messages.next_timeout(duration)) + } + + fn next_helper(&self, next_response: F) -> Option + where + F: Fn() -> Option, + { + self.clear_error(); + + loop { + match next_response() { + Some(Ok(mut message)) => { + match message.message_type() { + IncomingMessages::HistoricalData => { + // Initial historical data batch + match decoders::decode_historical_data(self.server_version, self.time_zone, &mut message) { + Ok(data) => { + return Some(HistoricalBarUpdate::Historical(data)); + } + Err(e) => { + self.set_error(e); + return None; + } + } + } + IncomingMessages::HistoricalDataUpdate => { + // Streaming bar update + match decoders::decode_historical_data_update(self.time_zone, &mut message) { + Ok(bar) => { + return Some(HistoricalBarUpdate::Update(bar)); + } + Err(e) => { + self.set_error(e); + return None; + } + } + } + IncomingMessages::Error => { + self.set_error(Error::from(message)); + return None; + } + _ => { + // Skip unexpected messages + debug!("unexpected message in streaming subscription: {:?}", message.message_type()); + continue; + } + } + } + Some(Err(e)) => { + self.set_error(e); + return None; + } + None => { + return None; + } + } + } + } + + /// Returns and clears the last error that occurred, if any. + pub fn error(&self) -> Option { + self.error.lock().unwrap().take() + } + + fn set_error(&self, e: Error) { + *self.error.lock().unwrap() = Some(e); + } + + fn clear_error(&self) { + *self.error.lock().unwrap() = None; + } + + /// Cancel the subscription. + pub fn cancel(&self) { + self.messages.cancel(); + } +} + // TickSubscription and related types /// Shared subscription handle that decodes historical tick batches as they arrive. @@ -1090,4 +1260,144 @@ mod tests { "time_zone should return the client's time zone when it is set" ); } + + #[test] + fn test_historical_data_streaming_with_updates() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + // Initial historical data (message type 17) + "17\09000\020230315 09:30:00\020230315 10:30:00\01\01678886400\0185.50\0186.00\0185.25\0185.75\01000\0185.70\0100\0".to_owned(), + // Streaming update (message type 90) + "90\09000\0-1\01678890000\0185.80\0186.10\0185.60\0185.90\0500\0185.85\050\0".to_owned(), + ], + }); + + let mut client = Client::stubbed(message_bus.clone(), server_versions::SIZE_RULES); + client.time_zone = Some(time_tz::timezones::db::UTC); + + let contract = Contract::stock("SPY").build(); + + let subscription = historical_data_streaming( + &client, + &contract, + Duration::days(1), + BarSize::Hour, + Some(WhatToShow::Trades), + TradingHours::Regular, + true, + ) + .expect("streaming request should succeed"); + + // First: receive initial historical data + let update1 = subscription.next(); + assert!(update1.is_some(), "Should receive initial historical data"); + match update1.unwrap() { + HistoricalBarUpdate::Historical(data) => { + assert_eq!(data.bars.len(), 1, "Should have 1 initial bar"); + assert_eq!(data.bars[0].open, 185.50, "Wrong open price"); + } + _ => panic!("Expected Historical variant"), + } + + // Second: receive streaming update + let update2 = subscription.next(); + assert!(update2.is_some(), "Should receive streaming update"); + match update2.unwrap() { + HistoricalBarUpdate::Update(bar) => { + assert_eq!(bar.open, 185.80, "Wrong open price in update"); + assert_eq!(bar.high, 186.10, "Wrong high price in update"); + assert_eq!(bar.close, 185.90, "Wrong close price in update"); + } + _ => panic!("Expected Update variant"), + } + + // Verify request message includes keepUpToDate=true + let request_messages = message_bus.request_messages.read().unwrap(); + assert_eq!(request_messages.len(), 1, "Should send one request"); + // keepUpToDate is at field index 21 (for non-bag contracts) + assert_eq!(request_messages[0].fields[21], "1", "Request should have keepUpToDate=true at field[21]"); + } + + #[test] + fn test_historical_data_streaming_keep_up_to_date_false() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + // Initial historical data only + "17\09000\020230315 09:30:00\020230315 10:30:00\01\01678886400\0185.50\0186.00\0185.25\0185.75\01000\0185.70\0100\0".to_owned(), + ], + }); + + let mut client = Client::stubbed(message_bus.clone(), server_versions::SIZE_RULES); + client.time_zone = Some(time_tz::timezones::db::UTC); + + let contract = Contract::stock("SPY").build(); + + let subscription = historical_data_streaming( + &client, + &contract, + Duration::days(1), + BarSize::Hour, + Some(WhatToShow::Trades), + TradingHours::Regular, + false, // keep_up_to_date = false + ) + .expect("streaming request should succeed"); + + // Receive initial historical data + let update1 = subscription.next(); + assert!(update1.is_some(), "Should receive initial historical data"); + match update1.unwrap() { + HistoricalBarUpdate::Historical(data) => { + assert_eq!(data.bars.len(), 1, "Should have 1 initial bar"); + } + _ => panic!("Expected Historical variant"), + } + + // Verify request message includes keepUpToDate=false + let request_messages = message_bus.request_messages.read().unwrap(); + assert_eq!(request_messages.len(), 1, "Should send one request"); + // keepUpToDate is at field index 21 (for non-bag contracts) + assert_eq!(request_messages[0].fields[21], "0", "Request should have keepUpToDate=false at field[21]"); + } + + #[test] + fn test_historical_data_streaming_error_response() { + let message_bus = Arc::new(MessageBusStub { + request_messages: RwLock::new(vec![]), + response_messages: vec![ + // Error response + "4\02\09000\0162\0Historical Market Data Service error message:No market data permissions.\0".to_owned(), + ], + }); + + let mut client = Client::stubbed(message_bus, server_versions::SIZE_RULES); + client.time_zone = Some(time_tz::timezones::db::UTC); + + let contract = Contract::stock("SPY").build(); + + let subscription = historical_data_streaming( + &client, + &contract, + Duration::days(1), + BarSize::Hour, + Some(WhatToShow::Trades), + TradingHours::Regular, + true, + ) + .expect("streaming request should succeed"); + + // Should return None due to error + let update = subscription.next(); + assert!(update.is_none(), "Should return None on error"); + + // Error should be accessible + let error = subscription.error(); + assert!(error.is_some(), "Error should be stored"); + assert!( + error.unwrap().to_string().contains("No market data permissions"), + "Error should contain the message" + ); + } } From 0d3524f71483583ef798ac020456787823f4865e Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 31 Jan 2026 00:15:31 -0800 Subject: [PATCH 16/18] Refactor: consolidate decoder context, add Client::decoder_context() (#386) * Refactor: consolidate decoder context, add Client::decoder_context() - Rename ResponseContext to DecoderContext - Add server_version and time_zone fields to DecoderContext - Change StreamDecoder::decode() signature to take &DecoderContext - Add decoder_context() helper to sync and async Client - Fix async SubscriptionBuilder to preserve time_zone - Replace 20+ DecoderContext::new() calls with client.decoder_context() * Format code --- src/accounts/common/stream_decoders.rs | 90 +++++----- src/client/async.rs | 10 ++ src/client/builders/async.rs | 74 ++++----- src/client/builders/sync.rs | 36 ++-- src/client/mod.rs | 2 +- src/client/sync.rs | 10 ++ src/contracts/async.rs | 21 +-- src/contracts/common/stream_decoders.rs | 10 +- src/contracts/sync.rs | 21 +-- src/display_groups/common/stream_decoders.rs | 16 +- src/market_data/realtime/async.rs | 33 ++-- src/market_data/realtime/mod.rs | 34 ++-- src/market_data/realtime/sync.rs | 9 +- src/news/async.rs | 31 ++-- src/news/sync.rs | 38 +---- src/orders/async.rs | 98 +++++++---- src/orders/sync.rs | 141 +++++++--------- src/scanner/async.rs | 9 +- src/scanner/sync.rs | 13 +- src/subscriptions/async.rs | 164 +++++++------------ src/subscriptions/common.rs | 135 +++++++++++---- src/subscriptions/mod.rs | 2 +- src/subscriptions/sync.rs | 25 +-- src/wsh/async.rs | 8 +- src/wsh/common/stream_decoders.rs | 10 +- src/wsh/sync.rs | 6 +- 26 files changed, 518 insertions(+), 528 deletions(-) diff --git a/src/accounts/common/stream_decoders.rs b/src/accounts/common/stream_decoders.rs index 5a99e8ce..7438ca58 100644 --- a/src/accounts/common/stream_decoders.rs +++ b/src/accounts/common/stream_decoders.rs @@ -5,7 +5,7 @@ use crate::accounts::*; use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; -use crate::subscriptions::{ResponseContext, StreamDecoder}; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::Error; use super::{decoders, encoders}; @@ -14,15 +14,18 @@ use crate::common::error_helpers; impl StreamDecoder for AccountSummaryResult { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::AccountSummary, IncomingMessages::AccountSummaryEnd]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::AccountSummary => Ok(AccountSummaryResult::Summary(decoders::decode_account_summary(server_version, message)?)), + IncomingMessages::AccountSummary => Ok(AccountSummaryResult::Summary(decoders::decode_account_summary( + context.server_version, + message, + )?)), IncomingMessages::AccountSummaryEnd => Ok(AccountSummaryResult::End), message => Err(Error::Simple(format!("unexpected message: {message:?}"))), } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id(request_id)?; encoders::encode_cancel_account_summary(request_id) } @@ -31,11 +34,11 @@ impl StreamDecoder for AccountSummaryResult { impl StreamDecoder for PnL { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::PnL]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { - decoders::decode_pnl(server_version, message) + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { + decoders::decode_pnl(context.server_version, message) } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id(request_id)?; encoders::encode_cancel_pnl(request_id) } @@ -44,11 +47,11 @@ impl StreamDecoder for PnL { impl StreamDecoder for PnLSingle { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::PnLSingle]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { - decoders::decode_pnl_single(server_version, message) + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { + decoders::decode_pnl_single(context.server_version, message) } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id(request_id)?; encoders::encode_cancel_pnl_single(request_id) } @@ -57,7 +60,7 @@ impl StreamDecoder for PnLSingle { impl StreamDecoder for PositionUpdate { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::Position, IncomingMessages::PositionEnd]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::Position => Ok(PositionUpdate::Position(decoders::decode_position(message)?)), IncomingMessages::PositionEnd => Ok(PositionUpdate::PositionEnd), @@ -65,7 +68,7 @@ impl StreamDecoder for PositionUpdate { } } - fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&DecoderContext>) -> Result { encoders::encode_cancel_positions() } } @@ -73,7 +76,7 @@ impl StreamDecoder for PositionUpdate { impl StreamDecoder for PositionUpdateMulti { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::PositionMulti, IncomingMessages::PositionMultiEnd]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::PositionMulti => Ok(PositionUpdateMulti::Position(decoders::decode_position_multi(message)?)), IncomingMessages::PositionMultiEnd => Ok(PositionUpdateMulti::PositionEnd), @@ -81,7 +84,7 @@ impl StreamDecoder for PositionUpdateMulti { } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id(request_id)?; encoders::encode_cancel_positions_multi(request_id) } @@ -95,11 +98,11 @@ impl StreamDecoder for AccountUpdate { IncomingMessages::AccountDownloadEnd, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::AccountValue => Ok(AccountUpdate::AccountValue(decoders::decode_account_value(message)?)), IncomingMessages::PortfolioValue => Ok(AccountUpdate::PortfolioValue(decoders::decode_account_portfolio_value( - server_version, + context.server_version, message, )?)), IncomingMessages::AccountUpdateTime => Ok(AccountUpdate::UpdateTime(decoders::decode_account_update_time(message)?)), @@ -108,7 +111,7 @@ impl StreamDecoder for AccountUpdate { } } - fn cancel_message(server_version: i32, _request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(server_version: i32, _request_id: Option, _context: Option<&DecoderContext>) -> Result { encoders::encode_cancel_account_updates(server_version) } } @@ -116,7 +119,7 @@ impl StreamDecoder for AccountUpdate { impl StreamDecoder for AccountUpdateMulti { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::AccountUpdateMulti, IncomingMessages::AccountUpdateMultiEnd]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::AccountUpdateMulti => Ok(AccountUpdateMulti::AccountMultiValue(decoders::decode_account_multi_value(message)?)), IncomingMessages::AccountUpdateMultiEnd => Ok(AccountUpdateMulti::End), @@ -124,7 +127,7 @@ impl StreamDecoder for AccountUpdateMulti { } } - fn cancel_message(server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id_for(request_id, "encode cancel account updates multi")?; encoders::encode_cancel_account_updates_multi(server_version, request_id) } @@ -135,12 +138,14 @@ mod tests { use super::*; use crate::common::test_utils::helpers::*; use crate::messages::OutgoingMessages; - use crate::subscriptions::ResponseContext; - // Test data const TEST_REQUEST_ID: i32 = 123; const TEST_SERVER_VERSION: i32 = 151; + fn test_context() -> DecoderContext { + DecoderContext::new(TEST_SERVER_VERSION, None) + } + mod account_summary_tests { use super::*; @@ -149,7 +154,7 @@ mod tests { // Format: message_type\0version\0request_id\0account\0tag\0value\0currency\0 let mut message = ResponseMessage::from("63\01\0123\0DU1234567\0NetLiquidation\0123456.78\0USD\0"); - let result = AccountSummaryResult::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountSummaryResult::decode(&test_context(), &mut message).unwrap(); match result { AccountSummaryResult::Summary(summary) => { @@ -167,7 +172,7 @@ mod tests { // Format: message_type\0version\0request_id\0 let mut message = ResponseMessage::from("64\01\0123\0"); - let result = AccountSummaryResult::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountSummaryResult::decode(&test_context(), &mut message).unwrap(); assert!(matches!(result, AccountSummaryResult::End)); } @@ -177,7 +182,7 @@ mod tests { // Using Error message type which is not expected for AccountSummaryResult let mut message = ResponseMessage::from("4\02\0123\0Some error\0"); - let result = AccountSummaryResult::decode(TEST_SERVER_VERSION, &mut message); + let result = AccountSummaryResult::decode(&test_context(), &mut message); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("unexpected message")); @@ -217,7 +222,7 @@ mod tests { // Format: message_type\0request_id\0daily_pnl\0unrealized_pnl\0realized_pnl\0 let mut message = ResponseMessage::from("94\0123\01234.56\02345.67\03456.78\0"); - let result = PnL::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PnL::decode(&test_context(), &mut message).unwrap(); assert_eq!(result.daily_pnl, 1234.56); assert_eq!(result.unrealized_pnl, Some(2345.67)); @@ -253,7 +258,7 @@ mod tests { // Format: message_type\0request_id\0position\0daily_pnl\0unrealized_pnl\0realized_pnl\0value\0 let mut message = ResponseMessage::from("95\0123\0100\01234.56\02345.67\03456.78\04567.89\0"); - let result = PnLSingle::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PnLSingle::decode(&test_context(), &mut message).unwrap(); assert_eq!(result.position, 100.0); assert_eq!(result.daily_pnl, 1234.56); @@ -284,7 +289,7 @@ mod tests { // Format: message_type\0version\0account\0contract_id\0symbol\0sec_type\0last_trade_date\0strike\0right\0multiplier\0exchange\0currency\0local_symbol\0trading_class\0position\0avg_cost\0 let mut message = ResponseMessage::from("61\03\0DU1234567\012345\0AAPL\0STK\0\00.0\0\0\0NASDAQ\0USD\0AAPL\0NMS\0100\050.25\0"); - let result = PositionUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PositionUpdate::decode(&test_context(), &mut message).unwrap(); match result { PositionUpdate::Position(pos) => { @@ -302,7 +307,7 @@ mod tests { // Format: message_type\0version\0 let mut message = ResponseMessage::from("62\01\0"); - let result = PositionUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PositionUpdate::decode(&test_context(), &mut message).unwrap(); assert!(matches!(result, PositionUpdate::PositionEnd)); } @@ -333,7 +338,7 @@ mod tests { let mut message = ResponseMessage::from("71\01\0123\0DU1234567\012345\0AAPL\0STK\0\00.0\0\0\0NASDAQ\0USD\0AAPL\0NMS\0100\050.25\0TARGET2024\0"); - let result = PositionUpdateMulti::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PositionUpdateMulti::decode(&test_context(), &mut message).unwrap(); match result { PositionUpdateMulti::Position(pos) => { @@ -352,7 +357,7 @@ mod tests { // Format: message_type\0version\0request_id\0 let mut message = ResponseMessage::from("72\01\0123\0"); - let result = PositionUpdateMulti::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PositionUpdateMulti::decode(&test_context(), &mut message).unwrap(); assert!(matches!(result, PositionUpdateMulti::PositionEnd)); } @@ -390,7 +395,7 @@ mod tests { // Format: message_type\0version\0key\0value\0currency\0account\0 let mut message = ResponseMessage::from("6\02\0NetLiquidation\0123456.78\0USD\0DU1234567\0"); - let result = AccountUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountUpdate::decode(&test_context(), &mut message).unwrap(); match result { AccountUpdate::AccountValue(val) => { @@ -410,7 +415,7 @@ mod tests { "7\08\012345\0AAPL\0STK\020230101\0150.0\0\0\0NASDAQ\0USD\0AAPL\0NMS\0100\0155.0\015500.0\0150.0\0500.0\00.0\0DU1234567\0", ); - let result = AccountUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountUpdate::decode(&test_context(), &mut message).unwrap(); match result { AccountUpdate::PortfolioValue(val) => { @@ -429,7 +434,7 @@ mod tests { // Format: message_type\0version\0timestamp\0 let mut message = ResponseMessage::from("8\01\014:30:00\0"); - let result = AccountUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountUpdate::decode(&test_context(), &mut message).unwrap(); match result { AccountUpdate::UpdateTime(time) => { @@ -444,7 +449,7 @@ mod tests { // Format: message_type\0version\0account\0 let mut message = ResponseMessage::from("54\01\0DU1234567\0"); - let result = AccountUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountUpdate::decode(&test_context(), &mut message).unwrap(); assert!(matches!(result, AccountUpdate::End)); } @@ -478,7 +483,7 @@ mod tests { // Format: message_type\0version\0request_id\0account\0model_code\0key\0value\0currency\0 let mut message = ResponseMessage::from("73\01\0123\0DU1234567\0TARGET2024\0NetLiquidation\0123456.78\0USD\0"); - let result = AccountUpdateMulti::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountUpdateMulti::decode(&test_context(), &mut message).unwrap(); match result { AccountUpdateMulti::AccountMultiValue(val) => { @@ -497,7 +502,7 @@ mod tests { // Format: message_type\0version\0request_id\0 let mut message = ResponseMessage::from("74\01\0123\0"); - let result = AccountUpdateMulti::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountUpdateMulti::decode(&test_context(), &mut message).unwrap(); assert!(matches!(result, AccountUpdateMulti::End)); } @@ -537,7 +542,7 @@ mod tests { // Invalid message with missing fields let mut message = ResponseMessage::from("63\01\0123\0"); - let result = AccountSummaryResult::decode(TEST_SERVER_VERSION, &mut message); + let result = AccountSummaryResult::decode(&test_context(), &mut message); assert!(result.is_err()); } @@ -547,7 +552,7 @@ mod tests { // Create a message with missing fields (only has account, missing tag, value, currency) let mut message = ResponseMessage::from("63\01\0123\0DU1234567\0"); - let result = AccountSummaryResult::decode(TEST_SERVER_VERSION, &mut message); + let result = AccountSummaryResult::decode(&test_context(), &mut message); assert!(result.is_err()); } @@ -555,10 +560,7 @@ mod tests { #[test] fn test_context_parameter_ignored() { // All cancel_message implementations should ignore the context parameter - let context = ResponseContext { - request_type: Some(OutgoingMessages::RequestMarketData), - is_smart_depth: false, - }; + let context = DecoderContext::new(TEST_SERVER_VERSION, None).with_request_type(OutgoingMessages::RequestMarketData); // Test that context is ignored (should produce same result with or without) let result1 = AccountSummaryResult::cancel_message(TEST_SERVER_VERSION, Some(TEST_REQUEST_ID), None).unwrap(); @@ -585,7 +587,7 @@ mod tests { let mut results = Vec::new(); for mut message in messages { - let result = AccountSummaryResult::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = AccountSummaryResult::decode(&test_context(), &mut message).unwrap(); results.push(result); } @@ -606,7 +608,7 @@ mod tests { let mut results = Vec::new(); for mut message in messages { - let result = PositionUpdate::decode(TEST_SERVER_VERSION, &mut message).unwrap(); + let result = PositionUpdate::decode(&test_context(), &mut message).unwrap(); results.push(result); } diff --git a/src/client/async.rs b/src/client/async.rs index 0189dcbb..7f9c27d7 100644 --- a/src/client/async.rs +++ b/src/client/async.rs @@ -154,6 +154,16 @@ impl Client { self.connection_time } + /// Returns the server's time zone + pub fn time_zone(&self) -> Option<&'static Tz> { + self.time_zone + } + + /// Returns a decoder context for this client + pub(crate) fn decoder_context(&self) -> crate::subscriptions::DecoderContext { + crate::subscriptions::DecoderContext::new(self.server_version, self.time_zone) + } + /// Returns true if the client is currently connected to TWS/IB Gateway. /// /// This method checks if the underlying connection to TWS or IB Gateway is active. diff --git a/src/client/builders/async.rs b/src/client/builders/async.rs index a7841e39..246000ef 100644 --- a/src/client/builders/async.rs +++ b/src/client/builders/async.rs @@ -8,7 +8,7 @@ use async_trait::async_trait; use crate::client::r#async::Client; use crate::errors::Error; use crate::messages::{OutgoingMessages, RequestMessage}; -use crate::subscriptions::{ResponseContext, StreamDecoder, Subscription}; +use crate::subscriptions::{DecoderContext, StreamDecoder, Subscription}; use crate::transport::{AsyncInternalSubscription, AsyncMessageBus}; /// Builder for creating requests with IDs @@ -49,22 +49,20 @@ impl<'a> RequestBuilder<'a> { where T: StreamDecoder + Send + 'static, { - let server_version = self.client.server_version(); + let context = self.client.decoder_context(); let message_bus = self.client.message_bus.clone(); - SubscriptionBuilder::::new_with_components(server_version, message_bus) + SubscriptionBuilder::::new_with_components(context, message_bus) .send_with_request_id::(self.request_id, message) .await } /// Send the request and create a subscription with context - pub async fn send_with_context(self, message: RequestMessage, context: ResponseContext) -> Result, Error> + pub async fn send_with_context(self, message: RequestMessage, context: DecoderContext) -> Result, Error> where T: StreamDecoder + Send + 'static, { - let server_version = self.client.server_version(); let message_bus = self.client.message_bus.clone(); - SubscriptionBuilder::::new_with_components(server_version, message_bus) - .with_context(context) + SubscriptionBuilder::::new_with_components(context, message_bus) .send_with_request_id::(self.request_id, message) .await } @@ -100,22 +98,20 @@ impl<'a> SharedRequestBuilder<'a> { where T: StreamDecoder + Send + 'static, { - let server_version = self.client.server_version(); + let context = self.client.decoder_context(); let message_bus = self.client.message_bus.clone(); - SubscriptionBuilder::::new_with_components(server_version, message_bus) + SubscriptionBuilder::::new_with_components(context, message_bus) .send_shared::(self.message_type, message) .await } /// Send the request and create a subscription with context - pub async fn send_with_context(self, message: RequestMessage, context: ResponseContext) -> Result, Error> + pub async fn send_with_context(self, message: RequestMessage, context: DecoderContext) -> Result, Error> where T: StreamDecoder + Send + 'static, { - let server_version = self.client.server_version(); let message_bus = self.client.message_bus.clone(); - SubscriptionBuilder::::new_with_components(server_version, message_bus) - .with_context(context) + SubscriptionBuilder::::new_with_components(context, message_bus) .send_shared::(self.message_type, message) .await } @@ -193,9 +189,8 @@ impl<'a> MessageBuilder<'a> { /// Builder for creating subscriptions with consistent patterns #[allow(dead_code)] pub(crate) struct SubscriptionBuilder { - server_version: i32, message_bus: Arc, - context: ResponseContext, + context: DecoderContext, _phantom: PhantomData, } @@ -205,17 +200,16 @@ where T: Send + 'static, { /// Creates a new subscription builder from components - pub fn new_with_components(server_version: i32, message_bus: Arc) -> Self { + pub fn new_with_components(context: DecoderContext, message_bus: Arc) -> Self { Self { - server_version, message_bus, - context: ResponseContext::default(), + context, _phantom: PhantomData, } } /// Sets the response context - pub fn with_context(mut self, context: ResponseContext) -> Self { + pub fn with_context(mut self, context: DecoderContext) -> Self { self.context = context; self } @@ -231,13 +225,10 @@ where where D: StreamDecoder + 'static, { - // Use atomic subscribe + send let subscription = self.message_bus.send_request(request_id, message).await?; - // Create subscription with decoder Ok(Subscription::new_from_internal::( subscription, - self.server_version, self.message_bus.clone(), Some(request_id), None, @@ -251,12 +242,10 @@ where where D: StreamDecoder + 'static, { - // Use atomic subscribe + send let subscription = self.message_bus.send_shared_request(message_type, message).await?; Ok(Subscription::new_from_internal::( subscription, - self.server_version, self.message_bus.clone(), None, None, @@ -270,12 +259,10 @@ where where D: StreamDecoder + 'static, { - // Use atomic subscribe + send let subscription = self.message_bus.send_order_request(order_id, message).await?; Ok(Subscription::new_from_internal::( subscription, - self.server_version, self.message_bus.clone(), None, Some(order_id), @@ -348,9 +335,9 @@ impl SubscriptionBuilderExt for Client { where T: Send + 'static, { - let server_version = self.server_version(); + let context = self.decoder_context(); let message_bus = self.message_bus.clone(); - SubscriptionBuilder::new_with_components(server_version, message_bus) + SubscriptionBuilder::new_with_components(context, message_bus) } } @@ -449,9 +436,9 @@ mod tests { #[tokio::test] async fn test_subscription_builder_new() { let (client, _gateway) = create_test_client().await; - let server_version = client.server_version(); + let context = client.decoder_context(); let message_bus = client.message_bus.clone(); - let builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(server_version, message_bus); + let builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(context, message_bus); // Builder created successfully let _ = builder; } @@ -459,22 +446,22 @@ mod tests { #[tokio::test] async fn test_subscription_builder_with_context() { let (client, _gateway) = create_test_client().await; - let server_version = client.server_version(); + let context = client + .decoder_context() + .with_smart_depth(true) + .with_request_type(OutgoingMessages::RequestMarketData); let message_bus = client.message_bus.clone(); - let context = ResponseContext { - is_smart_depth: true, - request_type: Some(OutgoingMessages::RequestMarketData), - }; - let builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(server_version, message_bus).with_context(context.clone()); + let builder: SubscriptionBuilder = + SubscriptionBuilder::new_with_components(client.decoder_context(), message_bus).with_context(context.clone()); assert_eq!(builder.context, context); } #[tokio::test] async fn test_subscription_builder_with_smart_depth() { let (client, _gateway) = create_test_client().await; - let server_version = client.server_version(); + let context = client.decoder_context(); let message_bus = client.message_bus.clone(); - let builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(server_version, message_bus).with_smart_depth(true); + let builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(context, message_bus).with_smart_depth(true); assert!(builder.context.is_smart_depth); } @@ -627,9 +614,9 @@ mod tests { for tc in test_cases { let (client, _gateway) = create_test_client().await; - let server_version = client.server_version(); + let context = client.decoder_context(); let message_bus = client.message_bus.clone(); - let mut builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(server_version, message_bus); + let mut builder: SubscriptionBuilder = SubscriptionBuilder::new_with_components(context, message_bus); // Set initial context builder.context.is_smart_depth = tc.initial_smart_depth; @@ -641,10 +628,9 @@ mod tests { } if let Some(request_type) = tc.set_request_type { - let context = ResponseContext { - is_smart_depth: builder.context.is_smart_depth, - request_type: Some(request_type), - }; + let context = DecoderContext::new(builder.context.server_version, builder.context.time_zone) + .with_smart_depth(builder.context.is_smart_depth) + .with_request_type(request_type); builder = builder.with_context(context); } diff --git a/src/client/builders/sync.rs b/src/client/builders/sync.rs index aa337457..df9cca0f 100644 --- a/src/client/builders/sync.rs +++ b/src/client/builders/sync.rs @@ -8,7 +8,7 @@ use crate::client::StreamDecoder; use crate::errors::Error; use crate::messages::{OutgoingMessages, RequestMessage}; use crate::subscriptions::sync::Subscription; -use crate::subscriptions::ResponseContext; +use crate::subscriptions::DecoderContext; use crate::transport::InternalSubscription; /// Builder for creating requests with IDs @@ -53,7 +53,7 @@ impl<'a> RequestBuilder<'a> { } /// Send the request and create a subscription with context - pub fn send_with_context(self, message: RequestMessage, context: ResponseContext) -> Result, Error> + pub fn send_with_context(self, message: RequestMessage, context: DecoderContext) -> Result, Error> where T: StreamDecoder, { @@ -97,7 +97,7 @@ impl<'a> SharedRequestBuilder<'a> { } /// Send the request and create a subscription with context - pub fn send_with_context(self, message: RequestMessage, context: ResponseContext) -> Result, Error> + pub fn send_with_context(self, message: RequestMessage, context: DecoderContext) -> Result, Error> where T: StreamDecoder, { @@ -180,7 +180,7 @@ impl<'a> MessageBuilder<'a> { #[allow(dead_code)] pub(crate) struct SubscriptionBuilder<'a, T> { client: &'a Client, - context: ResponseContext, + context: DecoderContext, _phantom: PhantomData, } @@ -193,13 +193,13 @@ where pub fn new(client: &'a Client) -> Self { Self { client, - context: ResponseContext::default(), + context: client.decoder_context(), _phantom: PhantomData, } } /// Sets the response context for special handling - pub fn with_context(mut self, context: ResponseContext) -> Self { + pub fn with_context(mut self, context: DecoderContext) -> Self { self.context = context; self } @@ -212,12 +212,7 @@ where /// Builds a subscription from an internal subscription (already sent) pub fn build(self, subscription: InternalSubscription) -> Subscription { - Subscription::new( - self.client.server_version, - Arc::clone(&self.client.message_bus), - subscription, - Some(self.context), - ) + Subscription::new(Arc::clone(&self.client.message_bus), subscription, self.context) } /// Sends a request with a specific request ID and builds the subscription @@ -312,7 +307,7 @@ mod tests { use crate::client::common::tests::setup_connect; use crate::market_data::realtime::Bar; use crate::messages::OutgoingMessages; - use crate::subscriptions::ResponseContext; + use crate::subscriptions::DecoderContext; fn create_test_client() -> (Client, MockGateway) { let gateway = setup_connect(); @@ -409,10 +404,10 @@ mod tests { #[test] fn test_subscription_builder_with_context() { let (client, _gateway) = create_test_client(); - let context = ResponseContext { - is_smart_depth: true, - request_type: Some(OutgoingMessages::RequestMarketData), - }; + let context = client + .decoder_context() + .with_smart_depth(true) + .with_request_type(OutgoingMessages::RequestMarketData); let builder: SubscriptionBuilder = SubscriptionBuilder::new(&client).with_context(context.clone()); assert_eq!(builder.context, context); } @@ -585,10 +580,9 @@ mod tests { } if let Some(request_type) = tc.set_request_type { - let context = ResponseContext { - is_smart_depth: builder.context.is_smart_depth, - request_type: Some(request_type), - }; + let context = DecoderContext::new(builder.context.server_version, builder.context.time_zone) + .with_smart_depth(builder.context.is_smart_depth) + .with_request_type(request_type); builder = builder.with_context(context); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 80910a07..53b90615 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -28,7 +28,7 @@ pub use r#async::Client; pub use sync::Client; #[cfg(feature = "sync")] -pub(crate) use crate::subscriptions::{ResponseContext, StreamDecoder}; +pub(crate) use crate::subscriptions::StreamDecoder; #[cfg(all(feature = "sync", not(feature = "async")))] pub use crate::subscriptions::sync::Subscription; diff --git a/src/client/sync.rs b/src/client/sync.rs index bc97b7db..00071794 100644 --- a/src/client/sync.rs +++ b/src/client/sync.rs @@ -230,6 +230,16 @@ impl Client { self.connection_time } + /// Returns the server's time zone + pub fn time_zone(&self) -> Option<&'static Tz> { + self.time_zone + } + + /// Returns a decoder context for this client + pub(crate) fn decoder_context(&self) -> crate::subscriptions::DecoderContext { + crate::subscriptions::DecoderContext::new(self.server_version, self.time_zone) + } + /// Returns true if the client is currently connected to TWS/IB Gateway. /// /// This method checks if the underlying connection to TWS or IB Gateway is active. diff --git a/src/contracts/async.rs b/src/contracts/async.rs index 4a955ac7..aff0454d 100644 --- a/src/contracts/async.rs +++ b/src/contracts/async.rs @@ -143,7 +143,7 @@ pub async fn calculate_option_price( let mut subscription = builder.send_raw(message).await?; match subscription.next().await { - Some(Ok(mut message)) => OptionComputation::decode(client.server_version(), &mut message), + Some(Ok(mut message)) => OptionComputation::decode(&client.decoder_context(), &mut message), Some(Err(e)) => Err(e), None => Err(Error::Simple("no data for option calculation".into())), } @@ -169,7 +169,7 @@ pub async fn calculate_implied_volatility( let mut subscription = builder.send_raw(message).await?; match subscription.next().await { - Some(Ok(mut message)) => OptionComputation::decode(client.server_version(), &mut message), + Some(Ok(mut message)) => OptionComputation::decode(&client.decoder_context(), &mut message), Some(Err(e)) => Err(e), None => Err(Error::Simple("no data for option calculation".into())), } @@ -196,7 +196,7 @@ mod tests { use crate::messages::ResponseMessage; use crate::server_versions; use crate::stubs::MessageBusStub; - use crate::subscriptions::{ResponseContext, StreamDecoder}; + use crate::subscriptions::{DecoderContext, StreamDecoder}; use std::sync::{Arc, RwLock}; #[tokio::test] @@ -398,14 +398,14 @@ mod tests { match &test_case.expected_result { StreamDecoderResult::OptionComputation { price, delta } => { - let result = OptionComputation::decode(server_versions::SIZE_RULES, &mut message); + let result = OptionComputation::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!(result.is_ok(), "Test '{}' failed: {:?}", test_case.name, result.err()); let computation = result.unwrap(); assert_eq!(computation.option_price, Some(*price), "Test '{}' price mismatch", test_case.name); assert_eq!(computation.delta, Some(*delta), "Test '{}' delta mismatch", test_case.name); } StreamDecoderResult::OptionChain { exchange, underlying_conid } => { - let result = OptionChain::decode(server_versions::SIZE_RULES, &mut message); + let result = OptionChain::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!(result.is_ok(), "Test '{}' failed: {:?}", test_case.name, result.err()); let chain = result.unwrap(); assert_eq!(chain.exchange, *exchange, "Test '{}' exchange mismatch", test_case.name); @@ -419,7 +419,7 @@ mod tests { match test_case.message { msg if msg.starts_with("76") => { // OptionChain end of stream - let result = OptionChain::decode(server_versions::SIZE_RULES, &mut message); + let result = OptionChain::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!(result.is_err(), "Test '{}' should have failed", test_case.name); assert!( format!("{:?}", result.err()).contains(expected_error), @@ -429,8 +429,8 @@ mod tests { } _ => { // Try both decoders - let opt_result = OptionComputation::decode(server_versions::SIZE_RULES, &mut message.clone()); - let chain_result = OptionChain::decode(server_versions::SIZE_RULES, &mut message); + let opt_result = OptionComputation::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message.clone()); + let chain_result = OptionChain::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!( opt_result.is_err() && chain_result.is_err(), "Test '{}' should have failed", @@ -446,10 +446,7 @@ mod tests { #[tokio::test] async fn test_cancel_messages() { for test_case in cancel_message_test_cases() { - let context = test_case.request_type.map(|rt| ResponseContext { - request_type: Some(rt), - is_smart_depth: false, - }); + let context = test_case.request_type.map(|rt| DecoderContext::default().with_request_type(rt)); let result = match test_case.decoder_type { "OptionComputation" => OptionComputation::cancel_message(server_versions::SIZE_RULES, test_case.request_id, context.as_ref()), diff --git a/src/contracts/common/stream_decoders.rs b/src/contracts/common/stream_decoders.rs index 4a2d4b91..b6e203f4 100644 --- a/src/contracts/common/stream_decoders.rs +++ b/src/contracts/common/stream_decoders.rs @@ -5,7 +5,7 @@ use crate::contracts::*; use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; -use crate::subscriptions::{ResponseContext, StreamDecoder}; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::Error; use super::decoders; @@ -14,14 +14,14 @@ use super::encoders; impl StreamDecoder for OptionComputation { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::TickOptionComputation]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::TickOptionComputation => Ok(decoders::decode_option_computation(server_version, message)?), + IncomingMessages::TickOptionComputation => Ok(decoders::decode_option_computation(context.server_version, message)?), message => Err(Error::Simple(format!("unexpected message: {message:?}"))), } } - fn cancel_message(_server_version: i32, request_id: Option, context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("request id required to cancel option calculations"); match context.and_then(|c| c.request_type) { Some(OutgoingMessages::ReqCalcImpliedVolat) => { @@ -37,7 +37,7 @@ impl StreamDecoder for OptionComputation { } impl StreamDecoder for OptionChain { - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::SecurityDefinitionOptionParameter => Ok(decoders::decode_option_chain(message)?), IncomingMessages::SecurityDefinitionOptionParameterEnd => Err(Error::EndOfStream), diff --git a/src/contracts/sync.rs b/src/contracts/sync.rs index b79f14b9..25e15f4e 100644 --- a/src/contracts/sync.rs +++ b/src/contracts/sync.rs @@ -133,7 +133,7 @@ pub(crate) fn calculate_option_price( let subscription = builder.send_raw(message)?; match subscription.next() { - Some(Ok(mut message)) => OptionComputation::decode(client.server_version, &mut message), + Some(Ok(mut message)) => OptionComputation::decode(&client.decoder_context(), &mut message), Some(Err(e)) => Err(e), None => Err(Error::Simple("no data for option calculation".into())), } @@ -159,7 +159,7 @@ pub(crate) fn calculate_implied_volatility( let subscription = builder.send_raw(message)?; match subscription.next() { - Some(Ok(mut message)) => OptionComputation::decode(client.server_version, &mut message), + Some(Ok(mut message)) => OptionComputation::decode(&client.decoder_context(), &mut message), Some(Err(e)) => Err(e), None => Err(Error::Simple("no data for option calculation".into())), } @@ -184,7 +184,7 @@ mod tests { use crate::messages::ResponseMessage; use crate::server_versions; use crate::stubs::MessageBusStub; - use crate::subscriptions::{ResponseContext, StreamDecoder}; + use crate::subscriptions::{DecoderContext, StreamDecoder}; use std::sync::{Arc, RwLock}; #[test] @@ -382,14 +382,14 @@ mod tests { match &test_case.expected_result { StreamDecoderResult::OptionComputation { price, delta } => { - let result = OptionComputation::decode(server_versions::SIZE_RULES, &mut message); + let result = OptionComputation::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!(result.is_ok(), "Test '{}' failed: {:?}", test_case.name, result.err()); let computation = result.unwrap(); assert_eq!(computation.option_price, Some(*price), "Test '{}' price mismatch", test_case.name); assert_eq!(computation.delta, Some(*delta), "Test '{}' delta mismatch", test_case.name); } StreamDecoderResult::OptionChain { exchange, underlying_conid } => { - let result = OptionChain::decode(server_versions::SIZE_RULES, &mut message); + let result = OptionChain::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!(result.is_ok(), "Test '{}' failed: {:?}", test_case.name, result.err()); let chain = result.unwrap(); assert_eq!(chain.exchange, *exchange, "Test '{}' exchange mismatch", test_case.name); @@ -403,7 +403,7 @@ mod tests { match test_case.message { msg if msg.starts_with("76") => { // OptionChain end of stream - let result = OptionChain::decode(server_versions::SIZE_RULES, &mut message); + let result = OptionChain::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!(result.is_err(), "Test '{}' should have failed", test_case.name); assert!( format!("{:?}", result.err()).contains(expected_error), @@ -413,8 +413,8 @@ mod tests { } _ => { // Try both decoders - let opt_result = OptionComputation::decode(server_versions::SIZE_RULES, &mut message.clone()); - let chain_result = OptionChain::decode(server_versions::SIZE_RULES, &mut message); + let opt_result = OptionComputation::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message.clone()); + let chain_result = OptionChain::decode(&DecoderContext::new(server_versions::SIZE_RULES, None), &mut message); assert!( opt_result.is_err() && chain_result.is_err(), "Test '{}' should have failed", @@ -430,10 +430,7 @@ mod tests { #[test] fn test_cancel_messages() { for test_case in cancel_message_test_cases() { - let context = test_case.request_type.map(|rt| ResponseContext { - request_type: Some(rt), - is_smart_depth: false, - }); + let context = test_case.request_type.map(|rt| DecoderContext::default().with_request_type(rt)); let result = match test_case.decoder_type { "OptionComputation" => OptionComputation::cancel_message(server_versions::SIZE_RULES, test_case.request_id, context.as_ref()), diff --git a/src/display_groups/common/stream_decoders.rs b/src/display_groups/common/stream_decoders.rs index 82580f51..f338a3fe 100644 --- a/src/display_groups/common/stream_decoders.rs +++ b/src/display_groups/common/stream_decoders.rs @@ -2,7 +2,7 @@ use crate::common::error_helpers; use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; -use crate::subscriptions::{ResponseContext, StreamDecoder}; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::Error; use super::{decoders, encoders}; @@ -27,14 +27,14 @@ impl DisplayGroupUpdate { impl StreamDecoder for DisplayGroupUpdate { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::DisplayGroupUpdated]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::DisplayGroupUpdated => decoders::decode_display_group_updated(message), _ => Err(Error::UnexpectedResponse(message.clone())), } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id_for(request_id, "unsubscribe from group events")?; encoders::encode_unsubscribe_from_group_events(request_id) } @@ -49,11 +49,15 @@ mod tests { ResponseMessage::from(&raw) } + fn test_context() -> DecoderContext { + DecoderContext::new(176, None) + } + #[test] fn test_decode_display_group_update() { let mut message = make_response(&["68", "1", "9000", "265598@SMART"]); - let result = DisplayGroupUpdate::decode(176, &mut message).expect("decoding failed"); + let result = DisplayGroupUpdate::decode(&test_context(), &mut message).expect("decoding failed"); assert_eq!(result.contract_info, "265598@SMART"); } @@ -62,7 +66,7 @@ mod tests { fn test_decode_display_group_update_empty() { let mut message = make_response(&["68", "1", "9000"]); - let result = DisplayGroupUpdate::decode(176, &mut message).expect("decoding failed"); + let result = DisplayGroupUpdate::decode(&test_context(), &mut message).expect("decoding failed"); assert_eq!(result.contract_info, ""); } @@ -71,7 +75,7 @@ mod tests { fn test_decode_wrong_message_type() { let mut message = make_response(&["67", "1", "9000", "data"]); - let result = DisplayGroupUpdate::decode(176, &mut message); + let result = DisplayGroupUpdate::decode(&test_context(), &mut message); assert!(result.is_err()); } diff --git a/src/market_data/realtime/async.rs b/src/market_data/realtime/async.rs index 3bdb6fef..b5943317 100644 --- a/src/market_data/realtime/async.rs +++ b/src/market_data/realtime/async.rs @@ -7,7 +7,7 @@ use crate::messages::OutgoingMessages; use crate::messages::{IncomingMessages, Notice, ResponseMessage}; use crate::protocol::{check_version, Features}; #[cfg(not(feature = "sync"))] -use crate::subscriptions::ResponseContext; +use crate::subscriptions::DecoderContext; #[cfg(not(feature = "sync"))] use crate::subscriptions::StreamDecoder; use crate::subscriptions::Subscription; @@ -23,7 +23,7 @@ use crate::market_data::TradingHours; impl StreamDecoder for BidAsk { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::TickByTick]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::TickByTick => decoders::decode_bid_ask_tick(message), IncomingMessages::Error => Err(Error::from(message.clone())), @@ -34,7 +34,7 @@ impl StreamDecoder for BidAsk { fn cancel_message( _server_version: i32, request_id: Option, - _context: Option<&ResponseContext>, + _context: Option<&DecoderContext>, ) -> Result { let request_id = request_id.expect("Request ID required to encode cancel tick by tick"); encoders::encode_cancel_tick_by_tick(request_id) @@ -45,7 +45,7 @@ impl StreamDecoder for BidAsk { impl StreamDecoder for MidPoint { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::TickByTick]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::TickByTick => decoders::decode_mid_point_tick(message), IncomingMessages::Error => Err(Error::from(message.clone())), @@ -56,7 +56,7 @@ impl StreamDecoder for MidPoint { fn cancel_message( _server_version: i32, request_id: Option, - _context: Option<&ResponseContext>, + _context: Option<&DecoderContext>, ) -> Result { let request_id = request_id.expect("Request ID required to encode cancel tick by tick"); encoders::encode_cancel_tick_by_tick(request_id) @@ -67,14 +67,14 @@ impl StreamDecoder for MidPoint { impl StreamDecoder for Bar { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::RealTimeBars]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { decoders::decode_realtime_bar(message) } fn cancel_message( _server_version: i32, request_id: Option, - _context: Option<&ResponseContext>, + _context: Option<&DecoderContext>, ) -> Result { let request_id = request_id.expect("Request ID required to encode cancel realtime bars"); encoders::encode_cancel_realtime_bars(request_id) @@ -85,7 +85,7 @@ impl StreamDecoder for Bar { impl StreamDecoder for Trade { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::TickByTick]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::TickByTick => decoders::decode_trade_tick(message), IncomingMessages::Error => Err(Error::from(message.clone())), @@ -96,7 +96,7 @@ impl StreamDecoder for Trade { fn cancel_message( _server_version: i32, request_id: Option, - _context: Option<&ResponseContext>, + _context: Option<&DecoderContext>, ) -> Result { let request_id = request_id.expect("Request ID required to encode cancel tick by tick"); encoders::encode_cancel_tick_by_tick(request_id) @@ -108,11 +108,14 @@ impl StreamDecoder for MarketDepths { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::MarketDepth, IncomingMessages::MarketDepthL2, IncomingMessages::Error]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { use crate::messages; match message.message_type() { IncomingMessages::MarketDepth => Ok(MarketDepths::MarketDepth(decoders::decode_market_depth(message)?)), - IncomingMessages::MarketDepthL2 => Ok(MarketDepths::MarketDepthL2(decoders::decode_market_depth_l2(server_version, message)?)), + IncomingMessages::MarketDepthL2 => Ok(MarketDepths::MarketDepthL2(decoders::decode_market_depth_l2( + context.server_version, + message, + )?)), IncomingMessages::Error => { let code = message.peek_int(messages::CODE_INDEX).unwrap(); if (2100..2200).contains(&code) { @@ -128,7 +131,7 @@ impl StreamDecoder for MarketDepths { fn cancel_message( server_version: i32, request_id: Option, - context: Option<&ResponseContext>, + context: Option<&DecoderContext>, ) -> Result { let request_id = request_id.expect("Request ID required to encode cancel market depth"); encoders::encode_cancel_market_depth(server_version, request_id, context.map(|c| c.is_smart_depth).unwrap_or(false)) @@ -149,15 +152,15 @@ impl StreamDecoder for TickTypes { IncomingMessages::TickReqParams, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::TickPrice => Ok(decoders::decode_tick_price(server_version, message)?), + IncomingMessages::TickPrice => Ok(decoders::decode_tick_price(context.server_version, message)?), IncomingMessages::TickSize => Ok(TickTypes::Size(decoders::decode_tick_size(message)?)), IncomingMessages::TickString => Ok(TickTypes::String(decoders::decode_tick_string(message)?)), IncomingMessages::TickEFP => Ok(TickTypes::EFP(decoders::decode_tick_efp(message)?)), IncomingMessages::TickGeneric => Ok(TickTypes::Generic(decoders::decode_tick_generic(message)?)), IncomingMessages::TickOptionComputation => Ok(TickTypes::OptionComputation(decoders::decode_tick_option_computation( - server_version, + context.server_version, message, )?)), IncomingMessages::TickReqParams => Ok(TickTypes::RequestParameters(decoders::decode_tick_request_parameters(message)?)), diff --git a/src/market_data/realtime/mod.rs b/src/market_data/realtime/mod.rs index a3edf652..bf46db71 100644 --- a/src/market_data/realtime/mod.rs +++ b/src/market_data/realtime/mod.rs @@ -3,13 +3,13 @@ use time::OffsetDateTime; use crate::ToField; -#[cfg(feature = "sync")] -use crate::client::{ResponseContext, StreamDecoder}; use crate::contracts::OptionComputation; use crate::messages::Notice; #[cfg(feature = "sync")] use crate::messages::{self, IncomingMessages, RequestMessage, ResponseMessage}; #[cfg(feature = "sync")] +use crate::subscriptions::{DecoderContext, StreamDecoder}; +#[cfg(feature = "sync")] use crate::Error; // Common modules @@ -70,7 +70,7 @@ pub struct BidAsk { impl StreamDecoder for BidAsk { const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::TickByTick]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::TickByTick => common::decoders::decode_bid_ask_tick(message), IncomingMessages::Error => Err(Error::from(message.clone())), @@ -78,7 +78,7 @@ impl StreamDecoder for BidAsk { } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel realtime bars"); common::encoders::encode_cancel_tick_by_tick(request_id) } @@ -106,7 +106,7 @@ pub struct MidPoint { impl StreamDecoder for MidPoint { const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::TickByTick]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::TickByTick => common::decoders::decode_mid_point_tick(message), IncomingMessages::Error => Err(Error::from(message.clone())), @@ -114,7 +114,7 @@ impl StreamDecoder for MidPoint { } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel mid point ticks"); common::encoders::encode_cancel_tick_by_tick(request_id) } @@ -145,11 +145,11 @@ pub struct Bar { impl StreamDecoder for Bar { const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::RealTimeBars]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { common::decoders::decode_realtime_bar(message) } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel realtime bars"); common::encoders::encode_cancel_realtime_bars(request_id) } @@ -178,7 +178,7 @@ pub struct Trade { impl StreamDecoder for Trade { const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::TickByTick]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::TickByTick => common::decoders::decode_trade_tick(message), IncomingMessages::Error => Err(Error::from(message.clone())), @@ -186,7 +186,7 @@ impl StreamDecoder for Trade { } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel realtime bars"); common::encoders::encode_cancel_tick_by_tick(request_id) } @@ -280,11 +280,11 @@ pub struct MarketDepthL2 { impl StreamDecoder for MarketDepths { const RESPONSE_MESSAGE_IDS: &[IncomingMessages] = &[IncomingMessages::MarketDepth, IncomingMessages::MarketDepthL2, IncomingMessages::Error]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::MarketDepth => Ok(MarketDepths::MarketDepth(common::decoders::decode_market_depth(message)?)), IncomingMessages::MarketDepthL2 => Ok(MarketDepths::MarketDepthL2(common::decoders::decode_market_depth_l2( - server_version, + context.server_version, message, )?)), IncomingMessages::Error => { @@ -299,7 +299,7 @@ impl StreamDecoder for MarketDepths { } } - fn cancel_message(server_version: i32, request_id: Option, context: Option<&ResponseContext>) -> Result { + fn cancel_message(server_version: i32, request_id: Option, context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel realtime bars"); common::encoders::encode_cancel_market_depth(server_version, request_id, context.map(|c| c.is_smart_depth).unwrap_or(false)) } @@ -359,15 +359,15 @@ impl StreamDecoder for TickTypes { IncomingMessages::TickReqParams, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::TickPrice => Ok(common::decoders::decode_tick_price(server_version, message)?), + IncomingMessages::TickPrice => Ok(common::decoders::decode_tick_price(context.server_version, message)?), IncomingMessages::TickSize => Ok(TickTypes::Size(common::decoders::decode_tick_size(message)?)), IncomingMessages::TickString => Ok(TickTypes::String(common::decoders::decode_tick_string(message)?)), IncomingMessages::TickEFP => Ok(TickTypes::EFP(common::decoders::decode_tick_efp(message)?)), IncomingMessages::TickGeneric => Ok(TickTypes::Generic(common::decoders::decode_tick_generic(message)?)), IncomingMessages::TickOptionComputation => Ok(TickTypes::OptionComputation(common::decoders::decode_tick_option_computation( - server_version, + context.server_version, message, )?)), IncomingMessages::TickReqParams => Ok(TickTypes::RequestParameters(common::decoders::decode_tick_request_parameters(message)?)), @@ -377,7 +377,7 @@ impl StreamDecoder for TickTypes { } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel realtime bars"); common::encoders::encode_cancel_market_data(request_id) } diff --git a/src/market_data/realtime/sync.rs b/src/market_data/realtime/sync.rs index 96203c88..8b5ba527 100644 --- a/src/market_data/realtime/sync.rs +++ b/src/market_data/realtime/sync.rs @@ -1,7 +1,6 @@ use log::debug; use crate::client::blocking::{ClientRequestBuilders, Subscription}; -use crate::client::ResponseContext; use crate::contracts::Contract; use crate::messages::OutgoingMessages; use crate::orders::TagValue; @@ -125,13 +124,7 @@ pub(crate) fn market_depth( let builder = client.request(); let request = encoders::encode_request_market_depth(client.server_version(), builder.request_id(), contract, number_of_rows, is_smart_depth)?; - builder.send_with_context( - request, - ResponseContext { - is_smart_depth, - ..Default::default() - }, - ) + builder.send_with_context(request, client.decoder_context().with_smart_depth(is_smart_depth)) } /// Fetch the venues that provide market depth data for the connected account. diff --git a/src/news/async.rs b/src/news/async.rs index 18bca18b..55f3f75e 100644 --- a/src/news/async.rs +++ b/src/news/async.rs @@ -7,24 +7,23 @@ use crate::market_data::realtime; use crate::messages::OutgoingMessages; #[cfg(not(feature = "sync"))] use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; -use crate::subscriptions::ResponseContext; -#[cfg(not(feature = "sync"))] -use crate::subscriptions::StreamDecoder; use crate::subscriptions::Subscription; +#[cfg(not(feature = "sync"))] +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::{server_versions, Client, Error}; #[cfg(not(feature = "sync"))] impl StreamDecoder for NewsBulletin { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::NewsBulletins]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::NewsBulletins => Ok(decoders::decode_news_bulletin(message.clone())?), _ => Err(Error::UnexpectedResponse(message.clone())), } } - fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&DecoderContext>) -> Result { encoders::encode_cancel_news_bulletin() } } @@ -37,7 +36,7 @@ impl StreamDecoder for NewsArticle { IncomingMessages::TickNews, ]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::HistoricalNews => Ok(decoders::decode_historical_news(None, message.clone())?), IncomingMessages::HistoricalNewsEnd => Err(Error::EndOfStream), @@ -46,7 +45,7 @@ impl StreamDecoder for NewsArticle { } } - fn cancel_message(_server_version: i32, request_id: Option, context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, context: Option<&DecoderContext>) -> Result { // News articles can come from market data subscriptions, so use the appropriate cancel if context.and_then(|c| c.request_type) == Some(OutgoingMessages::RequestMarketData) { let request_id = request_id.expect("Request ID required to encode cancel market data"); @@ -79,12 +78,11 @@ pub(crate) async fn news_bulletins(client: &Client, all_messages: bool) -> Resul Ok(Subscription::new_from_internal::( internal_subscription, - client.server_version(), client.message_bus.clone(), None, None, Some(OutgoingMessages::RequestNewsBulletins), - Default::default(), + client.decoder_context(), )) } @@ -113,12 +111,11 @@ pub(crate) async fn historical_news( Ok(Subscription::new_from_internal::( internal_subscription, - client.server_version(), client.message_bus.clone(), Some(request_id), None, None, - Default::default(), + client.decoder_context(), )) } @@ -159,15 +156,11 @@ pub(crate) async fn contract_news(client: &Client, contract: &Contract, provider Ok(Subscription::new_from_internal::( internal_subscription, - client.server_version(), client.message_bus.clone(), Some(request_id), None, None, - ResponseContext { - request_type: Some(OutgoingMessages::RequestMarketData), - ..Default::default() - }, + client.decoder_context().with_request_type(OutgoingMessages::RequestMarketData), )) } @@ -183,15 +176,11 @@ pub(crate) async fn broad_tape_news(client: &Client, provider_code: &str) -> Res Ok(Subscription::new_from_internal::( internal_subscription, - client.server_version(), client.message_bus.clone(), Some(request_id), None, None, - ResponseContext { - request_type: Some(OutgoingMessages::RequestMarketData), - ..Default::default() - }, + client.decoder_context().with_request_type(OutgoingMessages::RequestMarketData), )) } diff --git a/src/news/sync.rs b/src/news/sync.rs index 0315b3e7..b18f05f1 100644 --- a/src/news/sync.rs +++ b/src/news/sync.rs @@ -5,23 +5,23 @@ use std::sync::Arc; use super::common::{decoders, encoders}; use super::*; use crate::client::blocking::{SharesChannel, Subscription}; -use crate::client::{ResponseContext, StreamDecoder}; use crate::contracts::Contract; use crate::market_data::realtime; use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::{client::sync::Client, server_versions, Error}; impl SharesChannel for Vec {} impl StreamDecoder for NewsBulletin { - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::NewsBulletins => Ok(decoders::decode_news_bulletin(message.clone())?), _ => Err(Error::UnexpectedResponse(message.clone())), } } - fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&DecoderContext>) -> Result { encoders::encode_cancel_news_bulletin() } } @@ -29,7 +29,7 @@ impl StreamDecoder for NewsBulletin { impl SharesChannel for Subscription {} impl StreamDecoder for NewsArticle { - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::HistoricalNews => Ok(decoders::decode_historical_news(None, message.clone())?), IncomingMessages::HistoricalNewsEnd => Err(Error::EndOfStream), @@ -38,7 +38,7 @@ impl StreamDecoder for NewsArticle { } } - fn cancel_message(_server_version: i32, request_id: Option, context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, context: Option<&DecoderContext>) -> Result { if context.and_then(|ctx| ctx.request_type) == Some(OutgoingMessages::RequestMarketData) { let request_id = request_id.ok_or_else(|| Error::InvalidArgument("request id required to cancel market data subscription".to_string()))?; @@ -69,12 +69,7 @@ pub(crate) fn news_bulletins(client: &Client, all_messages: bool) -> Result Result for PlaceOrder { IncomingMessages::Error, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OpenOrder => Ok(PlaceOrder::OpenOrder(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(PlaceOrder::OrderStatus(decoders::decode_order_status(server_version, message)?)), - IncomingMessages::ExecutionData => Ok(PlaceOrder::ExecutionData(decoders::decode_execution_data(server_version, message)?)), - IncomingMessages::CommissionsReport => Ok(PlaceOrder::CommissionReport(decoders::decode_commission_report(server_version, message)?)), + IncomingMessages::OpenOrder => Ok(PlaceOrder::OpenOrder(decoders::decode_open_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::OrderStatus => Ok(PlaceOrder::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), + IncomingMessages::ExecutionData => Ok(PlaceOrder::ExecutionData(decoders::decode_execution_data( + context.server_version, + message, + )?)), + IncomingMessages::CommissionsReport => Ok(PlaceOrder::CommissionReport(decoders::decode_commission_report( + context.server_version, + message, + )?)), IncomingMessages::Error => Ok(PlaceOrder::Message(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), } @@ -49,13 +58,19 @@ impl StreamDecoder for OrderUpdate { IncomingMessages::Error, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OpenOrder => Ok(OrderUpdate::OpenOrder(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(OrderUpdate::OrderStatus(decoders::decode_order_status(server_version, message)?)), - IncomingMessages::ExecutionData => Ok(OrderUpdate::ExecutionData(decoders::decode_execution_data(server_version, message)?)), + IncomingMessages::OpenOrder => Ok(OrderUpdate::OpenOrder(decoders::decode_open_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::OrderStatus => Ok(OrderUpdate::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), + IncomingMessages::ExecutionData => Ok(OrderUpdate::ExecutionData(decoders::decode_execution_data( + context.server_version, + message, + )?)), IncomingMessages::CommissionsReport => Ok(OrderUpdate::CommissionReport(decoders::decode_commission_report( - server_version, + context.server_version, message, )?)), IncomingMessages::Error => Ok(OrderUpdate::Message(Notice::from(message))), @@ -68,9 +83,9 @@ impl StreamDecoder for OrderUpdate { impl StreamDecoder for CancelOrder { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::OrderStatus, IncomingMessages::Error]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OrderStatus => Ok(CancelOrder::OrderStatus(decoders::decode_order_status(server_version, message)?)), + IncomingMessages::OrderStatus => Ok(CancelOrder::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), IncomingMessages::Error => Ok(CancelOrder::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), } @@ -89,12 +104,15 @@ impl StreamDecoder for Orders { IncomingMessages::Error, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::CompletedOrder => Ok(Orders::OrderData(decoders::decode_completed_order(server_version, message.clone())?)), - IncomingMessages::CommissionsReport => Ok(Orders::OrderData(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OpenOrder => Ok(Orders::OrderData(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(Orders::OrderStatus(decoders::decode_order_status(server_version, message)?)), + IncomingMessages::CompletedOrder => Ok(Orders::OrderData(decoders::decode_completed_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::CommissionsReport => Ok(Orders::OrderData(decoders::decode_open_order(context.server_version, message.clone())?)), + IncomingMessages::OpenOrder => Ok(Orders::OrderData(decoders::decode_open_order(context.server_version, message.clone())?)), + IncomingMessages::OrderStatus => Ok(Orders::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), IncomingMessages::OpenOrderEnd | IncomingMessages::CompletedOrdersEnd => Err(Error::EndOfStream), IncomingMessages::Error => Ok(Orders::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), @@ -111,10 +129,16 @@ impl StreamDecoder for Executions { IncomingMessages::Error, ]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::ExecutionData => Ok(Executions::ExecutionData(decoders::decode_execution_data(server_version, message)?)), - IncomingMessages::CommissionsReport => Ok(Executions::CommissionReport(decoders::decode_commission_report(server_version, message)?)), + IncomingMessages::ExecutionData => Ok(Executions::ExecutionData(decoders::decode_execution_data( + context.server_version, + message, + )?)), + IncomingMessages::CommissionsReport => Ok(Executions::CommissionReport(decoders::decode_commission_report( + context.server_version, + message, + )?)), IncomingMessages::ExecutionDataEnd => Err(Error::EndOfStream), IncomingMessages::Error => Ok(Executions::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), @@ -126,10 +150,16 @@ impl StreamDecoder for Executions { impl StreamDecoder for ExerciseOptions { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::OpenOrder, IncomingMessages::OrderStatus, IncomingMessages::Error]; - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OpenOrder => Ok(ExerciseOptions::OpenOrder(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(ExerciseOptions::OrderStatus(decoders::decode_order_status(server_version, message)?)), + IncomingMessages::OpenOrder => Ok(ExerciseOptions::OpenOrder(decoders::decode_open_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::OrderStatus => Ok(ExerciseOptions::OrderStatus(decoders::decode_order_status( + context.server_version, + message, + )?)), IncomingMessages::Error => Ok(ExerciseOptions::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), } @@ -143,7 +173,7 @@ pub(crate) async fn order_update_stream(client: &Client) -> Result( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -187,7 +217,7 @@ pub(crate) async fn place_order(client: &Client, order_id: i32, contract: &Contr Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -203,7 +233,7 @@ pub(crate) async fn cancel_order(client: &Client, order_id: i32, manual_order_ca Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -247,7 +277,7 @@ pub(crate) async fn completed_orders(client: &Client, api_only: bool) -> Result< let internal_subscription = client.send_shared_request(OutgoingMessages::RequestCompletedOrders, request).await?; Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -263,7 +293,7 @@ pub(crate) async fn open_orders(client: &Client) -> Result, let internal_subscription = client.send_shared_request(OutgoingMessages::RequestOpenOrders, request).await?; Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -276,7 +306,7 @@ pub(crate) async fn all_open_orders(client: &Client) -> Result( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -288,7 +318,7 @@ pub(crate) async fn auto_open_orders(client: &Client, auto_bind: bool) -> Result let internal_subscription = client.send_shared_request(OutgoingMessages::RequestAutoOpenOrders, request).await?; Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -307,7 +337,7 @@ pub(crate) async fn executions(client: &Client, filter: ExecutionFilter) -> Resu let internal_subscription = client.send_request(request_id, request).await?; Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } @@ -336,7 +366,7 @@ pub(crate) async fn exercise_options( let internal_subscription = client.send_order(order_id, request).await?; Ok(Subscription::new_from_internal_simple::( internal_subscription, - client.server_version(), + client.decoder_context(), client.message_bus.clone(), )) } diff --git a/src/orders/sync.rs b/src/orders/sync.rs index 6bc5dde5..b42f0aa4 100644 --- a/src/orders/sync.rs +++ b/src/orders/sync.rs @@ -3,19 +3,28 @@ use std::sync::Arc; use super::common::{decoders, encoders, verify}; use super::{CancelOrder, ExecutionFilter, Executions, ExerciseAction, ExerciseOptions, OrderUpdate, Orders, PlaceOrder}; use crate::client::blocking::Subscription; -use crate::client::StreamDecoder; use crate::contracts::Contract; use crate::messages::{IncomingMessages, Notice, OutgoingMessages, ResponseMessage}; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::{client::sync::Client, server_versions, Error}; use time::OffsetDateTime; impl StreamDecoder for PlaceOrder { - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OpenOrder => Ok(PlaceOrder::OpenOrder(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(PlaceOrder::OrderStatus(decoders::decode_order_status(server_version, message)?)), - IncomingMessages::ExecutionData => Ok(PlaceOrder::ExecutionData(decoders::decode_execution_data(server_version, message)?)), - IncomingMessages::CommissionsReport => Ok(PlaceOrder::CommissionReport(decoders::decode_commission_report(server_version, message)?)), + IncomingMessages::OpenOrder => Ok(PlaceOrder::OpenOrder(decoders::decode_open_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::OrderStatus => Ok(PlaceOrder::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), + IncomingMessages::ExecutionData => Ok(PlaceOrder::ExecutionData(decoders::decode_execution_data( + context.server_version, + message, + )?)), + IncomingMessages::CommissionsReport => Ok(PlaceOrder::CommissionReport(decoders::decode_commission_report( + context.server_version, + message, + )?)), IncomingMessages::Error => Ok(PlaceOrder::Message(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), } @@ -23,13 +32,19 @@ impl StreamDecoder for PlaceOrder { } impl StreamDecoder for OrderUpdate { - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OpenOrder => Ok(OrderUpdate::OpenOrder(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(OrderUpdate::OrderStatus(decoders::decode_order_status(server_version, message)?)), - IncomingMessages::ExecutionData => Ok(OrderUpdate::ExecutionData(decoders::decode_execution_data(server_version, message)?)), + IncomingMessages::OpenOrder => Ok(OrderUpdate::OpenOrder(decoders::decode_open_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::OrderStatus => Ok(OrderUpdate::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), + IncomingMessages::ExecutionData => Ok(OrderUpdate::ExecutionData(decoders::decode_execution_data( + context.server_version, + message, + )?)), IncomingMessages::CommissionsReport => Ok(OrderUpdate::CommissionReport(decoders::decode_commission_report( - server_version, + context.server_version, message, )?)), IncomingMessages::Error => Ok(OrderUpdate::Message(Notice::from(message))), @@ -39,9 +54,9 @@ impl StreamDecoder for OrderUpdate { } impl StreamDecoder for CancelOrder { - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OrderStatus => Ok(CancelOrder::OrderStatus(decoders::decode_order_status(server_version, message)?)), + IncomingMessages::OrderStatus => Ok(CancelOrder::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), IncomingMessages::Error => Ok(CancelOrder::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), } @@ -49,12 +64,15 @@ impl StreamDecoder for CancelOrder { } impl StreamDecoder for Orders { - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::CompletedOrder => Ok(Orders::OrderData(decoders::decode_completed_order(server_version, message.clone())?)), - IncomingMessages::CommissionsReport => Ok(Orders::OrderData(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OpenOrder => Ok(Orders::OrderData(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(Orders::OrderStatus(decoders::decode_order_status(server_version, message)?)), + IncomingMessages::CompletedOrder => Ok(Orders::OrderData(decoders::decode_completed_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::CommissionsReport => Ok(Orders::OrderData(decoders::decode_open_order(context.server_version, message.clone())?)), + IncomingMessages::OpenOrder => Ok(Orders::OrderData(decoders::decode_open_order(context.server_version, message.clone())?)), + IncomingMessages::OrderStatus => Ok(Orders::OrderStatus(decoders::decode_order_status(context.server_version, message)?)), IncomingMessages::OpenOrderEnd | IncomingMessages::CompletedOrdersEnd => Err(Error::EndOfStream), IncomingMessages::Error => Ok(Orders::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), @@ -63,10 +81,16 @@ impl StreamDecoder for Orders { } impl StreamDecoder for Executions { - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::ExecutionData => Ok(Executions::ExecutionData(decoders::decode_execution_data(server_version, message)?)), - IncomingMessages::CommissionsReport => Ok(Executions::CommissionReport(decoders::decode_commission_report(server_version, message)?)), + IncomingMessages::ExecutionData => Ok(Executions::ExecutionData(decoders::decode_execution_data( + context.server_version, + message, + )?)), + IncomingMessages::CommissionsReport => Ok(Executions::CommissionReport(decoders::decode_commission_report( + context.server_version, + message, + )?)), IncomingMessages::ExecutionDataEnd => Err(Error::EndOfStream), IncomingMessages::Error => Ok(Executions::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), @@ -75,10 +99,16 @@ impl StreamDecoder for Executions { } impl StreamDecoder for ExerciseOptions { - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { - IncomingMessages::OpenOrder => Ok(ExerciseOptions::OpenOrder(decoders::decode_open_order(server_version, message.clone())?)), - IncomingMessages::OrderStatus => Ok(ExerciseOptions::OrderStatus(decoders::decode_order_status(server_version, message)?)), + IncomingMessages::OpenOrder => Ok(ExerciseOptions::OpenOrder(decoders::decode_open_order( + context.server_version, + message.clone(), + )?)), + IncomingMessages::OrderStatus => Ok(ExerciseOptions::OrderStatus(decoders::decode_order_status( + context.server_version, + message, + )?)), IncomingMessages::Error => Ok(ExerciseOptions::Notice(Notice::from(message))), _ => Err(Error::UnexpectedResponse(message.clone())), } @@ -90,12 +120,7 @@ impl StreamDecoder for ExerciseOptions { /// This function returns a subscription that will receive updates of activity for all orders placed by the client. pub(crate) fn order_update_stream(client: &Client) -> Result, Error> { let subscription = client.create_order_update_subscription()?; - Ok(Subscription::new( - client.server_version, - Arc::clone(&client.message_bus), - subscription, - None, - )) + Ok(Subscription::new(Arc::clone(&client.message_bus), subscription, client.decoder_context())) } /// Submits an Order. @@ -148,12 +173,7 @@ pub(crate) fn place_order(client: &Client, order_id: i32, contract: &Contract, o let request = encoders::encode_place_order(client.server_version, order_id, contract, order)?; let subscription = client.send_order(order_id, request)?; - Ok(Subscription::new( - client.server_version, - Arc::clone(&client.message_bus), - subscription, - None, - )) + Ok(Subscription::new(Arc::clone(&client.message_bus), subscription, client.decoder_context())) } /// Cancels an open Order and returns a subscription to receive cancellation events. @@ -177,12 +197,7 @@ pub(crate) fn cancel_order(client: &Client, order_id: i32, manual_order_cancel_t let request = encoders::encode_cancel_order(client.server_version, order_id, manual_order_cancel_time)?; let subscription = client.send_order(order_id, request)?; - Ok(Subscription::new( - client.server_version, - Arc::clone(&client.message_bus), - subscription, - None, - )) + Ok(Subscription::new(Arc::clone(&client.message_bus), subscription, client.decoder_context())) } /// Cancels all open Orders. @@ -242,12 +257,7 @@ pub(crate) fn completed_orders(client: &Client, api_only: bool) -> Result Result, Error let request = encoders::encode_open_orders()?; let subscription = client.send_shared_request(OutgoingMessages::RequestOpenOrders, request)?; - Ok(Subscription::new( - client.server_version, - Arc::clone(&client.message_bus), - subscription, - None, - )) + Ok(Subscription::new(Arc::clone(&client.message_bus), subscription, client.decoder_context())) } /// Requests all current open orders in associated accounts at the current moment. @@ -282,12 +287,7 @@ pub(crate) fn all_open_orders(client: &Client) -> Result, E let request = encoders::encode_all_open_orders()?; let subscription = client.send_shared_request(OutgoingMessages::RequestAllOpenOrders, request)?; - Ok(Subscription::new( - client.server_version, - Arc::clone(&client.message_bus), - subscription, - None, - )) + Ok(Subscription::new(Arc::clone(&client.message_bus), subscription, client.decoder_context())) } /// Requests status updates about future orders placed from TWS. @@ -305,12 +305,7 @@ pub(crate) fn auto_open_orders(client: &Client, auto_bind: bool) -> Result Result> for Vec { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::ScannerData]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result, Error> { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result, Error> { decoders::decode_scanner_message(message) } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel scanner subscription."); encoders::encode_cancel_scanner_subscription(request_id) } @@ -56,12 +56,11 @@ pub(crate) async fn scanner_subscription( Ok(Subscription::new_from_internal::>( internal_subscription, - client.server_version(), client.message_bus.clone(), Some(request_id), None, None, - Default::default(), + client.decoder_context(), )) } diff --git a/src/scanner/sync.rs b/src/scanner/sync.rs index f7313145..b73655bb 100644 --- a/src/scanner/sync.rs +++ b/src/scanner/sync.rs @@ -5,17 +5,17 @@ use std::sync::Arc; use super::common::{decoders, encoders}; use super::*; use crate::client::blocking::Subscription; -use crate::client::{ResponseContext, StreamDecoder}; use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage}; use crate::orders::TagValue; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::{client::sync::Client, server_versions, Error}; impl StreamDecoder> for Vec { - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result, Error> { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result, Error> { decoders::decode_scanner_message(message) } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = request_id.expect("Request ID required to encode cancel scanner subscription."); encoders::encode_cancel_scanner_subscription(request_id) } @@ -50,12 +50,7 @@ pub(crate) fn scanner_subscription( let request = encoders::encode_scanner_subscription(request_id, client.server_version, subscription, filter)?; let subscription = client.send_request(request_id, request)?; - Ok(Subscription::new( - client.server_version, - Arc::clone(&client.message_bus), - subscription, - None, - )) + Ok(Subscription::new(Arc::clone(&client.message_bus), subscription, client.decoder_context())) } #[cfg(test)] diff --git a/src/subscriptions/async.rs b/src/subscriptions/async.rs index 4fb6810b..f3a49125 100644 --- a/src/subscriptions/async.rs +++ b/src/subscriptions/async.rs @@ -6,15 +6,15 @@ use std::sync::Arc; use log::{debug, warn}; use tokio::sync::mpsc; -use super::common::{check_retry, process_decode_result, ProcessingResult, RetryDecision}; -use super::{ResponseContext, StreamDecoder}; +use super::common::{check_retry, process_decode_result, DecoderContext, ProcessingResult, RetryDecision}; +use super::StreamDecoder; use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage}; use crate::transport::{AsyncInternalSubscription, AsyncMessageBus}; use crate::Error; // Type aliases to reduce complexity -type CancelFn = Box, Option<&ResponseContext>) -> Result + Send + Sync>; -type DecoderFn = Arc Result + Send + Sync>; +type CancelFn = Box, Option<&DecoderContext>) -> Result + Send + Sync>; +type DecoderFn = Arc Result + Send + Sync>; /// Asynchronous subscription for streaming data pub struct Subscription { @@ -23,9 +23,8 @@ pub struct Subscription { request_id: Option, order_id: Option, _message_type: Option, - response_context: ResponseContext, + context: DecoderContext, cancelled: Arc, - server_version: i32, message_bus: Option>, /// Cancel message generator cancel_fn: Option>, @@ -36,7 +35,7 @@ enum SubscriptionInner { WithDecoder { subscription: AsyncInternalSubscription, decoder: DecoderFn, - server_version: i32, + context: DecoderContext, }, /// Pre-decoded subscription - receives T directly PreDecoded { receiver: mpsc::UnboundedReceiver> }, @@ -48,11 +47,11 @@ impl Clone for SubscriptionInner { SubscriptionInner::WithDecoder { subscription, decoder, - server_version, + context, } => SubscriptionInner::WithDecoder { subscription: subscription.clone(), decoder: decoder.clone(), - server_version: *server_version, + context: context.clone(), }, SubscriptionInner::PreDecoded { .. } => { // Can't clone mpsc receivers @@ -69,9 +68,8 @@ impl Clone for Subscription { request_id: self.request_id, order_id: self.order_id, _message_type: self._message_type, - response_context: self.response_context.clone(), + context: self.context.clone(), cancelled: self.cancelled.clone(), - server_version: self.server_version, message_bus: self.message_bus.clone(), cancel_fn: self.cancel_fn.clone(), } @@ -83,29 +81,27 @@ impl Subscription { #[allow(clippy::too_many_arguments)] pub fn with_decoder( internal: AsyncInternalSubscription, - server_version: i32, message_bus: Arc, decoder: D, request_id: Option, order_id: Option, message_type: Option, - response_context: ResponseContext, + context: DecoderContext, ) -> Self where - D: Fn(i32, &mut ResponseMessage) -> Result + Send + Sync + 'static, + D: Fn(&DecoderContext, &mut ResponseMessage) -> Result + Send + Sync + 'static, { Self { inner: SubscriptionInner::WithDecoder { subscription: internal, decoder: Arc::new(decoder), - server_version, + context: context.clone(), }, request_id, order_id, _message_type: message_type, - response_context, + context, cancelled: Arc::new(AtomicBool::new(false)), - server_version, message_bus: Some(message_bus), cancel_fn: None, } @@ -115,93 +111,67 @@ impl Subscription { #[allow(clippy::too_many_arguments)] pub fn new_with_decoder( internal: AsyncInternalSubscription, - server_version: i32, message_bus: Arc, decoder: F, request_id: Option, order_id: Option, message_type: Option, - response_context: ResponseContext, + context: DecoderContext, ) -> Self where - F: Fn(i32, &mut ResponseMessage) -> Result + Send + Sync + 'static, + F: Fn(&DecoderContext, &mut ResponseMessage) -> Result + Send + Sync + 'static, { - Self::with_decoder( - internal, - server_version, - message_bus, - decoder, - request_id, - order_id, - message_type, - response_context, - ) + Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context) } /// Create a subscription from components and a decoder (alias for with_decoder) #[allow(clippy::too_many_arguments)] pub fn with_decoder_components( internal: AsyncInternalSubscription, - server_version: i32, message_bus: Arc, decoder: D, request_id: Option, order_id: Option, message_type: Option, - response_context: ResponseContext, + context: DecoderContext, ) -> Self where - D: Fn(i32, &mut ResponseMessage) -> Result + Send + Sync + 'static, + D: Fn(&DecoderContext, &mut ResponseMessage) -> Result + Send + Sync + 'static, { - Self::with_decoder( - internal, - server_version, - message_bus, - decoder, - request_id, - order_id, - message_type, - response_context, - ) + Self::with_decoder(internal, message_bus, decoder, request_id, order_id, message_type, context) } /// Create a subscription from an internal subscription using the DataStream decoder pub(crate) fn new_from_internal( internal: AsyncInternalSubscription, - server_version: i32, message_bus: Arc, request_id: Option, order_id: Option, message_type: Option, - response_context: ResponseContext, + context: DecoderContext, ) -> Self where D: StreamDecoder + 'static, T: 'static, { - let mut sub = Self::with_decoder_components( - internal, - server_version, - message_bus, - D::decode, - request_id, - order_id, - message_type, - response_context, - ); + let mut sub = Self::with_decoder_components(internal, message_bus, D::decode, request_id, order_id, message_type, context); // Store the cancel function sub.cancel_fn = Some(Arc::new(Box::new(D::cancel_message))); sub } /// Create a subscription from internal subscription without explicit metadata - pub(crate) fn new_from_internal_simple(internal: AsyncInternalSubscription, server_version: i32, message_bus: Arc) -> Self + pub(crate) fn new_from_internal_simple( + internal: AsyncInternalSubscription, + context: DecoderContext, + message_bus: Arc, + ) -> Self where D: StreamDecoder + 'static, T: 'static, { // The AsyncInternalSubscription already has cleanup logic, so we don't need cancel metadata - Self::new_from_internal::(internal, server_version, message_bus, None, None, None, ResponseContext::default()) + Self::new_from_internal::(internal, message_bus, None, None, None, context) } /// Create subscription from existing receiver (for backward compatibility) @@ -213,9 +183,8 @@ impl Subscription { request_id: None, order_id: None, _message_type: None, - response_context: ResponseContext::default(), + context: DecoderContext::default(), cancelled: Arc::new(AtomicBool::new(false)), - server_version: 0, // Default value for backward compatibility message_bus: None, cancel_fn: None, } @@ -230,13 +199,13 @@ impl Subscription { SubscriptionInner::WithDecoder { subscription, decoder, - server_version, + context, } => { let mut retry_count = 0; loop { match subscription.next().await { Some(Ok(mut message)) => { - let result = decoder(*server_version, &mut message); + let result = decoder(context, &mut message); match process_decode_result(result) { ProcessingResult::Success(val) => return Some(Ok(val)), ProcessingResult::EndOfStream => return None, @@ -276,7 +245,7 @@ impl Subscription { if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) { let id = self.request_id.or(self.order_id); - if let Ok(message) = cancel_fn(self.server_version, id, Some(&self.response_context)) { + if let Ok(message) = cancel_fn(self.context.server_version, id, Some(&self.context)) { if let Err(e) = message_bus.send_message(message).await { warn!("error sending cancel message: {e}") } @@ -302,11 +271,10 @@ impl Drop for Subscription { if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) { let message_bus = message_bus.clone(); let id = self.request_id.or(self.order_id); - let response_context = self.response_context.clone(); - let server_version = self.server_version; + let context = self.context.clone(); // Clone the cancel function for use in the spawned task - if let Ok(message) = cancel_fn(server_version, id, Some(&response_context)) { + if let Ok(message) = cancel_fn(context.server_version, id, Some(&context)) { // Spawn a task to send the cancel message since drop can't be async tokio::spawn(async move { if let Err(e) = message_bus.send_message(message).await { @@ -346,9 +314,8 @@ mod tests { let subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus, - |_server_version, _msg| { + |_context, _msg| { let bar = Bar { date: OffsetDateTime::now_utc(), open: 100.5, @@ -364,7 +331,7 @@ mod tests { Some(9000), None, Some(OutgoingMessages::RequestRealTimeBars), - ResponseContext::default(), + DecoderContext::default(), ); // Send a test message @@ -388,13 +355,12 @@ mod tests { let subscription: Subscription = Subscription::new_with_decoder( internal, - 176, message_bus, - |_version, _msg| Ok("decoded".to_string()), + |_context, _msg| Ok("decoded".to_string()), Some(1), None, Some(OutgoingMessages::RequestMarketData), - ResponseContext::default(), + DecoderContext::default(), ); assert_eq!(subscription.request_id, Some(1)); @@ -409,13 +375,12 @@ mod tests { let subscription: Subscription = Subscription::with_decoder_components( internal, - 176, message_bus, - |_version, _msg| Ok(42), + |_context, _msg| Ok(42), Some(100), Some(200), Some(OutgoingMessages::RequestPositions), - ResponseContext::default(), + DecoderContext::default(), ); assert_eq!(subscription.request_id, Some(100)); @@ -444,13 +409,12 @@ mod tests { let mut subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus, - |_version, _msg| Err(Error::Simple("decode error".into())), + |_context, _msg| Err(Error::Simple("decode error".into())), None, None, None, - ResponseContext::default(), + DecoderContext::default(), ); // Send a message that will trigger the error @@ -470,13 +434,12 @@ mod tests { let mut subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus, - |_version, _msg| Err(Error::EndOfStream), + |_context, _msg| Err(Error::EndOfStream), None, None, None, - ResponseContext::default(), + DecoderContext::default(), ); // Send a message that will trigger end of stream @@ -502,13 +465,12 @@ mod tests { let mut subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus.clone(), - |_version, _msg| Ok("test".to_string()), + |_context, _msg| Ok("test".to_string()), Some(123), None, Some(OutgoingMessages::RequestMarketData), - ResponseContext::default(), + DecoderContext::default(), ); subscription.cancel_fn = Some(Arc::new(cancel_fn)); @@ -530,23 +492,21 @@ mod tests { let subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus, - |_version, _msg| Ok("test".to_string()), + |_context, _msg| Ok("test".to_string()), Some(456), Some(789), Some(OutgoingMessages::RequestPositions), - ResponseContext { - is_smart_depth: true, - request_type: Some(OutgoingMessages::RequestPositions), - }, + DecoderContext::default() + .with_smart_depth(true) + .with_request_type(OutgoingMessages::RequestPositions), ); let cloned = subscription.clone(); assert_eq!(cloned.request_id, Some(456)); assert_eq!(cloned.order_id, Some(789)); assert_eq!(cloned._message_type, Some(OutgoingMessages::RequestPositions)); - assert!(cloned.response_context.is_smart_depth); + assert!(cloned.context.is_smart_depth); } #[tokio::test] @@ -565,13 +525,12 @@ mod tests { { let mut subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus.clone(), - |_version, _msg| Ok("test".to_string()), + |_context, _msg| Ok("test".to_string()), Some(999), None, Some(OutgoingMessages::RequestMarketData), - ResponseContext::default(), + DecoderContext::default(), ); subscription.cancel_fn = Some(Arc::new(cancel_fn)); // Subscription will be dropped here and should send cancel message @@ -597,23 +556,21 @@ mod tests { let (_tx, rx) = broadcast::channel(100); let internal = AsyncInternalSubscription::new(rx); - let context = ResponseContext { - is_smart_depth: true, - request_type: Some(OutgoingMessages::RequestMarketDepth), - }; + let context = DecoderContext::default() + .with_smart_depth(true) + .with_request_type(OutgoingMessages::RequestMarketDepth); let subscription: Subscription = Subscription::with_decoder( internal, - 176, message_bus, - |_version, _msg| Ok("test".to_string()), + |_context, _msg| Ok("test".to_string()), None, None, None, context.clone(), ); - assert_eq!(subscription.response_context, context); + assert_eq!(subscription.context, context); } #[tokio::test] @@ -622,11 +579,11 @@ mod tests { struct TestDecoder; impl StreamDecoder for TestDecoder { - fn decode(_server_version: i32, _msg: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, _msg: &mut ResponseMessage) -> Result { Ok("decoded".to_string()) } - fn cancel_message(_server_version: i32, _id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, _id: Option, _context: Option<&DecoderContext>) -> Result { let mut msg = RequestMessage::new(); msg.push_field(&OutgoingMessages::CancelMarketData); Ok(msg) @@ -637,7 +594,8 @@ mod tests { let (_tx, rx) = broadcast::channel(100); let internal = AsyncInternalSubscription::new(rx); - let subscription: Subscription = Subscription::new_from_internal_simple::(internal, 176, message_bus); + let subscription: Subscription = + Subscription::new_from_internal_simple::(internal, DecoderContext::default(), message_bus); assert!(subscription.cancel_fn.is_some()); } diff --git a/src/subscriptions/common.rs b/src/subscriptions/common.rs index 6f707495..775736ed 100644 --- a/src/subscriptions/common.rs +++ b/src/subscriptions/common.rs @@ -1,5 +1,7 @@ //! Common utilities for subscription processing +use time_tz::Tz; + use crate::errors::Error; use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage}; @@ -139,48 +141,77 @@ mod tests { } #[test] - fn test_response_context_default() { - let context = ResponseContext::default(); + fn test_decoder_context_default() { + let context = DecoderContext::default(); + assert_eq!(context.server_version, 0); + assert!(context.time_zone.is_none()); + assert!(context.request_type.is_none()); assert!(!context.is_smart_depth); + } + + #[test] + fn test_decoder_context_new() { + let context = DecoderContext::new(176, None); + assert_eq!(context.server_version, 176); + assert!(context.time_zone.is_none()); assert!(context.request_type.is_none()); + assert!(!context.is_smart_depth); } #[test] - fn test_response_context_clone() { - let context = ResponseContext { + fn test_decoder_context_builder() { + let context = DecoderContext::new(176, None) + .with_request_type(crate::messages::OutgoingMessages::RequestMarketData) + .with_smart_depth(true); + + assert_eq!(context.server_version, 176); + assert_eq!(context.request_type, Some(crate::messages::OutgoingMessages::RequestMarketData)); + assert!(context.is_smart_depth); + } + + #[test] + fn test_decoder_context_clone() { + let context = DecoderContext { + server_version: 176, + time_zone: None, is_smart_depth: true, request_type: Some(crate::messages::OutgoingMessages::RequestMarketData), }; let cloned = context.clone(); assert_eq!(context, cloned); - assert_eq!(cloned.is_smart_depth, true); + assert_eq!(cloned.server_version, 176); + assert!(cloned.is_smart_depth); assert_eq!(cloned.request_type, Some(crate::messages::OutgoingMessages::RequestMarketData)); } #[test] - fn test_response_context_equality() { + fn test_decoder_context_equality() { struct TestCase { name: &'static str, - context1: ResponseContext, - context2: ResponseContext, + context1: DecoderContext, + context2: DecoderContext, expected: bool, } let test_cases = vec![ TestCase { name: "default_contexts_equal", - context1: ResponseContext::default(), - context2: ResponseContext::default(), + context1: DecoderContext::default(), + context2: DecoderContext::default(), expected: true, }, TestCase { name: "same_values_equal", - context1: ResponseContext { + context1: DecoderContext { + server_version: 176, + time_zone: None, is_smart_depth: true, request_type: Some(crate::messages::OutgoingMessages::RequestMarketData), }, - context2: ResponseContext { + context2: DecoderContext { + server_version: 176, + time_zone: None, is_smart_depth: true, request_type: Some(crate::messages::OutgoingMessages::RequestMarketData), }, @@ -188,37 +219,37 @@ mod tests { }, TestCase { name: "different_smart_depth", - context1: ResponseContext { + context1: DecoderContext { is_smart_depth: true, - request_type: None, + ..Default::default() }, - context2: ResponseContext { + context2: DecoderContext { is_smart_depth: false, - request_type: None, + ..Default::default() }, expected: false, }, TestCase { name: "different_request_type", - context1: ResponseContext { - is_smart_depth: false, + context1: DecoderContext { request_type: Some(crate::messages::OutgoingMessages::RequestMarketData), + ..Default::default() }, - context2: ResponseContext { - is_smart_depth: false, + context2: DecoderContext { request_type: Some(crate::messages::OutgoingMessages::CancelMarketData), + ..Default::default() }, expected: false, }, TestCase { - name: "one_none_one_some", - context1: ResponseContext { - is_smart_depth: false, - request_type: None, + name: "different_server_version", + context1: DecoderContext { + server_version: 175, + ..Default::default() }, - context2: ResponseContext { - is_smart_depth: false, - request_type: Some(crate::messages::OutgoingMessages::RequestMarketData), + context2: DecoderContext { + server_version: 176, + ..Default::default() }, expected: false, }, @@ -230,14 +261,17 @@ mod tests { } #[test] - fn test_response_context_debug_format() { - let context = ResponseContext { + fn test_decoder_context_debug_format() { + let context = DecoderContext { + server_version: 176, + time_zone: None, is_smart_depth: true, request_type: Some(crate::messages::OutgoingMessages::RequestMarketData), }; let debug_str = format!("{:?}", context); - assert!(debug_str.contains("ResponseContext")); + assert!(debug_str.contains("DecoderContext")); + assert!(debug_str.contains("server_version")); assert!(debug_str.contains("is_smart_depth")); assert!(debug_str.contains("true")); assert!(debug_str.contains("request_type")); @@ -245,30 +279,59 @@ mod tests { } } -/// Context information for response handling +/// Context for decoding responses, providing all necessary state for decoders. #[derive(Debug, Clone, Default, PartialEq)] -pub struct ResponseContext { +pub struct DecoderContext { + /// Server version for protocol compatibility + pub server_version: i32, + /// Timezone for parsing timestamps (from TWS connection) + pub time_zone: Option<&'static Tz>, /// Type of the original request that initiated this subscription pub request_type: Option, /// Whether this is a smart depth subscription pub is_smart_depth: bool, } +impl DecoderContext { + /// Create a new context with server version and optional timezone + pub fn new(server_version: i32, time_zone: Option<&'static Tz>) -> Self { + Self { + server_version, + time_zone, + request_type: None, + is_smart_depth: false, + } + } + + /// Set the request type + #[allow(dead_code)] + pub fn with_request_type(mut self, request_type: OutgoingMessages) -> Self { + self.request_type = Some(request_type); + self + } + + /// Set the smart depth flag + pub fn with_smart_depth(mut self, is_smart_depth: bool) -> Self { + self.is_smart_depth = is_smart_depth; + self + } +} + /// Common trait for decoding streaming data responses /// /// This trait is shared between sync and async implementations to avoid code duplication. -/// The key change from the original design is that `decode` takes `server_version` directly -/// instead of the entire `Client`, making it possible to share implementations. +/// Decoders receive a `DecoderContext` containing server version, timezone, and other +/// context needed to properly decode messages. pub(crate) trait StreamDecoder { /// Message types this stream can handle #[allow(dead_code)] const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[]; /// Decode a response message into the stream's data type - fn decode(server_version: i32, message: &mut ResponseMessage) -> Result; + fn decode(context: &DecoderContext, message: &mut ResponseMessage) -> Result; /// Generate a cancellation message for this stream - fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, _request_id: Option, _context: Option<&DecoderContext>) -> Result { Err(Error::NotImplemented) } diff --git a/src/subscriptions/mod.rs b/src/subscriptions/mod.rs index 36952b01..ed6ffdb4 100644 --- a/src/subscriptions/mod.rs +++ b/src/subscriptions/mod.rs @@ -1,7 +1,7 @@ //! Subscription types for sync/async streaming data mod common; -pub(crate) use common::{ResponseContext, StreamDecoder}; +pub(crate) use common::{DecoderContext, StreamDecoder}; #[cfg(feature = "sync")] pub mod sync; diff --git a/src/subscriptions/sync.rs b/src/subscriptions/sync.rs index dddfa950..3476de31 100644 --- a/src/subscriptions/sync.rs +++ b/src/subscriptions/sync.rs @@ -7,8 +7,8 @@ use std::time::Duration; use log::{debug, error, warn}; -use super::common::{check_retry, process_decode_result, should_retry_error, should_store_error, ProcessingResult, RetryDecision}; -use super::{ResponseContext, StreamDecoder}; +use super::common::{check_retry, process_decode_result, should_retry_error, should_store_error, DecoderContext, ProcessingResult, RetryDecision}; +use super::StreamDecoder; use crate::errors::Error; use crate::messages::{OutgoingMessages, ResponseMessage}; use crate::transport::{InternalSubscription, MessageBus}; @@ -20,7 +20,7 @@ use crate::transport::{InternalSubscription, MessageBus}; /// Alternatively, you may poll subscriptions in a blocking or non-blocking manner using the [next](Subscription::next), [try_next](Subscription::try_next) or [next_timeout](Subscription::next_timeout) methods. #[allow(private_bounds)] pub struct Subscription> { - server_version: i32, + context: DecoderContext, message_bus: Arc, request_id: Option, order_id: Option, @@ -29,31 +29,24 @@ pub struct Subscription> { cancelled: AtomicBool, snapshot_ended: AtomicBool, subscription: InternalSubscription, - response_context: Option, error: Mutex>, retry_count: AtomicUsize, } #[allow(private_bounds)] impl> Subscription { - pub(crate) fn new( - server_version: i32, - message_bus: Arc, - subscription: InternalSubscription, - context: Option, - ) -> Self { + pub(crate) fn new(message_bus: Arc, subscription: InternalSubscription, context: DecoderContext) -> Self { let request_id = subscription.request_id; let order_id = subscription.order_id; let message_type = subscription.message_type; Subscription { - server_version, + context, message_bus, request_id, order_id, message_type, subscription, - response_context: context, phantom: PhantomData, cancelled: AtomicBool::new(false), snapshot_ended: AtomicBool::new(false), @@ -77,21 +70,21 @@ impl> Subscription { self.cancelled.store(true, Ordering::Relaxed); if let Some(request_id) = self.request_id { - if let Ok(message) = T::cancel_message(self.server_version, self.request_id, self.response_context.as_ref()) { + if let Ok(message) = T::cancel_message(self.context.server_version, self.request_id, Some(&self.context)) { if let Err(e) = self.message_bus.cancel_subscription(request_id, &message) { warn!("error cancelling subscription: {e}") } self.subscription.cancel(); } } else if let Some(order_id) = self.order_id { - if let Ok(message) = T::cancel_message(self.server_version, self.request_id, self.response_context.as_ref()) { + if let Ok(message) = T::cancel_message(self.context.server_version, self.request_id, Some(&self.context)) { if let Err(e) = self.message_bus.cancel_order_subscription(order_id, &message) { warn!("error cancelling order subscription: {e}") } self.subscription.cancel(); } } else if let Some(message_type) = self.message_type { - if let Ok(message) = T::cancel_message(self.server_version, self.request_id, self.response_context.as_ref()) { + if let Ok(message) = T::cancel_message(self.context.server_version, self.request_id, Some(&self.context)) { if let Err(e) = self.message_bus.cancel_shared_subscription(message_type, &message) { warn!("error cancelling shared subscription: {e}") } @@ -195,7 +188,7 @@ impl> Subscription { } fn process_message(&self, mut message: ResponseMessage) -> Option { - match process_decode_result(T::decode(self.server_version, &mut message)) { + match process_decode_result(T::decode(&self.context, &mut message)) { ProcessingResult::Success(val) => { // Check if this decoded value represents the end of a snapshot subscription if val.is_snapshot_end() { diff --git a/src/wsh/async.rs b/src/wsh/async.rs index e500f389..3365d3d2 100644 --- a/src/wsh/async.rs +++ b/src/wsh/async.rs @@ -204,12 +204,12 @@ mod tests { #[tokio::test] async fn test_wsh_metadata_decode_table() { use crate::messages::ResponseMessage; - use crate::subscriptions::StreamDecoder; + use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::wsh::common::test_tables::WSH_METADATA_DECODE_TESTS; for test_case in WSH_METADATA_DECODE_TESTS { let mut message = ResponseMessage::from(test_case.message); - let result = WshMetadata::decode(0, &mut message); + let result = WshMetadata::decode(&DecoderContext::default(), &mut message); if test_case.should_error { assert!(result.is_err(), "Test '{}' should have failed", test_case.name); @@ -232,12 +232,12 @@ mod tests { #[tokio::test] async fn test_wsh_event_data_decode_table() { use crate::messages::ResponseMessage; - use crate::subscriptions::StreamDecoder; + use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::wsh::common::test_tables::WSH_EVENT_DATA_DECODE_TESTS; for test_case in WSH_EVENT_DATA_DECODE_TESTS { let mut message = ResponseMessage::from(test_case.message); - let result = WshEventData::decode(0, &mut message); + let result = WshEventData::decode(&DecoderContext::default(), &mut message); if test_case.should_error { assert!(result.is_err(), "Test '{}' should have failed", test_case.name); diff --git a/src/wsh/common/stream_decoders.rs b/src/wsh/common/stream_decoders.rs index 1e9bd3a9..578eb747 100644 --- a/src/wsh/common/stream_decoders.rs +++ b/src/wsh/common/stream_decoders.rs @@ -5,7 +5,7 @@ use crate::common::error_helpers; use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage}; -use crate::subscriptions::{ResponseContext, StreamDecoder}; +use crate::subscriptions::{DecoderContext, StreamDecoder}; use crate::wsh::*; use crate::Error; @@ -14,14 +14,14 @@ use super::decoders; impl StreamDecoder for WshMetadata { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::WshMetaData]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { match message.message_type() { IncomingMessages::WshMetaData => decoders::decode_wsh_metadata(message.clone()), _ => Err(Error::UnexpectedResponse(message.clone())), } } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id_for(request_id, "encode cancel wsh metadata message")?; super::encoders::encode_cancel_wsh_metadata(request_id) } @@ -30,11 +30,11 @@ impl StreamDecoder for WshMetadata { impl StreamDecoder for WshEventData { const RESPONSE_MESSAGE_IDS: &'static [IncomingMessages] = &[IncomingMessages::WshEventData]; - fn decode(_server_version: i32, message: &mut ResponseMessage) -> Result { + fn decode(_context: &DecoderContext, message: &mut ResponseMessage) -> Result { decoders::decode_event_data_message(message.clone()) } - fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&ResponseContext>) -> Result { + fn cancel_message(_server_version: i32, request_id: Option, _context: Option<&DecoderContext>) -> Result { let request_id = error_helpers::require_request_id_for(request_id, "encode cancel wsh event data message")?; super::encoders::encode_cancel_wsh_event_data(request_id) } diff --git a/src/wsh/sync.rs b/src/wsh/sync.rs index 5572af1c..70dc8cad 100644 --- a/src/wsh/sync.rs +++ b/src/wsh/sync.rs @@ -384,11 +384,12 @@ mod tests { #[test] fn test_wsh_metadata_decode_table() { + use crate::subscriptions::DecoderContext; use crate::wsh::common::test_tables::WSH_METADATA_DECODE_TESTS; for test_case in WSH_METADATA_DECODE_TESTS { let mut message = ResponseMessage::from(test_case.message); - let result = WshMetadata::decode(0, &mut message); + let result = WshMetadata::decode(&DecoderContext::default(), &mut message); if test_case.should_error { assert!(result.is_err(), "Test '{}' should have failed", test_case.name); @@ -410,11 +411,12 @@ mod tests { #[test] fn test_wsh_event_data_decode_table() { + use crate::subscriptions::DecoderContext; use crate::wsh::common::test_tables::WSH_EVENT_DATA_DECODE_TESTS; for test_case in WSH_EVENT_DATA_DECODE_TESTS { let mut message = ResponseMessage::from(test_case.message); - let result = WshEventData::decode(0, &mut message); + let result = WshEventData::decode(&DecoderContext::default(), &mut message); if test_case.should_error { assert!(result.is_err(), "Test '{}' should have failed", test_case.name); From 86cb90fd19a1030ad2eb746c59e62429af71e711 Mon Sep 17 00:00:00 2001 From: Wil Boayue Date: Sat, 31 Jan 2026 00:47:38 -0800 Subject: [PATCH 17/18] Update README version to 2.7 --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ae928ac4..166504c8 100644 --- a/README.md +++ b/README.md @@ -19,13 +19,13 @@ rust-ibapi ships both asynchronous (Tokio) and blocking (threaded) clients. The ```toml # Async only (default features) -ibapi = "2.1" +ibapi = "2.7" # Blocking only -ibapi = { version = "2.1", default-features = false, features = ["sync"] } +ibapi = { version = "2.7", default-features = false, features = ["sync"] } # Async + blocking together -ibapi = { version = "2.1", default-features = false, features = ["sync", "async"] } +ibapi = { version = "2.7", default-features = false, features = ["sync", "async"] } ``` ```bash From d5b6e300619f525bcc8084dd7f955469a14c8421 Mon Sep 17 00:00:00 2001 From: Abel E Date: Mon, 2 Feb 2026 01:49:58 +0400 Subject: [PATCH 18/18] Enable timezone-aware execution timestamps (#387) * feat(orders): add `last_n_days` and `specific_dates` to `ExecutionFilter` * feat(connection): bump server `max_version` to `200` (PARAMETRIZED_DAYS_OF_EXECUTIONS) * Add tests for encode_executions version-gated date fields Tests verify: - Version < 200: date fields not encoded - Version >= 200: last_n_days and specific_dates encoded * Update mock version responses to 200, improve doc comment - Mock server responses now return v200 to match advertised range - Add format spec and example to specific_dates doc * Fix test helper to include issuer_id field for v200 Server version 200 >= BOND_ISSUERID (176), so encoder adds issuer_id. Test helper was missing this field, causing packet mismatch with mock. --------- Co-authored-by: Wil Boayue --- src/client/async.rs | 1 + src/client/sync.rs | 1 + src/connection/common.rs | 2 +- src/orders/common/encoders.rs | 75 +++++++++++++++++++++++++++++++++++ src/orders/mod.rs | 4 ++ src/orders/sync.rs | 1 + src/transport/sync.rs | 33 ++++++++------- 7 files changed, 101 insertions(+), 16 deletions(-) diff --git a/src/client/async.rs b/src/client/async.rs index 7f9c27d7..a498cb4a 100644 --- a/src/client/async.rs +++ b/src/client/async.rs @@ -3324,6 +3324,7 @@ mod tests { security_type: "".to_string(), // Empty means all types exchange: "".to_string(), // Empty means all exchanges side: "".to_string(), // Empty means all sides + ..Default::default() }; // Request executions diff --git a/src/client/sync.rs b/src/client/sync.rs index 00071794..b9cb7f44 100644 --- a/src/client/sync.rs +++ b/src/client/sync.rs @@ -3457,6 +3457,7 @@ mod tests { security_type: "".to_string(), // Empty means all types exchange: "".to_string(), // Empty means all exchanges side: "".to_string(), // Empty means all sides + ..Default::default() }; // Request executions diff --git a/src/connection/common.rs b/src/connection/common.rs index 3cc38385..77bc02cc 100644 --- a/src/connection/common.rs +++ b/src/connection/common.rs @@ -65,7 +65,7 @@ impl Default for ConnectionHandler { fn default() -> Self { Self { min_version: 100, - max_version: server_versions::WSH_EVENT_DATA_FILTERS_DATE, + max_version: server_versions::PARAMETRIZED_DAYS_OF_EXECUTIONS, } } } diff --git a/src/orders/common/encoders.rs b/src/orders/common/encoders.rs index 3e712a8d..aec918e5 100644 --- a/src/orders/common/encoders.rs +++ b/src/orders/common/encoders.rs @@ -488,6 +488,14 @@ pub(crate) fn encode_executions(server_version: i32, request_id: i32, filter: &E message.push_field(&filter.exchange); message.push_field(&filter.side); + if server_version >= server_versions::PARAMETRIZED_DAYS_OF_EXECUTIONS { + message.push_field(&filter.last_n_days); + message.push_field(&(filter.specific_dates.len() as i32)); + for date in &filter.specific_dates { + message.push_field(date); + } + } + Ok(message) } @@ -817,4 +825,71 @@ pub(crate) mod tests { assert_eq!(field_vec[4], "0"); // is_more (false) assert_eq!(field_vec[5], "5.5"); // percent } + + #[test] + fn test_encode_executions_without_date_filter() { + let filter = ExecutionFilter { + client_id: Some(1), + account_code: "DU123456".to_string(), + time: "20260101 09:30:00".to_string(), + symbol: "AAPL".to_string(), + security_type: "STK".to_string(), + exchange: "SMART".to_string(), + side: "BUY".to_string(), + ..Default::default() + }; + + // Version below PARAMETRIZED_DAYS_OF_EXECUTIONS should not include date fields + let result = encode_executions(server_versions::WSH_EVENT_DATA_FILTERS_DATE, 9000, &filter).unwrap(); + let fields = result.encode(); + let field_vec: Vec<&str> = fields.split('\0').collect(); + + assert_eq!(field_vec[0], "7"); // RequestExecutions + assert_eq!(field_vec[1], "3"); // VERSION + assert_eq!(field_vec[2], "9000"); // request_id + assert_eq!(field_vec[3], "1"); // client_id + assert_eq!(field_vec[4], "DU123456"); // account_code + assert_eq!(field_vec[5], "20260101 09:30:00"); // time + assert_eq!(field_vec[6], "AAPL"); // symbol + assert_eq!(field_vec[7], "STK"); // security_type + assert_eq!(field_vec[8], "SMART"); // exchange + assert_eq!(field_vec[9], "BUY"); // side + assert_eq!(field_vec.len(), 11); // 10 fields + trailing empty + } + + #[test] + fn test_encode_executions_with_date_filter() { + let filter = ExecutionFilter { + client_id: Some(1), + account_code: "DU123456".to_string(), + time: "".to_string(), + symbol: "".to_string(), + security_type: "".to_string(), + exchange: "".to_string(), + side: "".to_string(), + last_n_days: 7, + specific_dates: vec!["20260125".to_string(), "20260126".to_string()], + }; + + // Version at PARAMETRIZED_DAYS_OF_EXECUTIONS should include date fields + let result = encode_executions(server_versions::PARAMETRIZED_DAYS_OF_EXECUTIONS, 9000, &filter).unwrap(); + let fields = result.encode(); + let field_vec: Vec<&str> = fields.split('\0').collect(); + + assert_eq!(field_vec[0], "7"); // RequestExecutions + assert_eq!(field_vec[1], "3"); // VERSION + assert_eq!(field_vec[2], "9000"); // request_id + assert_eq!(field_vec[3], "1"); // client_id + assert_eq!(field_vec[4], "DU123456"); // account_code + assert_eq!(field_vec[5], ""); // time + assert_eq!(field_vec[6], ""); // symbol + assert_eq!(field_vec[7], ""); // security_type + assert_eq!(field_vec[8], ""); // exchange + assert_eq!(field_vec[9], ""); // side + assert_eq!(field_vec[10], "7"); // last_n_days + assert_eq!(field_vec[11], "2"); // specific_dates count + assert_eq!(field_vec[12], "20260125"); // specific_dates[0] + assert_eq!(field_vec[13], "20260126"); // specific_dates[1] + assert_eq!(field_vec.len(), 15); // 14 fields + trailing empty + } } diff --git a/src/orders/mod.rs b/src/orders/mod.rs index 4234559d..0ffa43cd 100644 --- a/src/orders/mod.rs +++ b/src/orders/mod.rs @@ -1446,6 +1446,10 @@ pub struct ExecutionFilter { pub exchange: String, /// The Contract's side (BUY or SELL) pub side: String, + /// Filter executions from the last N days (0 = no filter). + pub last_n_days: i32, + /// Filter executions for specific dates (format: yyyymmdd, e.g., "20260130"). + pub specific_dates: Vec, } /// Enumerates possible results from querying an [Execution]. diff --git a/src/orders/sync.rs b/src/orders/sync.rs index b42f0aa4..e970de70 100644 --- a/src/orders/sync.rs +++ b/src/orders/sync.rs @@ -954,6 +954,7 @@ mod tests { security_type: "STK".to_owned(), exchange: "ISLAND".to_owned(), side: "BUY".to_owned(), + ..Default::default() }; let results = client.executions(filter); diff --git a/src/transport/sync.rs b/src/transport/sync.rs index d7185fdd..d0f26ec8 100644 --- a/src/transport/sync.rs +++ b/src/transport/sync.rs @@ -822,6 +822,9 @@ mod tests { packet.push_field(&contract.security_id_type); packet.push_field(&contract.security_id); + // Server version 200 includes issuer_id (>= 176) + packet.push_field(&contract.issuer_id); + Ok(packet) } @@ -1047,7 +1050,7 @@ mod tests { let request = encode_place_order(176, 5, contract, &order)?; let events = vec![ - Exchange::simple("v100..173", &["173|20250415 19:38:30 British Summer Time|"]), + Exchange::simple("v100..200", &["200|20250415 19:38:30 British Summer Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|5|"]), Exchange::request(request.clone(), &[ @@ -1080,7 +1083,7 @@ mod tests { #[test] fn test_connection_establish_connection() -> Result<(), Error> { let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple( "71|2|28||", &[ @@ -1102,7 +1105,7 @@ mod tests { #[test] fn test_reconnect_failed() -> Result<(), Error> { let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|", "\0"]), // RESTART ]; let socket = MockSocket::new(events, MAX_RECONNECT_ATTEMPTS as usize + 1); @@ -1122,9 +1125,9 @@ mod tests { #[test] fn test_reconnect_success() -> Result<(), Error> { let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|", "\0"]), // RESTART - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), ]; let socket = MockSocket::new(events, MAX_RECONNECT_ATTEMPTS as usize - 1); @@ -1141,10 +1144,10 @@ mod tests { #[test] fn test_client_reconnect() -> Result<(), Error> { let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), Exchange::simple("17|1|", &["\0"]), // ManagedAccounts RESTART - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), Exchange::simple("17|1|", &["15|1|DU1234567|"]), // ManagedAccounts ]; @@ -1170,9 +1173,9 @@ mod tests { let expected_response = &format!("10|9000|{AAPL_CONTRACT_RESPONSE}"); let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|", "\0"]), // RESTART - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), Exchange::request(packet.clone(), &[expected_response, "52|1|9001|"]), ]; @@ -1204,10 +1207,10 @@ mod tests { let packet = encode_request_contract_data(173, 9000, &Contract::stock("AAPL").build())?; let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), Exchange::request(packet.clone(), &["\0"]), // RESTART - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), ]; @@ -1236,9 +1239,9 @@ mod tests { let packet = encode_request_contract_data(173, 9000, &Contract::stock("AAPL").build())?; let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|", "\0"]), // RESTART - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::request(packet.clone(), &[]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), ]; @@ -1268,10 +1271,10 @@ mod tests { let packet = encode_request_contract_data(173, 9000, contract)?; let events = vec![ - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), Exchange::request(packet.clone(), &["\0"]), - Exchange::simple("v100..173", &["173|20250323 22:21:01 Greenwich Mean Time|"]), + Exchange::simple("v100..200", &["200|20250323 22:21:01 Greenwich Mean Time|"]), Exchange::simple("71|2|28||", &["15|1|DU1234567|", "9|1|1|"]), ];