Skip to content

Commit

Permalink
Use server default db over hardcoded default db (#197)
Browse files Browse the repository at this point in the history
* Add option to pass env and server config to testcontainers

* Run testcontainers async as all tests are async as well

* Only skip EE test when the license was not accepted

* Add test to show that the server default db is used

* Use server default db over hardcoded default db
  • Loading branch information
knutwalker authored Aug 28, 2024
1 parent fe990fb commit fc69715
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 62 deletions.
25 changes: 10 additions & 15 deletions lib/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::errors::{Error, Result};
use std::path::Path;
use std::{ops::Deref, sync::Arc};

const DEFAULT_DATABASE: &str = "neo4j";
const DEFAULT_FETCH_SIZE: usize = 200;
const DEFAULT_MAX_CONNECTIONS: usize = 16;

Expand All @@ -24,12 +23,6 @@ impl From<String> for Database {
}
}

impl Default for Database {
fn default() -> Self {
Database(DEFAULT_DATABASE.into())
}
}

impl AsRef<str> for Database {
fn as_ref(&self) -> &str {
&self.0
Expand All @@ -47,7 +40,7 @@ impl Deref for Database {
/// The configuration that is used once a connection is alive.
#[derive(Debug, Clone)]
pub struct LiveConfig {
pub(crate) db: Database,
pub(crate) db: Option<Database>,
pub(crate) fetch_size: usize,
}

Expand All @@ -58,7 +51,7 @@ pub struct Config {
pub(crate) user: String,
pub(crate) password: String,
pub(crate) max_connections: usize,
pub(crate) db: Database,
pub(crate) db: Option<Database>,
pub(crate) fetch_size: usize,
pub(crate) client_certificate: Option<ClientCertificate>,
}
Expand All @@ -77,7 +70,7 @@ pub struct ConfigBuilder {
uri: Option<String>,
user: Option<String>,
password: Option<String>,
db: Database,
db: Option<Database>,
fetch_size: usize,
max_connections: usize,
client_certificate: Option<ClientCertificate>,
Expand Down Expand Up @@ -109,9 +102,11 @@ impl ConfigBuilder {

/// The name of the database to connect to.
///
/// Defaults to "neo4j" if not set.
/// Defaults to the server configured default database if not set.
/// The database can also be specified on a per-query level, which will
/// override this value.
pub fn db(mut self, db: impl Into<Database>) -> Self {
self.db = db.into();
self.db = Some(db.into());
self
}

Expand Down Expand Up @@ -160,7 +155,7 @@ impl Default for ConfigBuilder {
uri: None,
user: None,
password: None,
db: DEFAULT_DATABASE.into(),
db: None,
max_connections: DEFAULT_MAX_CONNECTIONS,
fetch_size: DEFAULT_FETCH_SIZE,
client_certificate: None,
Expand All @@ -186,7 +181,7 @@ mod tests {
assert_eq!(config.uri, "127.0.0.1:7687");
assert_eq!(config.user, "some_user");
assert_eq!(config.password, "some_password");
assert_eq!(&*config.db, "some_db");
assert_eq!(config.db.as_deref(), Some("some_db"));
assert_eq!(config.fetch_size, 10);
assert_eq!(config.max_connections, 5);
assert!(config.client_certificate.is_none());
Expand All @@ -203,7 +198,7 @@ mod tests {
assert_eq!(config.uri, "127.0.0.1:7687");
assert_eq!(config.user, "some_user");
assert_eq!(config.password, "some_password");
assert_eq!(&*config.db, "neo4j");
assert_eq!(config.db, None);
assert_eq!(config.fetch_size, 200);
assert_eq!(config.max_connections, 16);
assert!(config.client_certificate.is_none());
Expand Down
26 changes: 20 additions & 6 deletions lib/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Graph {
///
/// Transactions will not be automatically retried on any failure.
pub async fn start_txn(&self) -> Result<Txn> {
self.start_txn_on(self.config.db.clone()).await
self.impl_start_txn_on(self.config.db.clone()).await
}

/// Starts a new transaction on the provided database.
Expand All @@ -62,8 +62,12 @@ impl Graph {
///
/// Transactions will not be automatically retried on any failure.
pub async fn start_txn_on(&self, db: impl Into<Database>) -> Result<Txn> {
self.impl_start_txn_on(Some(db.into())).await
}

async fn impl_start_txn_on(&self, db: Option<Database>) -> Result<Txn> {
let connection = self.pool.get().await?;
Txn::new(db.into(), self.config.fetch_size, connection).await
Txn::new(db, self.config.fetch_size, connection).await
}

/// Runs a query on the configured database using a connection from the connection pool,
Expand All @@ -78,7 +82,7 @@ impl Graph {
///
/// use [`Graph::execute`] when you are interested in the result stream
pub async fn run(&self, q: Query) -> Result<()> {
self.run_on(&self.config.db, q).await
self.impl_run_on(self.config.db.clone(), q).await
}

/// Runs a query on the provided database using a connection from the connection pool.
Expand All @@ -92,12 +96,17 @@ impl Graph {
/// Use [`Graph::run`] for cases where you just want a write operation
///
/// use [`Graph::execute`] when you are interested in the result stream
pub async fn run_on(&self, db: &str, q: Query) -> Result<()> {
pub async fn run_on(&self, db: impl Into<Database>, q: Query) -> Result<()> {
self.impl_run_on(Some(db.into()), q).await
}

async fn impl_run_on(&self, db: Option<Database>, q: Query) -> Result<()> {
backoff::future::retry_notify(
self.pool.manager().backoff(),
|| {
let pool = &self.pool;
let query = &q;
let db = db.as_deref();
async move {
let mut connection = pool.get().await.map_err(crate::Error::from)?;
query.run_retryable(db, &mut connection).await
Expand All @@ -115,7 +124,7 @@ impl Graph {
/// This includes errors during a leader election or when the transaction resources on the server (memory, handles, ...) are exhausted.
/// Retries happen with an exponential backoff until a retry delay exceeds 60s, at which point the query fails with the last error as it would without any retry.
pub async fn execute(&self, q: Query) -> Result<DetachedRowStream> {
self.execute_on(&self.config.db, q).await
self.impl_execute_on(self.config.db.clone(), q).await
}

/// Executes a query on the provided database and returns a [`DetaRowStream`]
Expand All @@ -124,13 +133,18 @@ impl Graph {
/// All errors with the `Transient` error class as well as a few other error classes are considered retryable.
/// This includes errors during a leader election or when the transaction resources on the server (memory, handles, ...) are exhausted.
/// Retries happen with an exponential backoff until a retry delay exceeds 60s, at which point the query fails with the last error as it would without any retry.
pub async fn execute_on(&self, db: &str, q: Query) -> Result<DetachedRowStream> {
pub async fn execute_on(&self, db: impl Into<Database>, q: Query) -> Result<DetachedRowStream> {
self.impl_execute_on(Some(db.into()), q).await
}

async fn impl_execute_on(&self, db: Option<Database>, q: Query) -> Result<DetachedRowStream> {
backoff::future::retry_notify(
self.pool.manager().backoff(),
|| {
let pool = &self.pool;
let fetch_size = self.config.fetch_size;
let query = &q;
let db = db.as_deref();
async move {
let connection = pool.get().await.map_err(crate::Error::from)?;
query.execute_retryable(db, fetch_size, connection).await
Expand Down
9 changes: 5 additions & 4 deletions lib/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ impl BoltRequest {
BoltRequest::Hello(hello::Hello::new(data))
}

pub fn run(db: &str, query: &str, params: BoltMap) -> BoltRequest {
BoltRequest::Run(Run::new(db.into(), query.into(), params))
pub fn run(db: Option<&str>, query: &str, params: BoltMap) -> BoltRequest {
BoltRequest::Run(Run::new(db.map(Into::into), query.into(), params))
}

#[cfg_attr(
Expand All @@ -152,8 +152,9 @@ impl BoltRequest {
BoltRequest::Discard(discard::Discard::default())
}

pub fn begin(db: &str) -> BoltRequest {
let begin = Begin::new([("db".into(), db.into())].into_iter().collect());
pub fn begin(db: Option<&str>) -> BoltRequest {
let extra = db.into_iter().map(|db| ("db".into(), db.into())).collect();
let begin = Begin::new(extra);
BoltRequest::Begin(begin)
}

Expand Down
15 changes: 6 additions & 9 deletions lib/src/messages/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ pub struct Run {
}

impl Run {
pub fn new(db: BoltString, query: BoltString, parameters: BoltMap) -> Run {
pub fn new(db: Option<BoltString>, query: BoltString, parameters: BoltMap) -> Run {
Run {
query,
parameters,
extra: vec![("db".into(), BoltType::String(db))]
extra: db
.into_iter()
.map(|db| ("db".into(), BoltType::String(db)))
.collect(),
}
}
Expand All @@ -31,7 +32,7 @@ mod tests {
#[test]
fn should_serialize_run() {
let run = Run::new(
"test".into(),
Some("test".into()),
"query".into(),
vec![("k".into(), "v".into())].into_iter().collect(),
);
Expand Down Expand Up @@ -69,7 +70,7 @@ mod tests {

#[test]
fn should_serialize_run_with_no_params() {
let run = Run::new("".into(), "query".into(), BoltMap::default());
let run = Run::new(None, "query".into(), BoltMap::default());

let bytes: Bytes = run.into_bytes(Version::V4_1).unwrap();

Expand All @@ -85,11 +86,7 @@ mod tests {
b'r',
b'y',
map::TINY,
map::TINY | 1,
string::TINY | 2,
b'd',
b'b',
string::TINY,
map::TINY,
])
);
}
Expand Down
12 changes: 8 additions & 4 deletions lib/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ impl Query {
self.params.value.contains_key(key)
}

pub(crate) async fn run(self, db: &str, connection: &mut ManagedConnection) -> Result<()> {
pub(crate) async fn run(
self,
db: Option<&str>,
connection: &mut ManagedConnection,
) -> Result<()> {
let request = BoltRequest::run(db, &self.query, self.params);
Self::try_run(request, connection)
.await
Expand All @@ -54,7 +58,7 @@ impl Query {

pub(crate) async fn run_retryable(
&self,
db: &str,
db: Option<&str>,
connection: &mut ManagedConnection,
) -> QueryResult<()> {
let request = BoltRequest::run(db, &self.query, self.params.clone());
Expand All @@ -63,7 +67,7 @@ impl Query {

pub(crate) async fn execute_retryable(
&self,
db: &str,
db: Option<&str>,
fetch_size: usize,
mut connection: ManagedConnection,
) -> QueryResult<DetachedRowStream> {
Expand All @@ -75,7 +79,7 @@ impl Query {

pub(crate) async fn execute_mut<'conn>(
self,
db: &str,
db: Option<&str>,
fetch_size: usize,
connection: &'conn mut ManagedConnection,
) -> Result<RowStream> {
Expand Down
10 changes: 5 additions & 5 deletions lib/src/txn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ use crate::{
/// When a transation is started, a dedicated connection is resered and moved into the handle which
/// will be released to the connection pool when the [`Txn`] handle is dropped.
pub struct Txn {
db: Database,
db: Option<Database>,
fetch_size: usize,
connection: ManagedConnection,
}

impl Txn {
pub(crate) async fn new(
db: Database,
db: Option<Database>,
fetch_size: usize,
mut connection: ManagedConnection,
) -> Result<Self> {
let begin = BoltRequest::begin(&db);
let begin = BoltRequest::begin(db.as_deref());
match connection.send_recv(begin).await? {
BoltResponse::Success(_) => Ok(Txn {
db,
Expand All @@ -49,12 +49,12 @@ impl Txn {

/// Runs a single query and discards the stream.
pub async fn run(&mut self, q: Query) -> Result<()> {
q.run(&self.db, &mut self.connection).await
q.run(self.db.as_deref(), &mut self.connection).await
}

/// Executes a query and returns a [`RowStream`]
pub async fn execute(&mut self, q: Query) -> Result<RowStream> {
q.execute_mut(&self.db, self.fetch_size, &mut self.connection)
q.execute_mut(self.db.as_deref(), self.fetch_size, &mut self.connection)
.await
}

Expand Down
Loading

0 comments on commit fc69715

Please sign in to comment.