From 5ccf15427801ed6761fd673a18acb03e218bdf3e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Sun, 9 Feb 2025 16:10:06 +0100 Subject: [PATCH 1/4] Feat: Bolt 5.8 - home db cache & optimistic routing --- CHANGELOG.md | 2 + neo4j/src/driver.rs | 44 ++- neo4j/src/driver/config.rs | 2 +- neo4j/src/driver/home_db_cache.rs | 249 ++++++++++++++ neo4j/src/driver/io.rs | 4 +- neo4j/src/driver/io/bolt.rs | 100 ++++-- neo4j/src/driver/io/bolt/bolt4x4/protocol.rs | 7 +- neo4j/src/driver/io/bolt/bolt5x0/protocol.rs | 22 +- neo4j/src/driver/io/bolt/bolt5x1/protocol.rs | 6 +- neo4j/src/driver/io/bolt/bolt5x2/protocol.rs | 16 +- neo4j/src/driver/io/bolt/bolt5x3/protocol.rs | 6 +- neo4j/src/driver/io/bolt/bolt5x4/protocol.rs | 7 +- neo4j/src/driver/io/bolt/bolt5x7/protocol.rs | 6 +- neo4j/src/driver/io/bolt/bolt5x8.rs | 19 + neo4j/src/driver/io/bolt/bolt5x8/protocol.rs | 325 ++++++++++++++++++ .../src/driver/io/bolt/bolt5x8/translator.rs | 17 + neo4j/src/driver/io/bolt/bolt_common.rs | 37 +- neo4j/src/driver/io/bolt/bolt_state.rs | 2 +- neo4j/src/driver/io/bolt/handshake.rs | 8 +- .../src/driver/io/bolt/message_parameters.rs | 12 +- neo4j/src/driver/io/deadline.rs | 14 +- neo4j/src/driver/io/pool.rs | 179 +++++++--- neo4j/src/driver/io/pool/routing.rs | 4 +- neo4j/src/driver/io/pool/single_pool.rs | 38 +- neo4j/src/driver/io/pool/ssr_tracker.rs | 86 +++++ neo4j/src/driver/record_stream.rs | 8 +- neo4j/src/driver/session.rs | 284 ++++++++++++--- neo4j/src/driver/transaction.rs | 5 +- neo4j/src/lib.rs | 8 +- neo4j/src/test_data/public-api.txt | 2 +- neo4j/src/value/value_send.rs | 7 +- .../src/testkit_backend/requests.rs | 6 + .../src/testkit_backend/responses.rs | 20 +- 33 files changed, 1316 insertions(+), 236 deletions(-) create mode 100644 neo4j/src/driver/home_db_cache.rs create mode 100644 neo4j/src/driver/io/bolt/bolt5x8.rs create mode 100644 neo4j/src/driver/io/bolt/bolt5x8/protocol.rs create mode 100644 neo4j/src/driver/io/bolt/bolt5x8/translator.rs create mode 100644 neo4j/src/driver/io/pool/ssr_tracker.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index dc679b0..78e0dba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ User-code should not need to create arbitrary `ServerError`s. In return, `ServerError` now implements `Clone`. - Add support for bolt handshake manifest v1. + - Add support for Bolt 5.8 (home database resolution cache) + - Includes an optimization where the driver uses a home/default database cache to perform optimistic routing under certain circumstances, saving a full round trip. See the [PR description](https://github.com/robsdedude/neo4j-rust-driver/pull/28) for more details. **🔧 Fixes** - Rework `neo4j::value::graph::Path` diff --git a/neo4j/src/driver.rs b/neo4j/src/driver.rs index e3cd7ff..294efdd 100644 --- a/neo4j/src/driver.rs +++ b/neo4j/src/driver.rs @@ -14,6 +14,7 @@ pub(crate) mod config; pub(crate) mod eager_result; +mod home_db_cache; pub(crate) mod io; pub(crate) mod record; pub mod record_stream; @@ -40,10 +41,11 @@ pub use config::{ InvalidRoutingContextError, KeepAliveConfig, TlsConfigError, }; pub use eager_result::{EagerResult, ScalarError}; +use home_db_cache::HomeDbCache; use io::bolt::message_parameters::TelemetryAPI; #[cfg(feature = "_internal_testkit_backend")] pub use io::ConnectionPoolMetrics; -use io::{AcquireConfig, Pool, PoolConfig, PooledBolt, SessionAuth, UpdateRtArgs}; +use io::{AcquireConfig, Pool, PoolConfig, PooledBolt, SessionAuth, UpdateRtArgs, UpdateRtDb}; use notification::NotificationFilter; pub use record::Record; use record_stream::RecordStream; @@ -76,8 +78,9 @@ pub mod notification { /// * [`Driver::session()`] for several mechanisms offering more advance patterns. #[derive(Debug)] pub struct Driver { - pub(crate) config: ReducedDriverConfig, - pub(crate) pool: Pool, + config: ReducedDriverConfig, + pool: Pool, + home_db_cache: Arc, capability_check_config: SessionConfig, execute_query_bookmark_manager: Arc, } @@ -120,6 +123,7 @@ impl Driver { idle_time_before_connection_test: config.idle_time_before_connection_test, }, pool: Pool::new(Arc::new(connection_config.address), pool_config), + home_db_cache: Default::default(), capability_check_config: SessionConfig::default() .with_database(Arc::new(String::from("system"))), execute_query_bookmark_manager: Arc::new(bookmark_managers::simple(None)), @@ -169,7 +173,12 @@ impl Driver { idle_time_before_connection_test: self.config.idle_time_before_connection_test, eager_begin: true, }; - Session::new(config, &self.pool, &self.config) + Session::new( + config, + &self.pool, + Arc::clone(&self.home_db_cache), + &self.config, + ) } fn execute_query_session( @@ -197,7 +206,12 @@ impl Driver { idle_time_before_connection_test: self.config.idle_time_before_connection_test, eager_begin: false, }; - Session::new(config, &self.pool, &self.config) + Session::new( + config, + &self.pool, + Arc::clone(&self.home_db_cache), + &self.config, + ) } /// Execute a single query inside a transaction. @@ -329,7 +343,13 @@ impl Driver { idle_time_before_connection_test: Some(Duration::ZERO), eager_begin: true, }; - Session::new(config, &self.pool, &self.config) + let mut session = Session::new( + config, + &self.pool, + Arc::clone(&self.home_db_cache), + &self.config, + ); + session .acquire_connection(RoutingControl::Read) .and_then(|mut con| { con.write_all(None)?; @@ -380,11 +400,21 @@ impl Driver { self.pool.acquire(AcquireConfig { mode: RoutingControl::Read, update_rt_args: UpdateRtArgs { - db: self.capability_check_config.database.as_ref(), + db: self + .capability_check_config + .database + .as_ref() + .map(|db| UpdateRtDb { + db: Arc::clone(db), + guess: false, + }) + .as_ref(), bookmarks: None, imp_user: None, + deadline: self.pool.config.connection_acquisition_deadline(), session_auth: SessionAuth::None, idle_time_before_connection_test: None, + db_resolution_cb: None, }, }) } diff --git a/neo4j/src/driver/config.rs b/neo4j/src/driver/config.rs index 28da3dc..24f1e1e 100644 --- a/neo4j/src/driver/config.rs +++ b/neo4j/src/driver/config.rs @@ -795,7 +795,7 @@ impl ConnectionConfig { } } Some(query) => { - if query == "" { + if query.is_empty() { Some(HashMap::new()) } else { if !routing { diff --git a/neo4j/src/driver/home_db_cache.rs b/neo4j/src/driver/home_db_cache.rs new file mode 100644 index 0000000..a124180 --- /dev/null +++ b/neo4j/src/driver/home_db_cache.rs @@ -0,0 +1,249 @@ +// Copyright Rouven Bauer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use log::debug; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::mem; +use std::sync::Arc; +use std::time::Instant; + +use super::auth::AuthToken; +use crate::value::spatial; +use crate::{value, ValueSend}; + +#[derive(Debug)] +pub(super) struct HomeDbCache { + cache: Mutex>, + config: HomeDbCacheConfig, +} + +#[derive(Debug, Copy, Clone)] +struct HomeDbCacheConfig { + max_size: usize, + prune_size: usize, +} + +impl Default for HomeDbCache { + fn default() -> Self { + Self::new(1000) + } +} + +impl HomeDbCache { + pub(super) fn new(max_size: usize) -> Self { + let max_size_f64 = max_size as f64; + let prune_size = usize::min(max_size, (max_size_f64 * 0.01).log(max_size_f64) as usize); + HomeDbCache { + cache: Mutex::new(HashMap::with_capacity(max_size)), + config: HomeDbCacheConfig { + max_size, + prune_size, + }, + } + } + + pub(super) fn get(&self, key: &HomeDbCacheKey) -> Option> { + let mut lock = self.cache.lock(); + let cache: &mut HashMap = &mut lock; + let res = cache.get_mut(key).map(|entry| { + entry.last_used = Instant::now(); + Arc::clone(&entry.database) + }); + debug!( + "Getting home database cache for key: {} -> {:?}", + key.log_str(), + res.as_deref(), + ); + res + } + + pub(super) fn update(&self, key: HomeDbCacheKey, database: Arc) { + let mut lock = self.cache.lock(); + debug!( + "Updating home database cache for key: {} -> {:?}", + key.log_str(), + database.as_str(), + ); + let cache: &mut HashMap = &mut lock; + let previous_val = cache.insert( + key, + HomeDbCacheEntry { + database, + last_used: Instant::now(), + }, + ); + if previous_val.is_none() { + // cache grew, prune if necessary + Self::prune(cache, self.config); + } + } + + fn prune(cache: &mut HashMap, config: HomeDbCacheConfig) { + if cache.len() <= config.max_size { + return; + } + debug!( + "Pruning home database cache to size: {}", + config.max_size - config.prune_size + ); + let new_cache = mem::take(cache); + *cache = new_cache + .into_iter() + .sorted_by(|(_, v1), (_, v2)| v2.last_used.cmp(&v1.last_used)) + .take(config.max_size - config.prune_size) + .collect(); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(super) enum HomeDbCacheKey { + DriverUser, + FixedUser(Arc), + SessionAuth(SessionAuthKey), +} + +impl HomeDbCacheKey { + fn log_str(&self) -> String { + match self { + HomeDbCacheKey::DriverUser | HomeDbCacheKey::FixedUser(_) => format!("{:?}", self), + HomeDbCacheKey::SessionAuth(SessionAuthKey(auth)) => { + let mut auth: AuthToken = (**auth).clone(); + auth.data + .get_mut("credentials") + .map(|c| *c = value!("**********")); + format!("SessionAuth({:?})", auth.data) + } + } + } +} + +#[derive(Debug, Clone)] +pub(super) struct SessionAuthKey(Arc); + +impl PartialEq for SessionAuthKey { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + || self.0.data.len() == other.0.data.len() + && self + .0 + .data + .iter() + .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) + .zip(other.0.data.iter().sorted_by(|(k1, _), (k2, _)| k1.cmp(k2))) + .all(|((k1, v1), (k2, v2))| k1 == k2 && v1.eq_data(v2)) + } +} + +impl Eq for SessionAuthKey {} + +impl Hash for SessionAuthKey { + fn hash(&self, state: &mut H) { + self.0 + .data + .iter() + .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) + .for_each(|(k, v)| { + k.hash(state); + Self::hash(v, state); + }); + } +} + +impl SessionAuthKey { + fn hash(v: &ValueSend, state: &mut impl Hasher) { + match v { + ValueSend::Null => state.write_usize(0), + ValueSend::Boolean(v) => v.hash(state), + ValueSend::Integer(v) => v.hash(state), + ValueSend::Float(v) => v.to_bits().hash(state), + ValueSend::Bytes(v) => v.hash(state), + ValueSend::String(v) => v.hash(state), + ValueSend::List(v) => v.iter().for_each(|v| Self::hash(v, state)), + ValueSend::Map(v) => { + v.iter() + .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) + .for_each(|(k, v)| { + k.hash(state); + Self::hash(v, state); + }); + } + ValueSend::Cartesian2D(spatial::Cartesian2D { srid, coordinates }) => { + srid.hash(state); + coordinates + .iter() + .map(|v| v.to_bits()) + .for_each(|v| v.hash(state)); + } + ValueSend::Cartesian3D(spatial::Cartesian3D { srid, coordinates }) => { + srid.hash(state); + coordinates + .iter() + .map(|v| v.to_bits()) + .for_each(|v| v.hash(state)); + } + ValueSend::WGS84_2D(spatial::WGS84_2D { srid, coordinates }) => { + srid.hash(state); + coordinates + .iter() + .map(|v| v.to_bits()) + .for_each(|v| v.hash(state)); + } + ValueSend::WGS84_3D(spatial::WGS84_3D { srid, coordinates }) => { + srid.hash(state); + coordinates + .iter() + .map(|v| v.to_bits()) + .for_each(|v| v.hash(state)); + } + ValueSend::Duration(v) => v.hash(state), + ValueSend::LocalTime(v) => v.hash(state), + ValueSend::Time(v) => v.hash(state), + ValueSend::Date(v) => v.hash(state), + ValueSend::LocalDateTime(v) => v.hash(state), + ValueSend::DateTime(v) => v.hash(state), + ValueSend::DateTimeFixed(v) => v.hash(state), + } + } +} + +impl HomeDbCacheKey { + pub(super) fn new( + imp_user: Option<&Arc>, + session_auth: Option<&Arc>, + ) -> Self { + if let Some(user) = imp_user { + HomeDbCacheKey::FixedUser(Arc::clone(user)) + } else if let Some(auth) = session_auth { + if let Some(ValueSend::String(scheme)) = auth.data.get("scheme") { + if scheme == "basic" { + if let Some(ValueSend::String(user)) = auth.data.get("principal") { + return HomeDbCacheKey::FixedUser(Arc::new(user.clone())); + } + } + } + HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(auth))) + } else { + HomeDbCacheKey::DriverUser + } + } +} + +#[derive(Debug, Clone)] +struct HomeDbCacheEntry { + database: Arc, + last_used: Instant, +} diff --git a/neo4j/src/driver/io.rs b/neo4j/src/driver/io.rs index 1655f57..f287c29 100644 --- a/neo4j/src/driver/io.rs +++ b/neo4j/src/driver/io.rs @@ -19,4 +19,6 @@ mod varint; #[cfg(feature = "_internal_testkit_backend")] pub use pool::ConnectionPoolMetrics; -pub(crate) use pool::{AcquireConfig, Pool, PoolConfig, PooledBolt, SessionAuth, UpdateRtArgs}; +pub(crate) use pool::{ + AcquireConfig, Pool, PoolConfig, PooledBolt, SessionAuth, UpdateRtArgs, UpdateRtDb, +}; diff --git a/neo4j/src/driver/io/bolt.rs b/neo4j/src/driver/io/bolt.rs index b841c1b..086576c 100644 --- a/neo4j/src/driver/io/bolt.rs +++ b/neo4j/src/driver/io/bolt.rs @@ -22,6 +22,7 @@ mod bolt5x3; mod bolt5x4; mod bolt5x6; mod bolt5x7; +mod bolt5x8; mod bolt_state; mod chunk; mod handshake; @@ -45,7 +46,6 @@ use std::time::Duration; use atomic_refcell::AtomicRefCell; use enum_dispatch::enum_dispatch; -use log::debug; use usize_cast::FromUsize; use super::deadline::DeadlineIO; @@ -63,6 +63,7 @@ use bolt5x3::{Bolt5x3, Bolt5x3StructTranslator}; use bolt5x4::{Bolt5x4, Bolt5x4StructTranslator}; use bolt5x6::{Bolt5x6, Bolt5x6StructTranslator}; use bolt5x7::{Bolt5x7, Bolt5x7StructTranslator}; +use bolt5x8::{Bolt5x8, Bolt5x8StructTranslator}; use bolt_state::{BoltState, BoltStateTracker}; use chunk::{Chunker, Dechunker}; pub(crate) use handshake::{open, TcpConnector}; @@ -80,70 +81,99 @@ pub(crate) use socket::{BufTcpStream, Socket}; macro_rules! debug_buf_start { ($name:ident) => { - let mut $name: Option = match log_enabled!(Level::Debug) { - true => Some(String::new()), - false => None, - }; + let mut $name = None; + { + #![allow(unused_imports)] + use log::{log_enabled, Level}; + + if log_enabled!(Level::Debug) { + $name = Some(String::new()); + } + } }; } pub(crate) use debug_buf_start; macro_rules! debug_buf { - ($name:ident, $($args:tt)+) => { + ($name:ident, $($args:tt)+) => {{ + #![allow(unused_imports)] + use log::{log_enabled, Level}; + if log_enabled!(Level::Debug) { $name.as_mut().unwrap().push_str(&format!($($args)*)) }; - } + }} } pub(crate) use debug_buf; macro_rules! bolt_debug_extra { ($meta:expr, $local_port:expr) => { 'a: { - let meta = $meta; - // ugly format because rust-fmt is broken - let Ok(meta) = meta else { - break 'a dbg_extra($local_port, Some("!!!!")); - }; - let Some(ValueReceive::String(id)) = meta.get("connection_id") else { - break 'a dbg_extra($local_port, None); - }; - dbg_extra($local_port, Some(id)) + { + #![allow(unused_imports)] + use crate::driver::io::bolt::dbg_extra; + + use crate::value::ValueReceive; + + let meta = $meta; + let Ok(meta) = meta else { + break 'a dbg_extra($local_port, Some("!!!!")); + }; + let Some(ValueReceive::String(id)) = meta.get("connection_id") else { + break 'a dbg_extra($local_port, None); + }; + dbg_extra($local_port, Some(id)) + } } }; } pub(crate) use bolt_debug_extra; macro_rules! debug_buf_end { - ($bolt:expr, $name:ident) => { + ($bolt:expr, $name:ident) => {{ + #![allow(unused_imports)] + use log::debug; + + use crate::driver::io::bolt::bolt_debug_extra; + debug!( "{}{}", bolt_debug_extra!($bolt.meta.try_borrow(), $bolt.local_port), $name.as_ref().map(|s| s.as_str()).unwrap_or("") ); - }; + }}; } pub(crate) use debug_buf_end; macro_rules! bolt_debug { - ($bolt:expr, $($args:tt)+) => { + ($bolt:expr, $($args:tt)+) => {{ + #![allow(unused_imports)] + use log::debug; + + use crate::driver::io::bolt::bolt_debug_extra; + debug!( "{}{}", bolt_debug_extra!($bolt.meta.try_borrow(), $bolt.local_port), format!($($args)*) ); - }; + }}; } pub(crate) use bolt_debug; macro_rules! socket_debug { - ($local_port:expr, $($args:tt)+) => { + ($local_port:expr, $($args:tt)+) => {{ + #![allow(unused_imports)] + use log::debug; + + use crate::driver::io::bolt::dbg_extra; + debug!( "{}{}", dbg_extra(Some($local_port), None), format!($($args)*) ); - }; + }}; } pub(crate) use socket_debug; @@ -179,6 +209,7 @@ impl Bolt { data: BoltData::new(version, stream, socket, local_port, address), // [bolt-version-bump] search tag when changing bolt version support protocol: match version { + (5, 8) => Bolt5x8::::default().into(), (5, 7) => Bolt5x7::::default().into(), (5, 6) => Bolt5x6::::default().into(), (5, 4) => Bolt5x4::::default().into(), @@ -392,6 +423,10 @@ impl Bolt { self.data.set_telemetry_enabled(enabled) } + pub(crate) fn ssr_enabled(&self) -> bool { + self.data.ssr_enabled() + } + #[inline(always)] pub(crate) fn debug_log(&self, msg: impl FnOnce() -> String) { bolt_debug!(self.data, "{}", msg()); @@ -488,14 +523,15 @@ trait BoltProtocol: Debug { #[enum_dispatch(BoltProtocol)] #[derive(Debug)] enum BoltProtocolVersion { - V4x4(Bolt4x4), - V5x0(Bolt5x0), - V5x1(Bolt5x1), - V5x2(Bolt5x2), - V5x3(Bolt5x3), - V5x4(Bolt5x4), - V5x6(Bolt5x6), + V5x8(Bolt5x8), V5x7(Bolt5x7), + V5x6(Bolt5x6), + V5x4(Bolt5x4), + V5x3(Bolt5x3), + V5x2(Bolt5x2), + V5x1(Bolt5x1), + V5x0(Bolt5x0), + V4x4(Bolt4x4), } #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] @@ -517,6 +553,7 @@ pub(crate) struct BoltData { meta: Arc>>, server_agent: Arc>>, telemetry_enabled: Arc>, + ssr_enabled: Arc>, address: Arc
, last_qid: Arc>>, auth: Option>, @@ -547,6 +584,7 @@ impl BoltData { meta: Default::default(), server_agent: Default::default(), telemetry_enabled: Default::default(), + ssr_enabled: Default::default(), address, last_qid: Default::default(), auth: None, @@ -714,6 +752,10 @@ impl BoltData { fn set_telemetry_enabled(&mut self, enabled: bool) { *self.telemetry_enabled.borrow_mut() = enabled; } + + fn ssr_enabled(&self) -> bool { + *(*self.ssr_enabled).borrow() + } } impl Debug for BoltData { diff --git a/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs b/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs index 0ea4cf8..f38ecec 100644 --- a/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt4x4/protocol.rs @@ -21,7 +21,7 @@ use std::ops::Deref; use std::sync::Arc; use atomic_refcell::AtomicRefCell; -use log::{debug, log_enabled, warn, Level}; +use log::warn; use usize_cast::FromUsize; use super::super::bolt5x0::Bolt5x0; @@ -36,9 +36,8 @@ use super::super::packstream::{ PackStreamSerializer, PackStreamSerializerDebugImpl, PackStreamSerializerImpl, }; use super::super::{ - bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, - BoltResponse, BoltStructTranslatorWithUtcPatch, OnServerErrorCb, ResponseCallbacks, - ResponseMessage, + debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, BoltResponse, + BoltStructTranslatorWithUtcPatch, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; use crate::error_::Result; use crate::value::ValueReceive; diff --git a/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs index bda7628..5293a1d 100644 --- a/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use std::time::Duration; use atomic_refcell::AtomicRefCell; -use log::{debug, log_enabled, warn, Level}; +use log::warn; use usize_cast::FromUsize; use super::super::bolt_common::{unsupported_protocol_feature_error, ServerAwareBoltVersion}; @@ -38,9 +38,9 @@ use super::super::packstream::{ PackStreamSerializerDebugImpl, PackStreamSerializerImpl, }; use super::super::{ - assert_response_field_count, bolt_debug, bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, - debug_buf_start, BoltData, BoltMeta, BoltProtocol, BoltResponse, BoltStructTranslator, - ConnectionState, OnServerErrorCb, ResponseCallbacks, ResponseMessage, + assert_response_field_count, bolt_debug, debug_buf, debug_buf_end, debug_buf_start, BoltData, + BoltMeta, BoltProtocol, BoltResponse, BoltStructTranslator, ConnectionState, OnServerErrorCb, + ResponseCallbacks, ResponseMessage, }; use crate::driver::config::auth::AuthToken; use crate::driver::config::notification::NotificationFilter; @@ -709,7 +709,12 @@ impl BoltProtocol for Bolt5x0 { Self::write_mode_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, mode)?; - Self::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; + Self::write_db_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + db.as_deref().map(String::as_str), + )?; Self::write_imp_user_entry( log_buf.as_mut(), @@ -833,7 +838,12 @@ impl BoltProtocol for Bolt5x0 { Self::write_mode_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, mode)?; - Self::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; + Self::write_db_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + db.as_deref().map(String::as_str), + )?; Self::write_imp_user_entry( log_buf.as_mut(), diff --git a/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs index b3e6fa4..31c3cd7 100644 --- a/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x1/protocol.rs @@ -17,7 +17,6 @@ use std::fmt::Debug; use std::io::{Read, Write}; use std::sync::Arc; -use log::{debug, log_enabled, Level}; use usize_cast::FromUsize; use super::super::bolt5x0::Bolt5x0; @@ -32,9 +31,8 @@ use super::super::packstream::{ PackStreamSerializer, PackStreamSerializerDebugImpl, PackStreamSerializerImpl, }; use super::super::{ - bolt_debug, bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, - BoltProtocol, BoltResponse, BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, - ResponseMessage, + bolt_debug, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, BoltResponse, + BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; use crate::driver::config::auth::AuthToken; use crate::error_::Result; diff --git a/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs index 778226b..3f18453 100644 --- a/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs @@ -17,7 +17,6 @@ use std::fmt::Debug; use std::io::{Read, Write}; use crate::driver::notification::NotificationFilter; -use log::{debug, log_enabled, Level}; use usize_cast::FromUsize; use super::super::bolt5x0::Bolt5x0; @@ -33,8 +32,8 @@ use super::super::packstream::{ PackStreamSerializer, PackStreamSerializerDebugImpl, PackStreamSerializerImpl, }; use super::super::{ - bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, - BoltResponse, BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, + debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, BoltResponse, + BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; use crate::error_::Result; use crate::value::ValueReceive; @@ -301,7 +300,12 @@ impl BoltProtocol for Bolt5x2 { mode, )?; - Bolt5x0::::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; + Bolt5x0::::write_db_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + db.as_deref().map(String::as_str), + )?; Bolt5x0::::write_imp_user_entry( log_buf.as_mut(), @@ -436,11 +440,11 @@ impl BoltProtocol for Bolt5x2 { if let Some(db) = db { debug_buf!(log_buf, "{}", { dbg_serializer.write_string("db").unwrap(); - dbg_serializer.write_string(db).unwrap(); + dbg_serializer.write_string(&db).unwrap(); dbg_serializer.flush() }); serializer.write_string("db")?; - serializer.write_string(db)?; + serializer.write_string(&db)?; } if let Some(imp_user) = imp_user { diff --git a/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs index 5b0c518..253f139 100644 --- a/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x3/protocol.rs @@ -16,8 +16,6 @@ use std::borrow::Borrow; use std::fmt::Debug; use std::io::{Read, Write}; -use log::{debug, log_enabled, Level}; - use super::super::bolt5x0::Bolt5x0; use super::super::bolt5x2::Bolt5x2; use super::super::bolt_common::{ @@ -34,8 +32,8 @@ use super::super::packstream::{ PackStreamSerializer, PackStreamSerializerDebugImpl, PackStreamSerializerImpl, }; use super::super::{ - bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, - BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, + debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, BoltStructTranslator, + OnServerErrorCb, ResponseCallbacks, }; use crate::error_::Result; use crate::value::ValueReceive; diff --git a/neo4j/src/driver/io/bolt/bolt5x4/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x4/protocol.rs index 9b54a20..313efb9 100644 --- a/neo4j/src/driver/io/bolt/bolt5x4/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x4/protocol.rs @@ -22,7 +22,7 @@ use std::net::TcpStream; use std::ops::Deref; use std::sync::Arc; -use log::{debug, log_enabled, warn, Level}; +use log::warn; use super::super::bolt5x0::Bolt5x0; use super::super::bolt5x2::Bolt5x2; @@ -38,9 +38,8 @@ use super::super::packstream::{ PackStreamSerializer, PackStreamSerializerDebugImpl, PackStreamSerializerImpl, }; use super::super::{ - bolt_debug_extra, dbg_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltMeta, - BoltProtocol, BoltResponse, BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, - ResponseMessage, + debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltMeta, BoltProtocol, BoltResponse, + BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, ResponseMessage, }; use crate::error_::Result; use crate::value::ValueReceive; diff --git a/neo4j/src/driver/io/bolt/bolt5x7/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x7/protocol.rs index bf7ca3e..a1c64ca 100644 --- a/neo4j/src/driver/io/bolt/bolt5x7/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x7/protocol.rs @@ -18,7 +18,7 @@ use std::fmt::Debug; use std::io::{Read, Write}; use crate::error::ServerError; -use log::{debug, warn}; +use log::warn; use super::super::bolt5x6::Bolt5x6; use super::super::bolt_common::ServerAwareBoltVersion; @@ -30,8 +30,8 @@ use super::super::message_parameters::{ }; use super::super::response::BoltMeta; use super::super::{ - assert_response_field_count, bolt_debug, bolt_debug_extra, dbg_extra, BoltData, BoltProtocol, - BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, + assert_response_field_count, bolt_debug, BoltData, BoltProtocol, BoltStructTranslator, + OnServerErrorCb, ResponseCallbacks, }; use crate::error_::Result; use crate::value::ValueReceive; diff --git a/neo4j/src/driver/io/bolt/bolt5x8.rs b/neo4j/src/driver/io/bolt/bolt5x8.rs new file mode 100644 index 0000000..937525e --- /dev/null +++ b/neo4j/src/driver/io/bolt/bolt5x8.rs @@ -0,0 +1,19 @@ +// Copyright Rouven Bauer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod protocol; +mod translator; + +pub(crate) use protocol::Bolt5x8; +pub(crate) use translator::Bolt5x8StructTranslator; diff --git a/neo4j/src/driver/io/bolt/bolt5x8/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x8/protocol.rs new file mode 100644 index 0000000..2d05831 --- /dev/null +++ b/neo4j/src/driver/io/bolt/bolt5x8/protocol.rs @@ -0,0 +1,325 @@ +// Copyright Rouven Bauer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Borrow; +use std::collections::HashMap; +use std::fmt::Debug; +use std::io::{Read, Write}; +use std::mem; +use std::net::TcpStream; +use std::ops::Deref; +use std::sync::Arc; + +use crate::driver::io::bolt::bolt5x4::Bolt5x4; +use crate::driver::io::bolt::{BoltResponse, ResponseMessage}; +use log::{debug, log_enabled, warn}; + +use super::super::bolt_common::ServerAwareBoltVersion; +use super::super::message::BoltMessage; +use super::super::message_parameters::{ + BeginParameters, CommitParameters, DiscardParameters, GoodbyeParameters, HelloParameters, + PullParameters, ReauthParameters, ResetParameters, RollbackParameters, RouteParameters, + RunParameters, TelemetryParameters, +}; +use super::super::packstream::{ + PackStreamSerializer, PackStreamSerializerDebugImpl, PackStreamSerializerImpl, +}; +use super::super::response::BoltMeta; +use super::super::{bolt5x0::Bolt5x0, bolt5x2::Bolt5x2, bolt5x3::Bolt5x3, bolt5x7::Bolt5x7}; +use super::super::{ + bolt_debug_extra, debug_buf, debug_buf_end, debug_buf_start, BoltData, BoltProtocol, + BoltStructTranslator, OnServerErrorCb, ResponseCallbacks, +}; +use crate::error_::Result; +use crate::value::ValueReceive; + +const HINTS_KEY: &str = "hints"; +const SSR_ENABLED_KEY: &str = "ssr.enabled"; + +#[derive(Debug)] +pub(crate) struct Bolt5x8 { + pub(in super::super) bolt5x7: Bolt5x7, +} + +impl Bolt5x8 { + pub(in super::super) fn new(protocol_version: ServerAwareBoltVersion) -> Self { + Self { + bolt5x7: Bolt5x7::new(protocol_version), + } + } + + pub(in super::super) fn enqueue_hello_response(data: &mut BoltData) { + let bolt_meta = Arc::clone(&data.meta); + let telemetry_enabled = Arc::clone(&data.telemetry_enabled); + let ssr_enabled = Arc::clone(&data.ssr_enabled); + let bolt_server_agent = Arc::clone(&data.server_agent); + let socket = Arc::clone(&data.socket); + + data.responses.push_back(BoltResponse::new( + ResponseMessage::Hello, + ResponseCallbacks::new().with_on_success(move |mut meta| { + Bolt5x0::::hello_response_handle_agent(&mut meta, &bolt_server_agent); + Self::hello_response_handle_connection_hints( + &meta, + socket.deref().as_ref(), + &mut telemetry_enabled.borrow_mut(), + &mut ssr_enabled.borrow_mut(), + ); + mem::swap(&mut *bolt_meta.borrow_mut(), &mut meta); + Ok(()) + }), + )); + } + + pub(in super::super) fn hello_response_handle_connection_hints( + meta: &BoltMeta, + socket: Option<&TcpStream>, + telemetry_enabled: &mut bool, + ssr_enabled: &mut bool, + ) { + let empty_hints = HashMap::new(); + let hints = match meta.get(HINTS_KEY) { + Some(ValueReceive::Map(hints)) => hints, + Some(value) => { + warn!("Server sent unexpected {HINTS_KEY} type {:?}", value); + &empty_hints + } + None => &empty_hints, + }; + Bolt5x0::::hello_response_handle_timeout_hint(hints, socket); + Bolt5x4::::hello_response_telemetry_hint(hints, telemetry_enabled); + Self::hello_response_handle_ssr_enabled_hint(hints, ssr_enabled); + } + + pub(in super::super) fn hello_response_handle_ssr_enabled_hint( + hints: &HashMap, + ssr_enabled: &mut bool, + ) { + match hints.get(SSR_ENABLED_KEY) { + None => { + *ssr_enabled = false; + } + Some(ValueReceive::Boolean(value)) => *ssr_enabled = *value, + Some(value) => { + *ssr_enabled = false; + warn!("Server sent unexpected {SSR_ENABLED_KEY} type {:?}", value); + } + } + } +} + +impl Default for Bolt5x8 { + fn default() -> Self { + Self::new(ServerAwareBoltVersion::V5x8) + } +} + +impl BoltProtocol for Bolt5x8 { + fn hello( + &mut self, + data: &mut BoltData, + parameters: HelloParameters, + ) -> Result<()> { + let HelloParameters { + user_agent, + auth: _, + routing_context, + notification_filter, + } = parameters; + debug_buf_start!(log_buf); + debug_buf!(log_buf, "C: HELLO"); + let mut dbg_serializer = PackStreamSerializerDebugImpl::new(); + let mut message_buff = Vec::new(); + let mut serializer = PackStreamSerializerImpl::new(&mut message_buff); + serializer.write_struct_header(0x01, 1)?; + + let extra_size = 2 + + Bolt5x2::::notification_filter_entries_count(Some(notification_filter)) + + >::into(routing_context.is_some()); + + serializer.write_dict_header(extra_size)?; + debug_buf!(log_buf, " {}", { + dbg_serializer.write_dict_header(extra_size).unwrap(); + dbg_serializer.flush() + }); + + Bolt5x0::::write_user_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + user_agent, + )?; + + Bolt5x3::::write_bolt_agent_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + )?; + + self.bolt5x7 + .bolt5x6 + .bolt5x4 + .bolt5x3 + .bolt5x2 + .bolt5x1 + .bolt5x0 + .write_routing_context_entry( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + data, + routing_context, + )?; + + Bolt5x2::::write_notification_filter_entries( + log_buf.as_mut(), + &mut serializer, + &mut dbg_serializer, + Some(notification_filter), + )?; + + data.message_buff.push_back(vec![message_buff]); + debug_buf_end!(data, log_buf); + + Self::enqueue_hello_response(data); + Ok(()) + } + + #[inline] + fn reauth( + &mut self, + data: &mut BoltData, + parameters: ReauthParameters, + ) -> Result<()> { + self.bolt5x7.reauth(data, parameters) + } + + #[inline] + fn supports_reauth(&self) -> bool { + self.bolt5x7.supports_reauth() + } + + #[inline] + fn goodbye( + &mut self, + data: &mut BoltData, + parameters: GoodbyeParameters, + ) -> Result<()> { + self.bolt5x7.goodbye(data, parameters) + } + + #[inline] + fn reset( + &mut self, + data: &mut BoltData, + parameters: ResetParameters, + ) -> Result<()> { + self.bolt5x7.reset(data, parameters) + } + + #[inline] + fn run + Debug, KM: Borrow + Debug>( + &mut self, + data: &mut BoltData, + parameters: RunParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.run(data, parameters, callbacks) + } + + #[inline] + fn discard( + &mut self, + data: &mut BoltData, + parameters: DiscardParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.discard(data, parameters, callbacks) + } + + #[inline] + fn pull( + &mut self, + data: &mut BoltData, + parameters: PullParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.pull(data, parameters, callbacks) + } + + #[inline] + fn begin + Debug>( + &mut self, + data: &mut BoltData, + parameters: BeginParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.begin(data, parameters, callbacks) + } + + #[inline] + fn commit( + &mut self, + data: &mut BoltData, + parameters: CommitParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.commit(data, parameters, callbacks) + } + + #[inline] + fn rollback( + &mut self, + data: &mut BoltData, + parameters: RollbackParameters, + ) -> Result<()> { + self.bolt5x7.rollback(data, parameters) + } + + #[inline] + fn route( + &mut self, + data: &mut BoltData, + parameters: RouteParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.route(data, parameters, callbacks) + } + + #[inline] + fn telemetry( + &mut self, + data: &mut BoltData, + parameters: TelemetryParameters, + callbacks: ResponseCallbacks, + ) -> Result<()> { + self.bolt5x7.telemetry(data, parameters, callbacks) + } + + #[inline] + fn load_value(&mut self, reader: &mut R) -> Result { + self.bolt5x7.load_value(reader) + } + + #[inline] + fn handle_response( + &mut self, + bolt_data: &mut BoltData, + message: BoltMessage, + on_server_error: OnServerErrorCb, + ) -> Result<()> { + self.bolt5x7 + .handle_response(bolt_data, message, on_server_error) + } +} diff --git a/neo4j/src/driver/io/bolt/bolt5x8/translator.rs b/neo4j/src/driver/io/bolt/bolt5x8/translator.rs new file mode 100644 index 0000000..fa66eb2 --- /dev/null +++ b/neo4j/src/driver/io/bolt/bolt5x8/translator.rs @@ -0,0 +1,17 @@ +// Copyright Rouven Bauer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::super::bolt5x7::Bolt5x7StructTranslator; + +pub(crate) type Bolt5x8StructTranslator = Bolt5x7StructTranslator; diff --git a/neo4j/src/driver/io/bolt/bolt_common.rs b/neo4j/src/driver/io/bolt/bolt_common.rs index 8b104cd..3f43e72 100644 --- a/neo4j/src/driver/io/bolt/bolt_common.rs +++ b/neo4j/src/driver/io/bolt/bolt_common.rs @@ -114,9 +114,7 @@ pub(super) enum ServerAwareBoltVersion { V5x3, V5x4, V5x6, - #[allow(dead_code)] // bolt versions exists, not yet implemented V5x7, - #[allow(dead_code)] V5x8, } @@ -124,30 +122,29 @@ impl ServerAwareBoltVersion { #[inline] fn protocol_version(&self) -> &'static str { match self { - ServerAwareBoltVersion::V4x4 => "4.4", - ServerAwareBoltVersion::V5x0 => "5.0", - ServerAwareBoltVersion::V5x1 => "5.1", - ServerAwareBoltVersion::V5x2 => "5.2", - ServerAwareBoltVersion::V5x3 => "5.3", - ServerAwareBoltVersion::V5x4 => "5.4", - ServerAwareBoltVersion::V5x6 => "5.6", - ServerAwareBoltVersion::V5x7 => "5.7", - ServerAwareBoltVersion::V5x8 => "5.8", + Self::V4x4 => "4.4", + Self::V5x0 => "5.0", + Self::V5x1 => "5.1", + Self::V5x2 => "5.2", + Self::V5x3 => "5.3", + Self::V5x4 => "5.4", + Self::V5x6 => "5.6", + Self::V5x7 => "5.7", + Self::V5x8 => "5.8", } } #[inline] fn min_server_version(&self) -> &'static str { match self { - ServerAwareBoltVersion::V4x4 => "4.4", - ServerAwareBoltVersion::V5x0 => "5.0", - ServerAwareBoltVersion::V5x1 => "5.5", - ServerAwareBoltVersion::V5x2 => "5.7", - ServerAwareBoltVersion::V5x3 => "5.9", - ServerAwareBoltVersion::V5x4 => "5.13", - ServerAwareBoltVersion::V5x6 => "5.23", - ServerAwareBoltVersion::V5x7 => "5.26", - ServerAwareBoltVersion::V5x8 => "5.26", + Self::V4x4 => "4.4", + Self::V5x0 => "5.0", + Self::V5x1 => "5.5", + Self::V5x2 => "5.7", + Self::V5x3 => "5.9", + Self::V5x4 => "5.13", + Self::V5x6 => "5.23", + Self::V5x7 | Self::V5x8 => "5.26", } } } diff --git a/neo4j/src/driver/io/bolt/bolt_state.rs b/neo4j/src/driver/io/bolt/bolt_state.rs index 6522b40..2e3202e 100644 --- a/neo4j/src/driver/io/bolt/bolt_state.rs +++ b/neo4j/src/driver/io/bolt/bolt_state.rs @@ -15,8 +15,8 @@ use log::debug; use std::collections::HashMap; +use super::bolt_debug_extra; use super::response::ResponseMessage; -use super::{bolt_debug_extra, dbg_extra}; use crate::value::ValueReceive; #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] diff --git a/neo4j/src/driver/io/bolt/handshake.rs b/neo4j/src/driver/io/bolt/handshake.rs index e8cf88b..4ece897 100644 --- a/neo4j/src/driver/io/bolt/handshake.rs +++ b/neo4j/src/driver/io/bolt/handshake.rs @@ -37,7 +37,7 @@ const BOLT_MAGIC_PREAMBLE: [u8; 4] = [0x60, 0x60, 0xB0, 0x17]; // [bolt-version-bump] search tag when changing bolt version support const BOLT_VERSION_OFFER: [u8; 16] = [ 0, 0, 1, 255, // BOLT handshake manifest v1 - 0, 1, 7, 5, // BOLT 5.7 - 5.6 + 0, 2, 8, 5, // BOLT 5.8 - 5.6 0, 4, 4, 5, // BOLT 5.4 - 5.0 0, 0, 4, 4, // BOLT 4.4 ]; @@ -407,8 +407,9 @@ fn decode_version_offer(offer: &[u8; 4]) -> Result<(u8, u8)> { } // [bolt-version-bump] search tag when changing bolt version support -const BOLT_VERSIONS: [(u8, u8); 8] = [ +const BOLT_VERSIONS: [(u8, u8); 9] = [ // important: ordered by descending preference + (5, 8), (5, 7), (5, 6), (5, 4), @@ -546,6 +547,7 @@ mod tests { #[case([0, 0, 4, 5], (5, 4))] #[case([0, 0, 6, 5], (5, 6))] #[case([0, 0, 7, 5], (5, 7))] + #[case([0, 0, 8, 5], (5, 8))] fn test_decode_version_offer( #[case] mut offer: [u8; 4], #[case] expected: (u8, u8), @@ -586,7 +588,7 @@ mod tests { #[case([0, 0, 2, 4])] // driver didn't offer version 4.2 #[case([0, 0, 3, 4])] // driver didn't offer version 4.3 #[case([0, 0, 5, 5])] // driver didn't offer version 5.5 - #[case([0, 0, 8, 5])] // driver didn't offer version 5.8 + #[case([0, 0, 9, 5])] // driver didn't offer version 5.8 #[case([0, 0, 0, 6])] // driver didn't offer version 6.0 fn test_garbage_server_version( #[case] mut offer: [u8; 4], diff --git a/neo4j/src/driver/io/bolt/message_parameters.rs b/neo4j/src/driver/io/bolt/message_parameters.rs index e5b040e..09f6656 100644 --- a/neo4j/src/driver/io/bolt/message_parameters.rs +++ b/neo4j/src/driver/io/bolt/message_parameters.rs @@ -76,7 +76,7 @@ impl ResetParameters { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) struct RunParameters<'a, KP: Borrow + Debug, KM: Borrow + Debug> { pub(super) query: &'a str, pub(super) parameters: Option<&'a HashMap>, @@ -84,7 +84,7 @@ pub(crate) struct RunParameters<'a, KP: Borrow + Debug, KM: Borrow + D pub(super) tx_timeout: Option, pub(super) tx_metadata: Option<&'a HashMap>, pub(super) mode: Option<&'a str>, - pub(super) db: Option<&'a str>, + pub(super) db: Option>, pub(super) imp_user: Option<&'a str>, pub(super) notification_filter: Option<&'a NotificationFilter>, } @@ -98,7 +98,7 @@ impl<'a, KP: Borrow + Debug, KM: Borrow + Debug> RunParameters<'a, KP, tx_timeout: Option, tx_metadata: Option<&'a HashMap>, mode: Option<&'a str>, - db: Option<&'a str>, + db: Option>, imp_user: Option<&'a str>, notification_filter: &'a NotificationFilter, ) -> Self { @@ -159,13 +159,13 @@ impl PullParameters { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) struct BeginParameters<'a, K: Borrow + Debug> { pub(super) bookmarks: Option<&'a Bookmarks>, pub(super) tx_timeout: Option, pub(super) tx_metadata: Option<&'a HashMap>, pub(super) mode: Option<&'a str>, - pub(super) db: Option<&'a str>, + pub(super) db: Option>, pub(super) imp_user: Option<&'a str>, pub(super) notification_filter: &'a NotificationFilter, } @@ -176,7 +176,7 @@ impl<'a, K: Borrow + Debug> BeginParameters<'a, K> { tx_timeout: Option, tx_metadata: Option<&'a HashMap>, mode: Option<&'a str>, - db: Option<&'a str>, + db: Option>, imp_user: Option<&'a str>, notification_filter: &'a NotificationFilter, ) -> Self { diff --git a/neo4j/src/driver/io/deadline.rs b/neo4j/src/driver/io/deadline.rs index c4437bf..109d0ed 100644 --- a/neo4j/src/driver/io/deadline.rs +++ b/neo4j/src/driver/io/deadline.rs @@ -87,13 +87,13 @@ impl<'tcp, S: Read + Write> DeadlineIO<'tcp, S> { }; let old_timeout = self.wrap_io_error(get_socket_timeout(socket), ReaderErrorDuring::GetTimeout)?; - let timeout = match deadline.checked_duration_since(Instant::now()) { - // deadline in the past - // => we set a tiny timeout to trigger a timeout error on pretty much any blocking - None => Duration::from_nanos(1), - // deadline in the future - Some(timeout) => timeout, - }; + let timeout = deadline + .checked_duration_since(Instant::now()) + .unwrap_or_else(|| { + // deadline in the past + // => we set a tiny timeout to trigger a timeout error on pretty much any blocking + Duration::from_nanos(1) + }); if let Some(old_timeout) = old_timeout { if timeout >= old_timeout { let res = work(self); diff --git a/neo4j/src/driver/io/pool.rs b/neo4j/src/driver/io/pool.rs index e3367d4..cc59fef 100644 --- a/neo4j/src/driver/io/pool.rs +++ b/neo4j/src/driver/io/pool.rs @@ -14,18 +14,20 @@ mod routing; mod single_pool; +mod ssr_tracker; use std::cell::RefCell; use std::collections::{HashMap, HashSet}; +use std::fmt::{Debug, Formatter}; use std::io::{Read, Write}; -use std::mem; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Duration; +use std::{fmt, mem}; use atomic_refcell::AtomicRefCell; use itertools::Itertools; -use log::{debug, info, warn}; +use log::{debug, error, info, warn}; use parking_lot::{Condvar, Mutex, RwLockReadGuard}; use rustls::ClientConfig; @@ -47,6 +49,7 @@ use routing::RoutingTable; pub use single_pool::ConnectionPoolMetrics; pub(crate) use single_pool::SessionAuth; use single_pool::{SimplePool, SinglePooledBolt, UnpreparedSinglePooledBolt}; +use ssr_tracker::SsrTracker; // 7 is a reasonable common upper bound for the size of clusters // this is, however, not a hard limit @@ -170,15 +173,21 @@ impl PoolConfig { #[derive(Debug)] pub(crate) struct Pool { - config: Arc, + pub(crate) config: Arc, + ssr_tracker: Arc, pools: Pools, } impl Pool { pub(crate) fn new(address: Arc
, config: PoolConfig) -> Self { let config = Arc::new(config); - let pools = Pools::new(address, Arc::clone(&config)); - Self { config, pools } + let ssr_tracker = Arc::new(SsrTracker::new()); + let pools = Pools::new(address, Arc::clone(&config), Arc::clone(&ssr_tracker)); + Self { + config, + ssr_tracker, + pools, + } } #[inline] @@ -191,6 +200,11 @@ impl Pool { self.config.tls_config.is_some() } + #[inline] + pub(crate) fn ssr_enabled(&self) -> bool { + self.ssr_tracker.ssr_enabled() + } + #[cfg(feature = "_internal_testkit_backend")] #[inline] pub(crate) fn get_metrics(&self, address: Arc
) -> Option { @@ -220,10 +234,9 @@ impl Pool { bolt: Some(match &self.pools { Pools::Direct(single_pool) => { let mut connection = None; - let deadline = self.config.connection_acquisition_deadline(); while connection.is_none() { - connection = single_pool.acquire(deadline)?.prepare( - deadline, + connection = single_pool.acquire(args.update_rt_args.deadline)?.prepare( + args.update_rt_args.deadline, args.update_rt_args.idle_time_before_connection_test, args.update_rt_args.session_auth, None, @@ -275,10 +288,10 @@ enum PoolsRef<'a> { } impl Pools { - fn new(address: Arc
, config: Arc) -> Self { + fn new(address: Arc
, config: Arc, ssr_tracker: Arc) -> Self { match config.routing_context { - None => Pools::Direct(SimplePool::new(address, config)), - Some(_) => Pools::Routing(RoutingPool::new(address, config)), + None => Pools::Direct(SimplePool::new(address, config, ssr_tracker)), + Some(_) => Pools::Routing(RoutingPool::new(address, config, ssr_tracker)), } } @@ -307,10 +320,11 @@ struct RoutingPool { routing_tables: MostlyRLock, address: Arc
, config: Arc, + ssr_tracker: Arc, } impl RoutingPool { - fn new(address: Arc
, config: Arc) -> Self { + fn new(address: Arc
, config: Arc, ssr_tracker: Arc) -> Self { assert!(config.routing_context.is_some()); Self { pools: MostlyRLock::new(HashMap::with_capacity(DEFAULT_CLUSTER_SIZE)), @@ -318,12 +332,21 @@ impl RoutingPool { routing_tables: MostlyRLock::new(HashMap::new()), address, config, + ssr_tracker, } } fn acquire(&self, args: AcquireConfig) -> Result { + debug!( + "acquiring {:?} connection towards {}", + args.mode, + args.update_rt_args + .db + .map(|db| format!("{:?}", db)) + .unwrap_or(String::from("default database")) + ); let (mut targets, db) = self.choose_addresses_from_fresh_rt(args)?; - let deadline = self.config.connection_acquisition_deadline(); + let deadline = args.update_rt_args.deadline; 'target: for target in &targets { while let Some(connection) = self.acquire_routing_address_no_wait(target) { let mut on_server_error = @@ -435,19 +458,18 @@ impl RoutingPool { args: UpdateRtArgs, ) -> Result { let mut connection = None; - let deadline = self.config.connection_acquisition_deadline(); while connection.is_none() { let unprepared_connection = { let pools = self.ensure_pool_exists(target); pools .get(target) .expect("just created above") - .acquire(deadline) + .acquire(args.deadline) }?; let mut on_server_error = |bolt_data: &mut _, error: &mut _| self.handle_server_error(bolt_data, error); connection = unprepared_connection.prepare( - deadline, + args.deadline, args.idle_time_before_connection_test, args.session_auth, Some(&mut on_server_error), @@ -463,7 +485,11 @@ impl RoutingPool { |mut rt| { rt.insert( Arc::clone(target), - SimplePool::new(Arc::clone(target), Arc::clone(&self.config)), + SimplePool::new( + Arc::clone(target), + Arc::clone(&self.config), + Arc::clone(&self.ssr_tracker), + ), ); Ok(()) }, @@ -476,16 +502,22 @@ impl RoutingPool { args: AcquireConfig, ) -> Result<(RwLockReadGuard, Option>)> { let rt_args = args.update_rt_args; - let db_name = RefCell::new(rt_args.db.cloned()); + let db_key = rt_args.rt_key(); + let db_name = RefCell::new(rt_args.db_name()); let db_name_ref = &db_name; let lock = self.routing_tables.maybe_write( |rts| { - rts.get(&*db_name_ref.borrow()) + let needs_update = rts + .get(&db_key) .map(|rt| !rt.is_fresh(args.mode)) - .unwrap_or(true) + .unwrap_or(true); + if !needs_update { + mem::swap(&mut *db_name_ref.borrow_mut(), &mut db_key.clone()); + } + needs_update }, |mut rts| { - let key = rt_args.db.cloned(); + let key = rt_args.rt_key(); let rt = rts.entry(key).or_insert_with(|| self.empty_rt()); if !rt.is_fresh(args.mode) { let mut new_db = self.update_rts(rt_args, &mut rts)?; @@ -522,7 +554,8 @@ impl RoutingPool { args: UpdateRtArgs, rts: &mut RoutingTables, ) -> Result>> { - let rt_key = args.db.cloned(); + debug!("Fetching new routing table for {:?}", args.db); + let rt_key = args.rt_key(); let rt = rts.entry(rt_key).or_insert_with(|| self.empty_rt()); let pref_init_router = rt.initialized_without_writers; let mut new_rt: Result; @@ -544,23 +577,35 @@ impl RoutingPool { } } match new_rt { - Err(err) => Err(Neo4jError::disconnect(format!( - "unable to retrieve routing information; last error: {}", - err - ))), + Err(err) => { + error!("failed to update routing table; last error: {}", err); + Err(Neo4jError::disconnect(format!( + "unable to retrieve routing information; last error: {}", + err + ))) + } Ok(mut new_rt) => { - if args.db.is_some() { - let db = args.db.cloned(); - new_rt.database.clone_from(&db); - rts.insert(db.clone(), new_rt); - self.clean_up_pools(rts); - Ok(db) - } else { - let db = new_rt.database.clone(); - rts.insert(db.clone(), new_rt); - self.clean_up_pools(rts); - Ok(db) + let db = match args.db { + Some(args_db) if !args_db.guess => { + let db = Some(Arc::clone(&args_db.db)); + new_rt.database.clone_from(&db); + debug!("Storing new routing table for {:?}: {:?}", db, new_rt); + rts.insert(db.as_ref().map(Arc::clone), new_rt); + self.clean_up_pools(rts); + db + } + _ => { + let db = new_rt.database.clone(); + debug!("Storing new routing table for {:?}: {:?}", db, new_rt); + rts.insert(db.clone(), new_rt); + self.clean_up_pools(rts); + db + } + }; + if let Some(cb) = args.db_resolution_cb { + cb(db.as_ref().map(Arc::clone)); } + Ok(db) } } } @@ -588,10 +633,9 @@ impl RoutingPool { self.deactivate_server_locked_rts(&resolved, rts); } } - Ok(Err(match last_err { - None => Neo4jError::disconnect("no known routers left"), - Some(err) => err, - })) + Ok(Err(last_err.unwrap_or_else(|| { + Neo4jError::disconnect("no known routers left") + }))) } fn fetch_rt_from_router( @@ -604,7 +648,7 @@ impl RoutingPool { RouteParameters::new( self.config.routing_context.as_ref().unwrap(), args.bookmarks, - args.db.as_ref().map(|db| db.as_str()), + args.db_request_str(), args.imp_user, ), ResponseCallbacks::new().with_on_success({ @@ -830,11 +874,58 @@ pub(crate) struct AcquireConfig<'a> { pub(crate) update_rt_args: UpdateRtArgs<'a>, } -#[derive(Debug, Copy, Clone)] +#[derive(Copy, Clone)] pub(crate) struct UpdateRtArgs<'a> { - pub(crate) db: Option<&'a Arc>, + pub(crate) db: Option<&'a UpdateRtDb>, pub(crate) bookmarks: Option<&'a Bookmarks>, pub(crate) imp_user: Option<&'a str>, pub(crate) session_auth: SessionAuth<'a>, + pub(crate) deadline: Option, pub(crate) idle_time_before_connection_test: Option, + pub(crate) db_resolution_cb: Option<&'a dyn Fn(Option>)>, +} + +impl Debug for UpdateRtArgs<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("UpdateRtArgs") + .field("db", &self.db) + .field("bookmarks", &self.bookmarks) + .field("imp_user", &self.imp_user) + .field("session_auth", &self.session_auth) + .field( + "idle_time_before_connection_test", + &self.idle_time_before_connection_test, + ) + .field( + "db_resolution_cb", + &self.db_resolution_cb.as_ref().map(|_| "..."), + ) + .finish() + } +} + +impl UpdateRtArgs<'_> { + fn rt_key(&self) -> Option> { + self.db.as_ref().map(|db| Arc::clone(&db.db)) + } + + fn db_request_str(&self) -> Option<&str> { + self.db.as_ref().and_then(|db| match db.guess { + true => None, + false => Some(db.db.as_str()), + }) + } + + fn db_name(&self) -> Option> { + self.db.as_ref().and_then(|db| match db.guess { + true => None, + false => Some(Arc::clone(&db.db)), + }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct UpdateRtDb { + pub(crate) db: Arc, + pub(crate) guess: bool, } diff --git a/neo4j/src/driver/io/pool/routing.rs b/neo4j/src/driver/io/pool/routing.rs index fe37cf1..f22139e 100644 --- a/neo4j/src/driver/io/pool/routing.rs +++ b/neo4j/src/driver/io/pool/routing.rs @@ -40,8 +40,8 @@ pub(crate) struct RoutingTable { impl RoutingTable { pub(crate) fn new(initial_router: Arc
) -> Self { Self { - routers: Vec::new(), - readers: vec![initial_router], + routers: vec![initial_router], + readers: Vec::new(), writers: Vec::new(), database: None, initialized_without_writers: true, diff --git a/neo4j/src/driver/io/pool/single_pool.rs b/neo4j/src/driver/io/pool/single_pool.rs index 4132f2e..53c9d97 100644 --- a/neo4j/src/driver/io/pool/single_pool.rs +++ b/neo4j/src/driver/io/pool/single_pool.rs @@ -22,12 +22,12 @@ use parking_lot::lock_api::MutexGuard; use parking_lot::{Condvar, Mutex, RawMutex}; use super::super::bolt::message_parameters::{HelloParameters, ReauthParameters}; -use super::super::bolt::{self, OnServerErrorCb, TcpBolt, TcpRW}; +use super::super::bolt::{self, AuthResetHandle, OnServerErrorCb, TcpBolt, TcpRW}; +use super::super::pool::ssr_tracker::SsrTracker; use super::PoolConfig; use crate::address_::Address; use crate::driver::config::auth::{auth_managers, AuthToken}; use crate::driver::config::AuthConfig; -use crate::driver::io::bolt::AuthResetHandle; use crate::error_::{Neo4jError, Result}; use crate::time::Instant; use crate::util::RefContainer; @@ -38,6 +38,7 @@ type PoolElement = TcpBolt; pub(crate) struct InnerPool { address: Arc
, config: Arc, + ssr_tracker: Arc, synced: Mutex, made_room_condition: Condvar, } @@ -51,7 +52,7 @@ struct InnerPoolSyncedData { } impl InnerPool { - fn new(address: Arc
, config: Arc) -> Self { + fn new(address: Arc
, config: Arc, ssr_tracker: Arc) -> Self { let raw_pool = VecDeque::with_capacity(config.max_connection_pool_size); // allow: `AuthResetHandle::hash` hashes by pointer address, not value #[allow(clippy::mutable_key_type)] @@ -65,6 +66,7 @@ impl InnerPool { Self { address, config, + ssr_tracker, synced, made_room_condition: Condvar::new(), } @@ -122,6 +124,7 @@ impl InnerPool { } connection.write_all(deadline)?; connection.read_all(deadline, None)?; + self.ssr_tracker.add_connection(&connection); Ok(connection) } @@ -164,8 +167,12 @@ impl InnerPool { pub(crate) struct SimplePool(Arc); impl SimplePool { - pub(crate) fn new(address: Arc
, config: Arc) -> Self { - Self(Arc::new(InnerPool::new(address, config))) + pub(crate) fn new( + address: Arc
, + config: Arc, + ssr_tracker: Arc, + ) -> Self { + Self(Arc::new(InnerPool::new(address, config, ssr_tracker))) } pub(crate) fn acquire(&self, deadline: Option) -> Result { @@ -261,11 +268,6 @@ impl SimplePool { } fn release(inner_pool: &Arc, mut connection: PoolElement) { - let mut lock = inner_pool.synced.lock(); - lock.borrowed -= 1; - assert!(lock - .borrowed_auth_reset - .remove(&connection.auth_reset_handler())); if connection.needs_reset() { let res = connection .reset() @@ -275,10 +277,19 @@ impl SimplePool { info!("ignoring failure during reset, dropping connection"); } } - if !connection.closed() { + let mut lock = inner_pool.synced.lock(); + assert!(lock + .borrowed_auth_reset + .remove(&connection.auth_reset_handler())); + lock.borrowed -= 1; + if connection.closed() { + inner_pool.made_room_condition.notify_one(); + drop(lock); + inner_pool.ssr_tracker.remove_connection(&connection); + } else { lock.raw_pool.push_back(connection); + inner_pool.made_room_condition.notify_one(); } - inner_pool.made_room_condition.notify_one(); } #[cfg(feature = "_internal_testkit_backend")] @@ -330,6 +341,7 @@ impl UnpreparedSinglePooledBolt { if connection.is_older_than(max_lifetime) { connection.debug_log(|| String::from("connection reached max lifetime")); connection.close(); + self.bolt = Some(connection); return Ok(None); } } @@ -343,7 +355,7 @@ impl UnpreparedSinglePooledBolt { on_server_error, ) { connection.debug_log(|| format!("liveness check failed: {}", err)); - SimplePool::release(&self.pool, connection); + self.bolt = Some(connection); return Ok(None); } } diff --git a/neo4j/src/driver/io/pool/ssr_tracker.rs b/neo4j/src/driver/io/pool/ssr_tracker.rs new file mode 100644 index 0000000..59c70c8 --- /dev/null +++ b/neo4j/src/driver/io/pool/ssr_tracker.rs @@ -0,0 +1,86 @@ +// Copyright Rouven Bauer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io::{Read, Write}; +use std::sync::atomic::AtomicUsize; + +use super::super::bolt::Bolt; + +#[derive(Debug, Default)] +pub(crate) struct SsrTracker { + with_ssr: AtomicUsize, + without_ssr: AtomicUsize, +} + +impl SsrTracker { + pub(super) fn new() -> Self { + Default::default() + } + + fn with_ssr(&self) -> usize { + self.with_ssr.load(std::sync::atomic::Ordering::Relaxed) + } + + fn increment_with_ssr(&self) { + self.with_ssr + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + fn decrement_with_ssr(&self) { + self.with_ssr + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + + fn increment_without_ssr(&self) { + self.without_ssr + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + fn decrement_without_ssr(&self) { + self.without_ssr + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } + + fn without_ssr(&self) -> usize { + self.without_ssr.load(std::sync::atomic::Ordering::Relaxed) + } + + pub(super) fn add_connection(&self, connection: &impl SsrTrackable) { + match connection.ssr_enabled() { + true => self.increment_with_ssr(), + false => self.increment_without_ssr(), + } + } + + pub(super) fn remove_connection(&self, connection: &impl SsrTrackable) { + match connection.ssr_enabled() { + true => self.decrement_with_ssr(), + false => self.decrement_without_ssr(), + } + } + + pub(super) fn ssr_enabled(&self) -> bool { + self.with_ssr() > 0 && self.without_ssr() == 0 + } +} + +pub(super) trait SsrTrackable { + fn ssr_enabled(&self) -> bool; +} + +impl SsrTrackable for Bolt { + fn ssr_enabled(&self) -> bool { + self.ssr_enabled() + } +} diff --git a/neo4j/src/driver/record_stream.rs b/neo4j/src/driver/record_stream.rs index 2ab317f..60119e7 100644 --- a/neo4j/src/driver/record_stream.rs +++ b/neo4j/src/driver/record_stream.rs @@ -44,6 +44,8 @@ pub struct RecordStream<'driver> { listener: Arc>, } +type BoltMetaCb = Box; + impl<'driver> RecordStream<'driver> { pub(crate) fn new( connection: Rc>>, @@ -71,6 +73,7 @@ impl<'driver> RecordStream<'driver> { pub(crate) fn run + Debug, KM: Borrow + Debug>( &mut self, parameters: RunParameters, + mut db_resolution_cb: Option, ) -> Result<()> { if let RecordListenerState::ForeignError(e) = &(*self.listener).borrow().state { return Err(Neo4jError::ServerError { @@ -80,7 +83,10 @@ impl<'driver> RecordStream<'driver> { let mut callbacks = self.failure_callbacks(); let listener = Arc::downgrade(&self.listener); - callbacks = callbacks.with_on_success(move |meta| { + callbacks = callbacks.with_on_success(move |mut meta| { + if let Some(db_resolution_cb) = db_resolution_cb.as_mut() { + db_resolution_cb(&mut meta) + } if let Some(listener) = listener.upgrade() { return listener.borrow_mut().run_success_cb(meta); } diff --git a/neo4j/src/driver/session.rs b/neo4j/src/driver/session.rs index 7068cfb..a2d5c45 100644 --- a/neo4j/src/driver/session.rs +++ b/neo4j/src/driver/session.rs @@ -25,23 +25,25 @@ use std::marker::PhantomData; use std::ops::Deref; use std::rc::Rc; use std::result::Result as StdResult; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use log::{debug, info}; use super::config::auth::AuthToken; +use super::home_db_cache::{HomeDbCache, HomeDbCacheKey}; use super::io::bolt::message_parameters::{ BeginParameters, RunParameters, TelemetryAPI, TelemetryParameters, }; -use super::io::bolt::ResponseCallbacks; -use super::io::{AcquireConfig, Pool, PooledBolt, UpdateRtArgs}; +use super::io::bolt::{BoltMeta, ResponseCallbacks}; +use super::io::{AcquireConfig, Pool, PooledBolt, UpdateRtArgs, UpdateRtDb}; use super::record_stream::{ErrorPropagator, RecordStream, SharedErrorPropagator}; use super::transaction::{Transaction, TransactionTimeout}; use super::{EagerResult, ReducedDriverConfig, RoutingControl}; use crate::driver::io::SessionAuth; use crate::error_::{Neo4jError, Result}; +use crate::time::Instant; use crate::transaction::InnerTransaction; -use crate::value::ValueSend; +use crate::value::{ValueReceive, ValueSend}; use bookmarks::{bookmark_managers, BookmarkManager, Bookmarks}; use config::InternalSessionConfig; pub use config::SessionConfig; @@ -75,25 +77,35 @@ use super::Driver; pub struct Session<'driver> { config: InternalSessionConfig, pool: &'driver Pool, + home_db_cache: Arc, driver_config: &'driver ReducedDriverConfig, - resolved_db: Option>, + target_db: Arc>, + home_db_cache_key: OnceLock, session_bookmarks: SessionBookmarks, + current_acquisition_deadline: Option, } impl<'driver> Session<'driver> { - pub(crate) fn new( + pub(super) fn new( config: InternalSessionConfig, pool: &'driver Pool, + home_db_cache: Arc, driver_config: &'driver ReducedDriverConfig, ) -> Self { let bookmarks = config.config.bookmarks.clone(); let manager = config.config.as_ref().bookmark_manager.clone(); + let target_db = Arc::new(AtomicRefCell::new(SessionTargetDb::new_init( + config.config.database.clone(), + ))); Session { config, pool, + home_db_cache, driver_config, - resolved_db: None, + target_db, + home_db_cache_key: Default::default(), session_bookmarks: SessionBookmarks::new(bookmarks, manager), + current_acquisition_deadline: None, } } @@ -147,23 +159,27 @@ impl<'driver> Session<'driver> { true, None, ); + let target_db = AtomicRefCell::borrow(&self.target_db).as_db(); let res = record_stream - .run(RunParameters::new_auto_commit_run( - builder.query.as_ref(), - Some(builder.param.borrow()), - Some(&*self.session_bookmarks.get_bookmarks_for_work()?), - builder.timeout.raw(), - Some(builder.meta.borrow()), - builder.mode.as_protocol_str(), - self.resolved_db().as_ref().map(|db| db.as_str()), - self.config - .config - .as_ref() - .impersonated_user - .as_ref() - .map(|imp| imp.as_str()), - &self.config.config.notification_filter, - )) + .run( + RunParameters::new_auto_commit_run( + builder.query.as_ref(), + Some(builder.param.borrow()), + Some(&*self.session_bookmarks.get_bookmarks_for_work()?), + builder.timeout.raw(), + Some(builder.meta.borrow()), + builder.mode.as_protocol_str(), + target_db, + self.config + .config + .as_ref() + .impersonated_user + .as_ref() + .map(|imp| imp.as_str()), + &self.config.config.notification_filter, + ), + Some(Box::new(self.make_db_meta_resolution_cb())), + ) .and_then(|_| (builder.receiver)(&mut record_stream)); let res = match res { Ok(r) => { @@ -228,7 +244,7 @@ impl<'driver> Session<'driver> { builder.timeout.raw(), Some(builder.meta.borrow()), builder.mode.as_protocol_str(), - self.resolved_db().as_ref().map(|db| db.as_str()), + AtomicRefCell::borrow(&self.target_db).as_db(), self.config .config .impersonated_user @@ -236,10 +252,18 @@ impl<'driver> Session<'driver> { .map(|imp| imp.as_str()), &self.config.config.notification_filter, ); + tx.begin( parameters, self.config.eager_begin, ResponseCallbacks::new() + .with_on_success({ + let db_cb = self.make_db_meta_resolution_cb(); + move |mut meta| { + db_cb(&mut meta); + Ok(()) + } + }) .with_on_failure(ErrorPropagator::make_on_error_cb(error_propagator)), )?; let res = receiver(Transaction::new(&mut tx)); @@ -267,9 +291,38 @@ impl<'driver> Session<'driver> { } fn resolve_db(&mut self) -> Result<()> { - if self.resolved_db().is_none() && self.pool.is_routing() { - debug!("Resolving home db"); - self.resolved_db = self.pool.resolve_home_db(UpdateRtArgs { + let mut target_db = AtomicRefCell::borrow_mut(&self.target_db); + if target_db.pinned + || target_db + .target + .as_ref() + .map(|t| !t.guess) + .unwrap_or_default() + || !self.pool.is_routing() + { + debug!( + "Targeting fixed db: {:?}", + target_db.target.as_ref().map(|t| t.db.as_str()) + ); + target_db.pinned = true; + return Ok(()); + } + if self.pool.ssr_enabled() { + if let Some(cached_db) = self.home_db_cache.get(self.home_db_cache_key()) { + debug!("Targeting cached home db: {:?}", cached_db.as_str()); + *target_db = SessionTargetDb::new_guess(cached_db); + return Ok(()); + } + } + drop(target_db); + + self.resolve_db_forced() + } + + fn resolve_db_forced(&mut self) -> Result<()> { + debug!("Resolving home db"); + self.pool + .resolve_home_db(UpdateRtArgs { db: None, bookmarks: Some(&*self.session_bookmarks.get_bookmarks_for_work()?), imp_user: self @@ -279,40 +332,132 @@ impl<'driver> Session<'driver> { .as_ref() .map(|imp| imp.as_str()), session_auth: self.session_auth(), + deadline: self.current_acquisition_deadline, idle_time_before_connection_test: self.config.idle_time_before_connection_test, - })?; - debug!("Resolved home db to {:?}", &self.resolved_db); + db_resolution_cb: Some(&self.make_db_resolution_cb()), + }) + .map(|_| ()) + } + + fn make_db_meta_resolution_cb(&self) -> impl Fn(&mut BoltMeta) + Send + Sync + 'static { + let base_cb = self.make_db_resolution_cb(); + move |meta| { + let db = match meta.remove("db") { + Some(ValueReceive::String(db)) => Some(Arc::new(db)), + _ => None, + }; + base_cb(db); } - Ok(()) } - #[inline] - fn resolved_db(&self) -> &Option> { - match self.resolved_db { - None => &self.config.config.database, - Some(_) => &self.resolved_db, + fn make_db_resolution_cb_if_needed( + &self, + ) -> Option>) + Send + Sync + 'static> { + if !self.pool.is_routing() { + return None; } + { + let target_db = AtomicRefCell::borrow(&self.target_db); + if target_db.pinned || !target_db.target.as_ref().map(|t| t.guess).unwrap_or(true) { + return None; + } + }; + Some(self.make_db_resolution_cb()) + } + + fn make_db_resolution_cb(&self) -> impl Fn(Option>) + Send + Sync + 'static { + let cache = Arc::clone(&self.home_db_cache); + let target_db = Arc::clone(&self.target_db); + let key = self.home_db_cache_key().clone(); + move |db| { + if let Some(db) = db.as_ref() { + cache.update(key.clone(), Arc::clone(db)); + } + { + let mut target_db = AtomicRefCell::borrow_mut(&target_db); + if !target_db.pinned { + debug!("Pinning db: {:?}", db.as_ref().map(|d| d.as_str())); + *target_db = SessionTargetDb::new_pinned(db); + } + } + } + } + + fn home_db_cache_key(&self) -> &HomeDbCacheKey { + self.home_db_cache_key.get_or_init(|| { + HomeDbCacheKey::new( + self.config.config.impersonated_user.as_ref(), + self.config.config.auth.as_ref(), + ) + }) } pub(super) fn acquire_connection( &mut self, mode: RoutingControl, ) -> Result> { + self.acquire_connection_args(mode, AcquireArgs::default()) + } + + fn acquire_connection_args( + &mut self, + mode: RoutingControl, + args: AcquireArgs, + ) -> Result> { + self.current_acquisition_deadline = self.pool.config.connection_acquisition_deadline(); self.resolve_db()?; let bookmarks = self.session_bookmarks.get_bookmarks_for_work()?; + let target = AtomicRefCell::borrow(&self.target_db).target.clone(); + let connection = self.no_resolve_acquire_connection( + mode, + Some(&*bookmarks), + target.as_ref(), + args.session_auth.unwrap_or_else(|| self.session_auth()), + )?; + if target.as_ref().map(|t| t.guess).unwrap_or_default() && !connection.ssr_enabled() { + debug!( + "Used db cached, received connection without SSR => \ + returning connection and falling back to explicit db resolution" + ); + drop(connection); + self.resolve_db_forced()?; + let target = AtomicRefCell::borrow(&self.target_db).target.clone(); + self.no_resolve_acquire_connection( + mode, + Some(&*bookmarks), + target.as_ref(), + args.session_auth.unwrap_or_else(|| self.session_auth()), + ) + } else { + Ok(connection) + } + } + + fn no_resolve_acquire_connection( + &self, + mode: RoutingControl, + bookmarks: Option<&Bookmarks>, + db: Option<&UpdateRtDb>, + session_auth: SessionAuth, + ) -> Result> { self.pool.acquire(AcquireConfig { mode, update_rt_args: UpdateRtArgs { - db: self.resolved_db().as_ref(), - bookmarks: Some(&*bookmarks), + db, + bookmarks, imp_user: self .config .config .impersonated_user .as_ref() .map(|imp| imp.as_str()), - session_auth: self.session_auth(), + session_auth, + deadline: self.current_acquisition_deadline, idle_time_before_connection_test: self.config.idle_time_before_connection_test, + db_resolution_cb: self + .make_db_resolution_cb_if_needed() + .as_ref() + .map(|cb| cb as _), }, }) } @@ -338,23 +483,10 @@ impl<'driver> Session<'driver> { } fn forced_auth(&mut self, auth: &Arc) -> Result<()> { - self.resolve_db()?; - let bookmarks = self.session_bookmarks.get_bookmarks_for_work()?; - let mut connection = self.pool.acquire(AcquireConfig { - mode: RoutingControl::Read, - update_rt_args: UpdateRtArgs { - db: self.resolved_db().as_ref(), - bookmarks: Some(&*bookmarks), - imp_user: self - .config - .config - .impersonated_user - .as_ref() - .map(|imp| imp.as_str()), - session_auth: SessionAuth::Forced(auth), - idle_time_before_connection_test: self.config.idle_time_before_connection_test, - }, - })?; + let args = AcquireArgs { + session_auth: Some(SessionAuth::Forced(auth)), + }; + let mut connection = self.acquire_connection_args(RoutingControl::Read, args)?; connection.write_all(None)?; connection.read_all(None) } @@ -1135,3 +1267,45 @@ impl SessionBookmarks { Ok(()) } } + +#[derive(Debug, Default)] +struct AcquireArgs<'a> { + session_auth: Option>, +} + +#[derive(Debug, Default)] +struct SessionTargetDb { + target: Option, + pinned: bool, +} + +impl SessionTargetDb { + fn new_init(target: Option>) -> Self { + Self { + target: target.map(|db| UpdateRtDb { db, guess: false }), + pinned: false, + } + } + + fn new_guess(db: Arc) -> Self { + Self { + target: Some(UpdateRtDb { db, guess: true }), + pinned: false, + } + } + + fn new_pinned(target: Option>) -> Self { + Self { + target: target.map(|db| UpdateRtDb { db, guess: false }), + pinned: true, + } + } + + fn as_db(&self) -> Option> { + if self.pinned || self.target.as_ref().map(|t| !t.guess).unwrap_or_default() { + self.target.as_ref().map(|t| Arc::clone(&t.db)) + } else { + None + } + } +} diff --git a/neo4j/src/driver/transaction.rs b/neo4j/src/driver/transaction.rs index 9378aef..65fc935 100644 --- a/neo4j/src/driver/transaction.rs +++ b/neo4j/src/driver/transaction.rs @@ -249,7 +249,10 @@ impl<'driver> InnerTransaction<'driver> { false, Some(Arc::clone(&self.error_propagator)), ); - record_stream.run(RunParameters::new_transaction_run(query, Some(parameters)))?; + record_stream.run( + RunParameters::new_transaction_run(query, Some(parameters)), + None, + )?; Ok(record_stream) } diff --git a/neo4j/src/lib.rs b/neo4j/src/lib.rs index 9f6a47f..1f5e60d 100644 --- a/neo4j/src/lib.rs +++ b/neo4j/src/lib.rs @@ -23,10 +23,12 @@ //! //! ## Compatibility // [bolt-version-bump] search tag when changing bolt version support -//! This driver supports bolt protocol version 4.4, and 5.0 - 5.7. -//! This corresponds to Neo4j versions 4.4, and 5.0 - 5.26+. +//! This driver supports bolt protocol version 4.4, and 5.0 - 5.8. +//! This corresponds to Neo4j versions 4.4, and the whole 5.x series. +//! Newer versions of Neo4j are supported as long as they keep support for at least one of the +//! protocol versions mentioned above. //! For details of bolt protocol compatibility, see the -//! [official Neo4j documentation](https://neo4j.com/docs/bolt/current/bolt-compatibility/). +//! [official Neo4j documentation](https://7687.org/bolt-compatibility/). //! //! ## Basic Example //! ``` diff --git a/neo4j/src/test_data/public-api.txt b/neo4j/src/test_data/public-api.txt index ef78c49..c611a3c 100644 --- a/neo4j/src/test_data/public-api.txt +++ b/neo4j/src/test_data/public-api.txt @@ -738,7 +738,7 @@ pub fn neo4j::session::Session<'driver>::last_bookmarks(&self) -> alloc::sync::A pub fn neo4j::session::Session<'driver>::transaction<'session>(&'session mut self) -> neo4j::session::TransactionBuilder<'driver, 'session, alloc::string::String, std::collections::hash::map::HashMap> impl<'driver> core::fmt::Debug for neo4j::session::Session<'driver> pub fn neo4j::session::Session<'driver>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl<'driver> core::marker::Freeze for neo4j::session::Session<'driver> +impl<'driver> !core::marker::Freeze for neo4j::session::Session<'driver> impl<'driver> core::marker::Send for neo4j::session::Session<'driver> impl<'driver> core::marker::Sync for neo4j::session::Session<'driver> impl<'driver> core::marker::Unpin for neo4j::session::Session<'driver> diff --git a/neo4j/src/value/value_send.rs b/neo4j/src/value/value_send.rs index cc79947..e5a7ef8 100644 --- a/neo4j/src/value/value_send.rs +++ b/neo4j/src/value/value_send.rs @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; - use super::spatial; use super::time; use super::value_receive::ValueReceive; use super::ValueConversionError; +use itertools::Itertools; +use std::collections::HashMap; // imports for docs #[allow(unused)] @@ -73,7 +73,8 @@ impl ValueSend { ValueSend::Map(v1) => match other { ValueSend::Map(v2) if v1.len() == v2.len() => v1 .iter() - .zip(v2.iter()) + .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2)) + .zip(v2.iter().sorted_by(|(k1, _), (k2, _)| k1.cmp(k2))) .all(|((k1, v1), (k2, v2))| k1 == k2 && v1.eq_data(v2)), _ => false, }, diff --git a/testkit_backend/src/testkit_backend/requests.rs b/testkit_backend/src/testkit_backend/requests.rs index 6f86512..c67a45e 100644 --- a/testkit_backend/src/testkit_backend/requests.rs +++ b/testkit_backend/src/testkit_backend/requests.rs @@ -75,6 +75,7 @@ pub(super) enum Request { fetch_size: Option, max_tx_retry_time_ms: Option, liveness_check_timeout_ms: Option, + max_connection_lifetime_ms: Option, max_connection_pool_size: Option, connection_acquisition_timeout_ms: Option, #[serde(rename = "clientCertificate")] @@ -818,6 +819,7 @@ impl Request { fetch_size, max_tx_retry_time_ms, liveness_check_timeout_ms, + max_connection_lifetime_ms, max_connection_pool_size, connection_acquisition_timeout_ms, client_certificate, @@ -885,6 +887,10 @@ impl Request { driver_config = driver_config.with_idle_time_before_connection_test(Duration::from_millis(timeout)); } + if let Some(timeout) = max_connection_lifetime_ms { + driver_config = + driver_config.with_max_connection_lifetime(Duration::from_millis(timeout)); + } if let Some(max_connection_pool_size) = max_connection_pool_size { driver_config = driver_config.with_max_connection_pool_size(max_connection_pool_size); } diff --git a/testkit_backend/src/testkit_backend/responses.rs b/testkit_backend/src/testkit_backend/responses.rs index 6b9622e..fee5248 100644 --- a/testkit_backend/src/testkit_backend/responses.rs +++ b/testkit_backend/src/testkit_backend/responses.rs @@ -37,7 +37,7 @@ use super::{BackendId, TestKitResultT}; // [bolt-version-bump] search tag when changing bolt version support // https://github.com/rust-lang/rust/issues/85077 -const FEATURE_LIST: [&str; 48] = [ +const FEATURE_LIST: [&str; 52] = [ // === FUNCTIONAL FEATURES === "Feature:API:BookmarkManager", "Feature:API:ConnectionAcquisitionTimeout", @@ -45,9 +45,7 @@ const FEATURE_LIST: [&str; 48] = [ // "Feature:API:Driver.ExecuteQuery:WithAuth", "Feature:API:Driver:GetServerInfo", "Feature:API:Driver.IsEncrypted", - // "Feature:API:Driver:MaxConnectionLifetime", - // Even tough the driver does not support notification config, - // TestKit uses this flag to change assertions on the notification objects + "Feature:API:Driver:MaxConnectionLifetime", "Feature:API:Driver:NotificationsConfig", "Feature:API:Driver.VerifyAuthentication", "Feature:API:Driver.VerifyConnectivity", @@ -83,7 +81,7 @@ const FEATURE_LIST: [&str; 48] = [ // "Feature:Bolt:5.5", // unused/deprecated protocol version "Feature:Bolt:5.6", "Feature:Bolt:5.7", - // "Feature:Bolt:5.8", + "Feature:Bolt:5.8", "Feature:Bolt:HandshakeManifestV1", "Feature:Bolt:Patch:UTC", "Feature:Impersonation", @@ -97,8 +95,8 @@ const FEATURE_LIST: [&str; 48] = [ "Optimization:ConnectionReuse", "Optimization:EagerTransactionBegin", "Optimization:ExecuteQueryPipelining", - // "Optimization:HomeDatabaseCache", - // "Optimization:HomeDbCacheBasicPrincipalIsImpersonatedUser", + "Optimization:HomeDatabaseCache", + "Optimization:HomeDbCacheBasicPrincipalIsImpersonatedUser", "Optimization:ImplicitDefaultArguments", "Optimization:MinimalBookmarksSet", "Optimization:MinimalResets", @@ -131,6 +129,14 @@ fn get_plain_skipped_tests() -> &'static HashMap<&'static str, &'static str> { "stub.summary.test_summary.TestSummaryNotifications4x4.test_no_notifications", "An empty list is returned when there are no notifications", ), + ( + "stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_does_not_encompass_router_route_response", + "Pending driver unification: only some drivers consider a single connection acquisition timeout for all operations on acquisition (like fetching routing table) and some consider a separate timeout for each operation", + ), + ( + "stub.driver_parameters.test_connection_acquisition_timeout_ms.TestConnectionAcquisitionTimeoutMs.test_router_handshake_has_own_timeout_in_time", + "Pending driver unification: only some drivers consider a single connection acquisition timeout for all operations on acquisition (like fetching routing table) and some consider a separate timeout for each operation", + ), ( "neo4j.test_summary.TestSummary.test_no_notification_info", "An empty list is returned when there are no notifications", From b4f4ede754741eaf11377cae5882105c9fc7c006 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Sat, 15 Feb 2025 11:40:10 +0100 Subject: [PATCH 2/4] Add unit tests for the cache --- neo4j/src/driver/home_db_cache.rs | 232 +++++++++++++++++++++++++++++- 1 file changed, 226 insertions(+), 6 deletions(-) diff --git a/neo4j/src/driver/home_db_cache.rs b/neo4j/src/driver/home_db_cache.rs index a124180..43889e0 100644 --- a/neo4j/src/driver/home_db_cache.rs +++ b/neo4j/src/driver/home_db_cache.rs @@ -46,7 +46,11 @@ impl Default for HomeDbCache { impl HomeDbCache { pub(super) fn new(max_size: usize) -> Self { let max_size_f64 = max_size as f64; - let prune_size = usize::min(max_size, (max_size_f64 * 0.01).log(max_size_f64) as usize); + let mut prune_size = (0.01 * max_size_f64 * max_size_f64.ln()) as usize; + prune_size = usize::min(prune_size, max_size); + if prune_size == 0 && max_size > 0 { + prune_size = 1; // ensure at least one entry is pruned + } HomeDbCache { cache: Mutex::new(HashMap::with_capacity(max_size)), config: HomeDbCacheConfig { @@ -221,11 +225,8 @@ impl SessionAuthKey { } impl HomeDbCacheKey { - pub(super) fn new( - imp_user: Option<&Arc>, - session_auth: Option<&Arc>, - ) -> Self { - if let Some(user) = imp_user { + pub(super) fn new(user: Option<&Arc>, session_auth: Option<&Arc>) -> Self { + if let Some(user) = user { HomeDbCacheKey::FixedUser(Arc::clone(user)) } else if let Some(auth) = session_auth { if let Some(ValueSend::String(scheme)) = auth.data.get("scheme") { @@ -247,3 +248,222 @@ struct HomeDbCacheEntry { database: Arc, last_used: Instant, } + +#[cfg(test)] +mod test { + use rstest::*; + + use crate::value::time; + use crate::value_map; + + use super::*; + + #[rstest] + #[case(HashMap::new(), HashMap::new())] + #[case( + value_map!({ + "list": [1, 1.5, ValueSend::Null, "string", true], + "principal": "user", + "map": value_map!({ + "nested": value_map!({ + "key": "value", + "when": time::LocalDateTime::new( + time::Date::from_ymd_opt(2021, 1, 1).unwrap(), + time::LocalTime::from_hms_opt(12, 0, 0).unwrap(), + ), + }), + "point": spatial::Cartesian2D::new(1.0, 2.0), + "key": "value", + }), + "nan": ValueSend::Float(f64::NAN), + "foo": "bar", + }), + value_map!({ + "foo": "bar", + "principal": "user", + "nan": ValueSend::Float(f64::NAN), + "list": [1, 1.5, ValueSend::Null, "string", true], + "map": value_map!({ + "key": "value", + "nested": value_map!({ + "key": "value", + "when": time::LocalDateTime::new( + time::Date::from_ymd_opt(2021, 1, 1).unwrap(), + time::LocalTime::from_hms_opt(12, 0, 0).unwrap(), + ), + }), + "point": spatial::Cartesian2D::new(1.0, 2.0), + }), + }) + )] + fn test_cache_key_equality( + #[case] a: HashMap, + #[case] b: HashMap, + ) { + let auth1 = Arc::new(AuthToken { data: a }); + let auth2 = Arc::new(AuthToken { data: b }); + let key1 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth1))); + let key2 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth2))); + #[allow(clippy::eq_op)] // we're explicitly testing the equality implementation here + { + assert_eq!(key1, key1); + assert_eq!(key2, key2); + } + assert_eq!(key1, key2); + assert_eq!(key2, key1); + + let mut hasher1 = std::collections::hash_map::DefaultHasher::new(); + let mut hasher2 = std::collections::hash_map::DefaultHasher::new(); + key1.hash(&mut hasher1); + key2.hash(&mut hasher2); + assert_eq!(hasher1.finish(), hasher2.finish()); + } + + #[rstest] + #[case(value_map!({"principal": "user"}), value_map!({"principal": "admin"}))] + #[case(value_map!({"int": 1}), value_map!({"int": 2}))] + #[case(value_map!({"int": 1}), value_map!({"int": 1.0}))] + #[case(value_map!({"zero": 0.0}), value_map!({"zero": -0.0}))] + #[case(value_map!({"large": f64::INFINITY}), value_map!({"large": f64::NEG_INFINITY}))] + #[case(value_map!({"nan": f64::NAN}), value_map!({"nan": -f64::NAN}))] + #[case(value_map!({"int": 1}), value_map!({"int": "1"}))] + #[case(value_map!({"list": [1, 2]}), value_map!({"list": [2, 1]}))] + fn test_cache_key_inequality( + #[case] a: HashMap, + #[case] b: HashMap, + ) { + let auth1 = Arc::new(AuthToken { data: a }); + let auth2 = Arc::new(AuthToken { data: b }); + let key1 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth1))); + let key2 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth2))); + assert_ne!(key1, key2); + } + + fn fixed_user_key(user: &str) -> HomeDbCacheKey { + HomeDbCacheKey::FixedUser(Arc::new(user.to_string())) + } + + fn auth_basic(principal: &str) -> AuthToken { + AuthToken { + data: value_map!({ + "scheme": "basic", + "principal": principal, + "credentials": "password", + }), + } + } + + fn any_auth_key() -> HomeDbCacheKey { + HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::new(AuthToken { + data: Default::default(), + }))) + } + + #[rstest] + #[case(None, None, HomeDbCacheKey::DriverUser)] + #[case(Some("user"), None, fixed_user_key("user"))] + #[case(Some("user"), Some(auth_basic("user2")), fixed_user_key("user"))] + #[case( + None, + Some(AuthToken::new_basic_auth("user2", "password")), + fixed_user_key("user2") + )] + #[case( + None, + Some(AuthToken::new_basic_auth_with_realm("user2", "password", "my-realm")), + fixed_user_key("user2") + )] + #[case(None, Some(AuthToken::new_basic_auth("", "empty")), fixed_user_key(""))] + #[case(None, Some(AuthToken::new_none_auth()), any_auth_key())] + #[case(None, Some(AuthToken::new_bearer_auth("token123")), any_auth_key())] + #[case(None, Some(AuthToken::new_kerberos_auth("token123")), any_auth_key())] + #[case( + None, + Some(AuthToken::new_custom_auth(None, None, None, None, None)), + any_auth_key() + )] + #[case( + None, + Some(AuthToken::new_custom_auth( + Some("principal".into()), + Some("credentials".into()), + Some("realm".into()), + Some("scheme".into()), + Some(value_map!({"key": "value"})), + )), + any_auth_key() + )] + fn test_cache_key_new( + #[case] user: Option<&str>, + #[case] session_auth: Option, + #[case] expected: HomeDbCacheKey, + ) { + let user = user.map(String::from).map(Arc::new); + let session_auth = session_auth.map(Arc::new); + let expected = match expected { + HomeDbCacheKey::SessionAuth(_) => HomeDbCacheKey::SessionAuth(SessionAuthKey( + Arc::clone(session_auth.as_ref().unwrap()), + )), + _ => expected, + }; + assert_eq!( + HomeDbCacheKey::new(user.as_ref(), session_auth.as_ref()), + expected + ); + } + + #[rstest] + #[case(0, 0)] + #[case(1, 1)] + #[case(5, 1)] + #[case(50, 1)] + #[case(60, 2)] + #[case(100, 4)] + #[case(200, 10)] + #[case(1_000, 69)] + #[case(10_000, 921)] + #[case(100_000, 11_512)] + #[case(1_000_000, 138_155)] + fn test_cache_pruning_size(#[case] max_size: usize, #[case] expected: usize) { + let cache = HomeDbCache::new(max_size); + assert_eq!(cache.config.prune_size, expected); + } + + #[test] + fn test_pruning() { + const SIZE: usize = 200; + const PRUNE_SIZE: usize = 10; + let cache = HomeDbCache::new(SIZE); + // sanity check + assert_eq!(cache.config.prune_size, PRUNE_SIZE); + + let users: Vec<_> = (0..=SIZE).map(|i| Arc::new(format!("user{i}"))).collect(); + let keys: Vec<_> = (0..=SIZE) + .map(|i| HomeDbCacheKey::new(Some(&users[i]), None)) + .collect(); + let entries: Vec<_> = (0..=SIZE).map(|i| Arc::new(format!("db{i}"))).collect(); + + // WHEN: cache is filled to the max + for i in 0..SIZE { + cache.update(keys[i].clone(), Arc::clone(&entries[i])); + } + // THEN: no entry has been removed + for i in 0..SIZE { + assert_eq!(cache.get(&keys[i]), Some(Arc::clone(&entries[i]))); + } + + // WHEN: The oldest entry is touched + cache.get(&keys[0]); + // AND: cache is filled with one more entry + cache.update(keys[SIZE].clone(), Arc::clone(&entries[SIZE])); + // THEN: the oldest PRUNE_SIZE entries (2nd to (PRUNE_SIZE + 1)th) are pruned + for key in keys.iter().skip(1).take(PRUNE_SIZE) { + assert_eq!(cache.get(key), None); + } + // AND: the rest of the entries are still in the cache + assert_eq!(cache.get(&keys[0]), Some(Arc::clone(&entries[0]))); + for i in PRUNE_SIZE + 2..=SIZE { + assert_eq!(cache.get(&keys[i]), Some(Arc::clone(&entries[i]))); + } + } +} From 99f20005368dff2c17497f6567e2d891a614e652 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Sat, 15 Feb 2025 12:01:51 +0100 Subject: [PATCH 3/4] Run cargo check on tests and MSRV + stable --- .pre-commit-config.yaml | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1fb0809..e24e109 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,25 @@ repos: hooks: - id: check name: cargo check - entry: cargo check -- + entry: cargo +stable check --all --tests -- + language: system + types: [rust] + pass_filenames: false + - id: check-msrv-all-features + name: cargo check (all features) + entry: cargo +stable check --all --tests --all-features -- + language: system + types: [rust] + pass_filenames: false + - id: check-msrv + name: cargo check MSRV + entry: cargo +1.70 check --all --tests -- + language: system + types: [rust] + pass_filenames: false + - id: check-msrv-all-features + name: cargo check MSRV (all features) + entry: cargo +1.70 check --all --tests --all-features -- language: system types: [rust] pass_filenames: false From 61eaf6432bec5dd9af87a9a818938b44b6dac5a4 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Sat, 15 Feb 2025 16:21:23 +0100 Subject: [PATCH 4/4] Refactoring --- neo4j/src/driver/home_db_cache.rs | 26 ++++++++++--------- neo4j/src/driver/io/bolt.rs | 16 ++++++------ neo4j/src/driver/io/bolt/bolt5x0/protocol.rs | 7 +---- neo4j/src/driver/io/bolt/bolt5x2/protocol.rs | 7 +---- .../src/driver/io/bolt/message_parameters.rs | 4 +-- neo4j/src/driver/io/pool.rs | 18 +++++-------- neo4j/src/driver/session.rs | 2 +- 7 files changed, 33 insertions(+), 47 deletions(-) diff --git a/neo4j/src/driver/home_db_cache.rs b/neo4j/src/driver/home_db_cache.rs index 43889e0..01cfc89 100644 --- a/neo4j/src/driver/home_db_cache.rs +++ b/neo4j/src/driver/home_db_cache.rs @@ -226,19 +226,21 @@ impl SessionAuthKey { impl HomeDbCacheKey { pub(super) fn new(user: Option<&Arc>, session_auth: Option<&Arc>) -> Self { - if let Some(user) = user { - HomeDbCacheKey::FixedUser(Arc::clone(user)) - } else if let Some(auth) = session_auth { - if let Some(ValueSend::String(scheme)) = auth.data.get("scheme") { - if scheme == "basic" { - if let Some(ValueSend::String(user)) = auth.data.get("principal") { - return HomeDbCacheKey::FixedUser(Arc::new(user.clone())); - } - } + fn get_basic_auth_principal(auth: &AuthToken) -> Option<&str> { + let scheme = auth.data.get("scheme")?.as_string()?.as_str(); + if scheme != "basic" { + return None; } - HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(auth))) - } else { - HomeDbCacheKey::DriverUser + Some(auth.data.get("principal")?.as_string()?.as_str()) + } + + match (user, session_auth) { + (Some(user), _) => HomeDbCacheKey::FixedUser(Arc::clone(user)), + (None, Some(auth)) => match get_basic_auth_principal(auth) { + Some(user) => HomeDbCacheKey::FixedUser(Arc::new(user.to_string())), + None => HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(auth))), + }, + (None, None) => HomeDbCacheKey::DriverUser, } } } diff --git a/neo4j/src/driver/io/bolt.rs b/neo4j/src/driver/io/bolt.rs index 086576c..f3aca74 100644 --- a/neo4j/src/driver/io/bolt.rs +++ b/neo4j/src/driver/io/bolt.rs @@ -523,15 +523,15 @@ trait BoltProtocol: Debug { #[enum_dispatch(BoltProtocol)] #[derive(Debug)] enum BoltProtocolVersion { - V5x8(Bolt5x8), - V5x7(Bolt5x7), - V5x6(Bolt5x6), - V5x4(Bolt5x4), - V5x3(Bolt5x3), - V5x2(Bolt5x2), - V5x1(Bolt5x1), - V5x0(Bolt5x0), V4x4(Bolt4x4), + V5x0(Bolt5x0), + V5x1(Bolt5x1), + V5x2(Bolt5x2), + V5x3(Bolt5x3), + V5x4(Bolt5x4), + V5x6(Bolt5x6), + V5x7(Bolt5x7), + V5x8(Bolt5x8), } #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] diff --git a/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs index 5293a1d..42c8997 100644 --- a/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x0/protocol.rs @@ -709,12 +709,7 @@ impl BoltProtocol for Bolt5x0 { Self::write_mode_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, mode)?; - Self::write_db_entry( - log_buf.as_mut(), - &mut serializer, - &mut dbg_serializer, - db.as_deref().map(String::as_str), - )?; + Self::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; Self::write_imp_user_entry( log_buf.as_mut(), diff --git a/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs b/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs index 3f18453..c9ebdc1 100644 --- a/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs +++ b/neo4j/src/driver/io/bolt/bolt5x2/protocol.rs @@ -300,12 +300,7 @@ impl BoltProtocol for Bolt5x2 { mode, )?; - Bolt5x0::::write_db_entry( - log_buf.as_mut(), - &mut serializer, - &mut dbg_serializer, - db.as_deref().map(String::as_str), - )?; + Bolt5x0::::write_db_entry(log_buf.as_mut(), &mut serializer, &mut dbg_serializer, db)?; Bolt5x0::::write_imp_user_entry( log_buf.as_mut(), diff --git a/neo4j/src/driver/io/bolt/message_parameters.rs b/neo4j/src/driver/io/bolt/message_parameters.rs index 09f6656..1cc7380 100644 --- a/neo4j/src/driver/io/bolt/message_parameters.rs +++ b/neo4j/src/driver/io/bolt/message_parameters.rs @@ -84,7 +84,7 @@ pub(crate) struct RunParameters<'a, KP: Borrow + Debug, KM: Borrow + D pub(super) tx_timeout: Option, pub(super) tx_metadata: Option<&'a HashMap>, pub(super) mode: Option<&'a str>, - pub(super) db: Option>, + pub(super) db: Option<&'a str>, pub(super) imp_user: Option<&'a str>, pub(super) notification_filter: Option<&'a NotificationFilter>, } @@ -98,7 +98,7 @@ impl<'a, KP: Borrow + Debug, KM: Borrow + Debug> RunParameters<'a, KP, tx_timeout: Option, tx_metadata: Option<&'a HashMap>, mode: Option<&'a str>, - db: Option>, + db: Option<&'a str>, imp_user: Option<&'a str>, notification_filter: &'a NotificationFilter, ) -> Self { diff --git a/neo4j/src/driver/io/pool.rs b/neo4j/src/driver/io/pool.rs index cc59fef..c96a4e0 100644 --- a/neo4j/src/driver/io/pool.rs +++ b/neo4j/src/driver/io/pool.rs @@ -503,7 +503,7 @@ impl RoutingPool { ) -> Result<(RwLockReadGuard, Option>)> { let rt_args = args.update_rt_args; let db_key = rt_args.rt_key(); - let db_name = RefCell::new(rt_args.db_name()); + let db_name = RefCell::new(rt_args.db_request()); let db_name_ref = &db_name; let lock = self.routing_tables.maybe_write( |rts| { @@ -589,19 +589,13 @@ impl RoutingPool { Some(args_db) if !args_db.guess => { let db = Some(Arc::clone(&args_db.db)); new_rt.database.clone_from(&db); - debug!("Storing new routing table for {:?}: {:?}", db, new_rt); - rts.insert(db.as_ref().map(Arc::clone), new_rt); - self.clean_up_pools(rts); - db - } - _ => { - let db = new_rt.database.clone(); - debug!("Storing new routing table for {:?}: {:?}", db, new_rt); - rts.insert(db.clone(), new_rt); - self.clean_up_pools(rts); db } + _ => new_rt.database.clone(), }; + debug!("Storing new routing table for {:?}: {:?}", db, new_rt); + rts.insert(db.as_ref().map(Arc::clone), new_rt); + self.clean_up_pools(rts); if let Some(cb) = args.db_resolution_cb { cb(db.as_ref().map(Arc::clone)); } @@ -916,7 +910,7 @@ impl UpdateRtArgs<'_> { }) } - fn db_name(&self) -> Option> { + fn db_request(&self) -> Option> { self.db.as_ref().and_then(|db| match db.guess { true => None, false => Some(Arc::clone(&db.db)), diff --git a/neo4j/src/driver/session.rs b/neo4j/src/driver/session.rs index a2d5c45..8f90907 100644 --- a/neo4j/src/driver/session.rs +++ b/neo4j/src/driver/session.rs @@ -169,7 +169,7 @@ impl<'driver> Session<'driver> { builder.timeout.raw(), Some(builder.meta.borrow()), builder.mode.as_protocol_str(), - target_db, + target_db.as_deref().map(String::as_str), self.config .config .as_ref()