diff --git a/lib/src/config.rs b/lib/src/config.rs index f154f4b..549dba3 100644 --- a/lib/src/config.rs +++ b/lib/src/config.rs @@ -3,7 +3,6 @@ use crate::errors::{Error, Result}; use std::path::Path; use std::{ops::Deref, sync::Arc}; -const DEFAULT_DATABASE: &str = "neo4j"; const DEFAULT_FETCH_SIZE: usize = 200; const DEFAULT_MAX_CONNECTIONS: usize = 16; @@ -24,12 +23,6 @@ impl From for Database { } } -impl Default for Database { - fn default() -> Self { - Database(DEFAULT_DATABASE.into()) - } -} - impl AsRef for Database { fn as_ref(&self) -> &str { &self.0 @@ -47,7 +40,7 @@ impl Deref for Database { /// The configuration that is used once a connection is alive. #[derive(Debug, Clone)] pub struct LiveConfig { - pub(crate) db: Database, + pub(crate) db: Option, pub(crate) fetch_size: usize, } @@ -58,7 +51,7 @@ pub struct Config { pub(crate) user: String, pub(crate) password: String, pub(crate) max_connections: usize, - pub(crate) db: Database, + pub(crate) db: Option, pub(crate) fetch_size: usize, pub(crate) client_certificate: Option, } @@ -77,7 +70,7 @@ pub struct ConfigBuilder { uri: Option, user: Option, password: Option, - db: Database, + db: Option, fetch_size: usize, max_connections: usize, client_certificate: Option, @@ -109,9 +102,11 @@ impl ConfigBuilder { /// The name of the database to connect to. /// - /// Defaults to "neo4j" if not set. + /// Defaults to the server configured default database if not set. + /// The database can also be specified on a per-query level, which will + /// override this value. pub fn db(mut self, db: impl Into) -> Self { - self.db = db.into(); + self.db = Some(db.into()); self } @@ -160,7 +155,7 @@ impl Default for ConfigBuilder { uri: None, user: None, password: None, - db: DEFAULT_DATABASE.into(), + db: None, max_connections: DEFAULT_MAX_CONNECTIONS, fetch_size: DEFAULT_FETCH_SIZE, client_certificate: None, @@ -186,7 +181,7 @@ mod tests { assert_eq!(config.uri, "127.0.0.1:7687"); assert_eq!(config.user, "some_user"); assert_eq!(config.password, "some_password"); - assert_eq!(&*config.db, "some_db"); + assert_eq!(config.db.as_deref(), Some("some_db")); assert_eq!(config.fetch_size, 10); assert_eq!(config.max_connections, 5); assert!(config.client_certificate.is_none()); @@ -203,7 +198,7 @@ mod tests { assert_eq!(config.uri, "127.0.0.1:7687"); assert_eq!(config.user, "some_user"); assert_eq!(config.password, "some_password"); - assert_eq!(&*config.db, "neo4j"); + assert_eq!(config.db, None); assert_eq!(config.fetch_size, 200); assert_eq!(config.max_connections, 16); assert!(config.client_certificate.is_none()); diff --git a/lib/src/graph.rs b/lib/src/graph.rs index 4b92c45..5300e31 100644 --- a/lib/src/graph.rs +++ b/lib/src/graph.rs @@ -53,7 +53,7 @@ impl Graph { /// /// Transactions will not be automatically retried on any failure. pub async fn start_txn(&self) -> Result { - self.start_txn_on(self.config.db.clone()).await + self.impl_start_txn_on(self.config.db.clone()).await } /// Starts a new transaction on the provided database. @@ -62,8 +62,12 @@ impl Graph { /// /// Transactions will not be automatically retried on any failure. pub async fn start_txn_on(&self, db: impl Into) -> Result { + self.impl_start_txn_on(Some(db.into())).await + } + + async fn impl_start_txn_on(&self, db: Option) -> Result { let connection = self.pool.get().await?; - Txn::new(db.into(), self.config.fetch_size, connection).await + Txn::new(db, self.config.fetch_size, connection).await } /// Runs a query on the configured database using a connection from the connection pool, @@ -78,7 +82,7 @@ impl Graph { /// /// use [`Graph::execute`] when you are interested in the result stream pub async fn run(&self, q: Query) -> Result<()> { - self.run_on(&self.config.db, q).await + self.impl_run_on(self.config.db.clone(), q).await } /// Runs a query on the provided database using a connection from the connection pool. @@ -92,12 +96,17 @@ impl Graph { /// Use [`Graph::run`] for cases where you just want a write operation /// /// use [`Graph::execute`] when you are interested in the result stream - pub async fn run_on(&self, db: &str, q: Query) -> Result<()> { + pub async fn run_on(&self, db: impl Into, q: Query) -> Result<()> { + self.impl_run_on(Some(db.into()), q).await + } + + async fn impl_run_on(&self, db: Option, q: Query) -> Result<()> { backoff::future::retry_notify( self.pool.manager().backoff(), || { let pool = &self.pool; let query = &q; + let db = db.as_deref(); async move { let mut connection = pool.get().await.map_err(crate::Error::from)?; query.run_retryable(db, &mut connection).await @@ -115,7 +124,7 @@ impl Graph { /// This includes errors during a leader election or when the transaction resources on the server (memory, handles, ...) are exhausted. /// Retries happen with an exponential backoff until a retry delay exceeds 60s, at which point the query fails with the last error as it would without any retry. pub async fn execute(&self, q: Query) -> Result { - self.execute_on(&self.config.db, q).await + self.impl_execute_on(self.config.db.clone(), q).await } /// Executes a query on the provided database and returns a [`DetaRowStream`] @@ -124,13 +133,18 @@ impl Graph { /// All errors with the `Transient` error class as well as a few other error classes are considered retryable. /// This includes errors during a leader election or when the transaction resources on the server (memory, handles, ...) are exhausted. /// Retries happen with an exponential backoff until a retry delay exceeds 60s, at which point the query fails with the last error as it would without any retry. - pub async fn execute_on(&self, db: &str, q: Query) -> Result { + pub async fn execute_on(&self, db: impl Into, q: Query) -> Result { + self.impl_execute_on(Some(db.into()), q).await + } + + async fn impl_execute_on(&self, db: Option, q: Query) -> Result { backoff::future::retry_notify( self.pool.manager().backoff(), || { let pool = &self.pool; let fetch_size = self.config.fetch_size; let query = &q; + let db = db.as_deref(); async move { let connection = pool.get().await.map_err(crate::Error::from)?; query.execute_retryable(db, fetch_size, connection).await diff --git a/lib/src/messages.rs b/lib/src/messages.rs index af469da..4343674 100644 --- a/lib/src/messages.rs +++ b/lib/src/messages.rs @@ -132,8 +132,8 @@ impl BoltRequest { BoltRequest::Hello(hello::Hello::new(data)) } - pub fn run(db: &str, query: &str, params: BoltMap) -> BoltRequest { - BoltRequest::Run(Run::new(db.into(), query.into(), params)) + pub fn run(db: Option<&str>, query: &str, params: BoltMap) -> BoltRequest { + BoltRequest::Run(Run::new(db.map(Into::into), query.into(), params)) } #[cfg_attr( @@ -152,8 +152,9 @@ impl BoltRequest { BoltRequest::Discard(discard::Discard::default()) } - pub fn begin(db: &str) -> BoltRequest { - let begin = Begin::new([("db".into(), db.into())].into_iter().collect()); + pub fn begin(db: Option<&str>) -> BoltRequest { + let extra = db.into_iter().map(|db| ("db".into(), db.into())).collect(); + let begin = Begin::new(extra); BoltRequest::Begin(begin) } diff --git a/lib/src/messages/run.rs b/lib/src/messages/run.rs index a6c5fd2..2515235 100644 --- a/lib/src/messages/run.rs +++ b/lib/src/messages/run.rs @@ -10,12 +10,13 @@ pub struct Run { } impl Run { - pub fn new(db: BoltString, query: BoltString, parameters: BoltMap) -> Run { + pub fn new(db: Option, query: BoltString, parameters: BoltMap) -> Run { Run { query, parameters, - extra: vec![("db".into(), BoltType::String(db))] + extra: db .into_iter() + .map(|db| ("db".into(), BoltType::String(db))) .collect(), } } @@ -31,7 +32,7 @@ mod tests { #[test] fn should_serialize_run() { let run = Run::new( - "test".into(), + Some("test".into()), "query".into(), vec![("k".into(), "v".into())].into_iter().collect(), ); @@ -69,7 +70,7 @@ mod tests { #[test] fn should_serialize_run_with_no_params() { - let run = Run::new("".into(), "query".into(), BoltMap::default()); + let run = Run::new(None, "query".into(), BoltMap::default()); let bytes: Bytes = run.into_bytes(Version::V4_1).unwrap(); @@ -85,11 +86,7 @@ mod tests { b'r', b'y', map::TINY, - map::TINY | 1, - string::TINY | 2, - b'd', - b'b', - string::TINY, + map::TINY, ]) ); } diff --git a/lib/src/query.rs b/lib/src/query.rs index 42a3ab0..f878338 100644 --- a/lib/src/query.rs +++ b/lib/src/query.rs @@ -45,7 +45,11 @@ impl Query { self.params.value.contains_key(key) } - pub(crate) async fn run(self, db: &str, connection: &mut ManagedConnection) -> Result<()> { + pub(crate) async fn run( + self, + db: Option<&str>, + connection: &mut ManagedConnection, + ) -> Result<()> { let request = BoltRequest::run(db, &self.query, self.params); Self::try_run(request, connection) .await @@ -54,7 +58,7 @@ impl Query { pub(crate) async fn run_retryable( &self, - db: &str, + db: Option<&str>, connection: &mut ManagedConnection, ) -> QueryResult<()> { let request = BoltRequest::run(db, &self.query, self.params.clone()); @@ -63,7 +67,7 @@ impl Query { pub(crate) async fn execute_retryable( &self, - db: &str, + db: Option<&str>, fetch_size: usize, mut connection: ManagedConnection, ) -> QueryResult { @@ -75,7 +79,7 @@ impl Query { pub(crate) async fn execute_mut<'conn>( self, - db: &str, + db: Option<&str>, fetch_size: usize, connection: &'conn mut ManagedConnection, ) -> Result { diff --git a/lib/src/txn.rs b/lib/src/txn.rs index 9252c61..f82c957 100644 --- a/lib/src/txn.rs +++ b/lib/src/txn.rs @@ -14,18 +14,18 @@ use crate::{ /// When a transation is started, a dedicated connection is resered and moved into the handle which /// will be released to the connection pool when the [`Txn`] handle is dropped. pub struct Txn { - db: Database, + db: Option, fetch_size: usize, connection: ManagedConnection, } impl Txn { pub(crate) async fn new( - db: Database, + db: Option, fetch_size: usize, mut connection: ManagedConnection, ) -> Result { - let begin = BoltRequest::begin(&db); + let begin = BoltRequest::begin(db.as_deref()); match connection.send_recv(begin).await? { BoltResponse::Success(_) => Ok(Txn { db, @@ -49,12 +49,12 @@ impl Txn { /// Runs a single query and discards the stream. pub async fn run(&mut self, q: Query) -> Result<()> { - q.run(&self.db, &mut self.connection).await + q.run(self.db.as_deref(), &mut self.connection).await } /// Executes a query and returns a [`RowStream`] pub async fn execute(&mut self, q: Query) -> Result { - q.execute_mut(&self.db, self.fetch_size, &mut self.connection) + q.execute_mut(self.db.as_deref(), self.fetch_size, &mut self.connection) .await } diff --git a/lib/tests/container.rs b/lib/tests/container.rs index 1f42030..e021288 100644 --- a/lib/tests/container.rs +++ b/lib/tests/container.rs @@ -1,6 +1,6 @@ use lenient_semver::Version; use neo4rs::{ConfigBuilder, Graph}; -use testcontainers::{runners::SyncRunner, Container, ImageExt}; +use testcontainers::{runners::AsyncRunner, ContainerAsync, ContainerRequest, ImageExt}; use testcontainers_modules::neo4j::{Neo4j, Neo4jImage}; use std::{error::Error, io::BufRead as _}; @@ -10,6 +10,7 @@ use std::{error::Error, io::BufRead as _}; pub struct Neo4jContainerBuilder { enterprise: bool, config: ConfigBuilder, + env: Vec<(String, String)>, } #[allow(dead_code)] @@ -23,25 +24,39 @@ impl Neo4jContainerBuilder { self } - pub fn with_config(mut self, config: ConfigBuilder) -> Self { + pub fn with_driver_config(mut self, config: ConfigBuilder) -> Self { self.config = config; self } - pub fn modify_config(mut self, block: impl FnOnce(ConfigBuilder) -> ConfigBuilder) -> Self { + pub fn modify_driver_config( + mut self, + block: impl FnOnce(ConfigBuilder) -> ConfigBuilder, + ) -> Self { self.config = block(self.config); self } + pub fn add_env(mut self, key: impl Into, value: impl Into) -> Self { + self.env.push((key.into(), value.into())); + self + } + + pub fn with_server_config(self, key: &str, value: impl Into) -> Self { + let key = format!("NEO4J_{}", key.replace('_', "__").replace('.', "_")); + self.add_env(key, value) + } + pub async fn start(self) -> Result> { - Neo4jContainer::from_config_and_edition(self.config, self.enterprise).await + Neo4jContainer::from_config_and_edition_and_env(self.config, self.enterprise, self.env) + .await } } pub struct Neo4jContainer { graph: Graph, version: String, - _container: Option>, + _container: Option>, } impl Neo4jContainer { @@ -58,6 +73,20 @@ impl Neo4jContainer { config: ConfigBuilder, enterprise_edition: bool, ) -> Result> { + Self::from_config_and_edition_and_env::<_, String, String>(config, enterprise_edition, []) + .await + } + + pub async fn from_config_and_edition_and_env( + config: ConfigBuilder, + enterprise_edition: bool, + env_vars: I, + ) -> Result> + where + I: IntoIterator, + K: Into, + V: Into, + { let _ = pretty_env_logger::try_init(); let server = Self::server_from_env(); @@ -65,7 +94,12 @@ impl Neo4jContainer { let (uri, _container) = match server { TestServer::TestContainer => { - let (uri, container) = Self::create_testcontainer(&connection, enterprise_edition)?; + let (uri, container) = Self::create_testcontainer( + &connection, + enterprise_edition, + env_vars.into_iter().map(|(k, v)| (k.into(), v.into())), + ) + .await?; (uri, Some(container)) } TestServer::External(uri) | TestServer::Aura(uri) => (uri, None), @@ -107,10 +141,30 @@ impl Neo4jContainer { .unwrap_or(TestServer::TestContainer) } - fn create_testcontainer( + async fn create_testcontainer( connection: &TestConnection, enterprise: bool, - ) -> Result<(String, Container), Box> { + env_vars: I, + ) -> Result<(String, ContainerAsync), Box> + where + I: Iterator, + { + let container = Self::create_testcontainer_image(connection, enterprise, env_vars)?; + let container = container.start().await?; + + let uri = format!("bolt://127.0.0.1:{}", container.image().bolt_port_ipv4()?); + + Ok((uri, container)) + } + + fn create_testcontainer_image( + connection: &TestConnection, + enterprise: bool, + env_vars: I, + ) -> Result, Box> + where + I: Iterator, + { let image = Neo4j::new() .with_user(connection.auth.user.to_owned()) .with_password(connection.auth.pass.to_owned()); @@ -147,17 +201,17 @@ impl Neo4jContainer { .into()); } - image - .with_version(version) - .with_env_var("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes") - .start()? + env_vars.fold( + image + .with_version(version) + .with_env_var("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes"), + |i, (k, v)| i.with_env_var(k, v), + ) } else { - image.with_version(connection.version.to_owned()).start()? + image.with_version(connection.version.to_owned()).into() }; - let uri = format!("bolt://127.0.0.1:{}", container.image().bolt_port_ipv4()?); - - Ok((uri, container)) + Ok(container) } fn create_test_endpoint(use_aura: bool) -> TestConnection { diff --git a/lib/tests/txn_change_db.rs b/lib/tests/txn_change_db.rs index 9457bab..93ed8bd 100644 --- a/lib/tests/txn_change_db.rs +++ b/lib/tests/txn_change_db.rs @@ -7,15 +7,19 @@ mod container; #[tokio::test] async fn txn_changes_db() { let neo4j = match container::Neo4jContainerBuilder::new() - .modify_config(|c| c.db("deebee")) + .modify_driver_config(|c| c.db("deebee")) .with_enterprise_edition() .start() .await { Ok(n) => n, Err(e) => { - eprintln!("Skipping test: {}", e); - return; + if e.to_string().contains("Neo4j Enterprise Edition") { + eprintln!("Skipping test: {}", e); + return; + } + + std::panic::panic_any(e); } }; let graph = neo4j.graph(); diff --git a/lib/tests/use_default_db.rs b/lib/tests/use_default_db.rs new file mode 100644 index 0000000..501efd6 --- /dev/null +++ b/lib/tests/use_default_db.rs @@ -0,0 +1,69 @@ +use futures::TryStreamExt; +use neo4rs::*; + +mod container; + +#[tokio::test] +async fn use_default_db() { + let dbname = uuid::Uuid::new_v4().to_string().replace(['-', '_'], ""); + + let neo4j = match container::Neo4jContainerBuilder::new() + .with_server_config("initial.dbms.default_database", dbname.as_str()) + .with_enterprise_edition() + .start() + .await + { + Ok(n) => n, + Err(e) => { + if e.to_string().contains("Neo4j Enterprise Edition") { + eprintln!("Skipping test: {}", e); + return; + } + + std::panic::panic_any(e); + } + }; + let graph = neo4j.graph(); + + let default_db = graph + .execute_on("system", query("SHOW DEFAULT DATABASE")) + .await + .unwrap() + .column_into_stream::("name") + .try_fold(None::, |acc, db| async { Ok(acc.or(Some(db))) }) + .await + .unwrap() + .unwrap(); + + if default_db != dbname { + eprintln!( + concat!( + "Skipping test: The test must run against a testcontainer ", + "or have `{}` configured as the default database" + ), + dbname + ); + return; + } + + let id = uuid::Uuid::new_v4(); + graph + .run(query("CREATE (:Node { uuid: $uuid })").param("uuid", id.to_string())) + .await + .unwrap(); + + let count = graph + .execute_on( + dbname.as_str(), + query("MATCH (n:Node {uuid: $uuid}) RETURN count(n) AS result") + .param("uuid", id.to_string()), + ) + .await + .unwrap() + .column_into_stream::("result") + .try_fold(0, |sum, count| async move { Ok(sum + count) }) + .await + .unwrap(); + + assert_eq!(count, 1); +}