diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 634b71de4b..d46253b1c0 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -33,21 +33,21 @@ impl PgConnection { ("TimeZone", "UTC"), ]; - if let Some(ref extra_float_digits) = options.extra_float_digits { + if let Some(extra_float_digits) = options.get_extra_float_digits() { params.push(("extra_float_digits", extra_float_digits)); } - if let Some(ref application_name) = options.application_name { + if let Some(application_name) = options.get_application_name() { params.push(("application_name", application_name)); } - if let Some(ref options) = options.options { + if let Some(options) = options.get_options() { params.push(("options", options)); } stream.write(Startup { - username: Some(&options.username), - database: options.database.as_deref(), + username: Some(options.get_username()), + database: Some(options.get_database()), params: ¶ms, })?; @@ -77,7 +77,7 @@ impl PgConnection { stream .send(Password::Cleartext( - options.password.as_deref().unwrap_or_default(), + options.get_password().unwrap_or_default(), )) .await?; } @@ -90,8 +90,8 @@ impl PgConnection { stream .send(Password::Md5 { - username: &options.username, - password: options.password.as_deref().unwrap_or_default(), + username: options.get_username(), + password: options.get_password().unwrap_or_default(), salt: body.salt, }) .await?; diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index 729cc1fcc5..ec64856f5b 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -52,7 +52,7 @@ pub(crate) async fn authenticate( BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding); // "n=" saslname ;; Usernames are prepared using SASLprep. - let username = format!("{}={}", USERNAME_ATTR, options.username); + let username = format!("{}={}", USERNAME_ATTR, options.get_username()); let username = match saslprep(&username) { Ok(v) => v, // TODO(danielakhterov): Remove panic when we have proper support for configuration errors @@ -87,7 +87,7 @@ pub(crate) async fn authenticate( // SaltedPassword := Hi(Normalize(password), salt, i) let salted_password = hi( - options.password.as_deref().unwrap_or_default(), + options.get_password().unwrap_or_default(), &cont.salt, cont.iterations, )?; diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index e8a1aedc47..025309fd6b 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -44,7 +44,14 @@ impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { let socket_result = match options.fetch_socket() { Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, - None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, + None => { + net::connect_tcp( + options.get_host(), + options.get_port(), + MaybeUpgradeTls(options), + ) + .await? + } }; let socket = socket_result?; diff --git a/sqlx-postgres/src/connection/tls.rs b/sqlx-postgres/src/connection/tls.rs index a49c9caa8c..080718f3fb 100644 --- a/sqlx-postgres/src/connection/tls.rs +++ b/sqlx-postgres/src/connection/tls.rs @@ -20,7 +20,7 @@ async fn maybe_upgrade( options: &PgConnectOptions, ) -> Result, Error> { // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS - match options.ssl_mode { + match options.get_ssl_mode() { // FIXME: Implement ALLOW PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)), @@ -46,15 +46,15 @@ async fn maybe_upgrade( } let accept_invalid_certs = !matches!( - options.ssl_mode, + options.get_ssl_mode(), PgSslMode::VerifyCa | PgSslMode::VerifyFull ); - let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull); + let accept_invalid_hostnames = !matches!(options.get_ssl_mode(), PgSslMode::VerifyFull); let config = TlsConfig { accept_invalid_certs, accept_invalid_hostnames, - hostname: &options.host, + hostname: options.get_host(), root_cert_path: options.ssl_root_cert.as_ref(), client_cert_path: options.ssl_client_cert.as_ref(), client_key_path: options.ssl_client_key.as_ref(), diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index b96c021be2..ceed876b7a 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -18,25 +18,19 @@ use crate::query_scalar::query_scalar; use crate::{PgConnectOptions, PgConnection, Postgres}; fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error> { - let mut options = PgConnectOptions::from_str(url)?; + let options = PgConnectOptions::from_str(url)?; // pull out the name of the database to create - let database = options - .database - .as_deref() - .unwrap_or(&options.username) - .to_owned(); + let database = options.get_database().to_owned(); // switch us to the maintenance database // use `postgres` _unless_ the database is postgres, in which case, use `template1` // this matches the behavior of the `createdb` util - options.database = if database == "postgres" { - Some("template1".into()) + if database == "postgres" { + Ok((options.database("template1"), database)) } else { - Some("postgres".into()) - }; - - Ok((options, database)) + Ok((options.database("postgres"), database)) + } } impl MigrateDatabase for Postgres { diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index efbc43989b..5fb6db9dc3 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -96,10 +96,10 @@ impl PgConnectOptions { pub(crate) fn apply_pgpass(mut self) -> Self { if self.password.is_none() { self.password = pgpass::load_password( - &self.host, - self.port, - &self.username, - self.database.as_deref(), + self.get_host(), + self.get_port(), + self.get_username(), + self.get_database(), ); } @@ -519,18 +519,34 @@ impl PgConnectOptions { &self.username } + /// Get the password. + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .password("53C237"); + /// assert_eq!(options.get_password(), Some("53C237")); + /// ``` + pub fn get_password(&self) -> Option<&str> { + self.password.as_deref() + } + /// Get the current database name. /// + /// Defaults to username if not given. + /// /// # Example /// /// ```rust /// # use sqlx_postgres::PgConnectOptions; - /// let options = PgConnectOptions::new() - /// .database("postgres"); - /// assert!(options.get_database().is_some()); + /// let options = PgConnectOptions::new().database("postgres"); + /// assert_eq!(options.get_database(), "postgres"); + /// + /// let options = PgConnectOptions::new().username("alice"); + /// assert_eq!(options.get_database(), "alice"); /// ``` - pub fn get_database(&self) -> Option<&str> { - self.database.as_deref() + pub fn get_database(&self) -> &str { + self.database.as_deref().unwrap_or(&self.username) } /// Get the SSL mode. @@ -560,6 +576,19 @@ impl PgConnectOptions { self.application_name.as_deref() } + /// Get the extra float digits. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new(); + /// assert_eq!(options.get_extra_float_digits(), Some("2")); + /// ``` + pub fn get_extra_float_digits(&self) -> std::option::Option<&str> { + self.extra_float_digits.as_deref() + } + /// Get the options. /// /// # Example diff --git a/sqlx-postgres/src/options/pgpass.rs b/sqlx-postgres/src/options/pgpass.rs index bf16559548..f56e977e6e 100644 --- a/sqlx-postgres/src/options/pgpass.rs +++ b/sqlx-postgres/src/options/pgpass.rs @@ -5,12 +5,7 @@ use std::io::{BufRead, BufReader}; use std::path::PathBuf; /// try to load a password from the various pgpass file locations -pub fn load_password( - host: &str, - port: u16, - username: &str, - database: Option<&str>, -) -> Option { +pub fn load_password(host: &str, port: u16, username: &str, database: &str) -> Option { let custom_file = var_os("PGPASSFILE"); if let Some(file) = custom_file { if let Some(password) = @@ -39,7 +34,7 @@ fn load_password_from_file( host: &str, port: u16, username: &str, - database: Option<&str>, + database: &str, ) -> Option { let file = File::open(&path) .map_err(|e| { @@ -88,7 +83,7 @@ fn load_password_from_reader( host: &str, port: u16, username: &str, - database: Option<&str>, + database: &str, ) -> Option { let mut line = String::new(); @@ -129,7 +124,7 @@ fn load_password_from_line( host: &str, port: u16, username: &str, - database: Option<&str>, + database: &str, ) -> Option { let whole_line = line; @@ -140,7 +135,7 @@ fn load_password_from_line( _ => { matches_next_field(whole_line, &mut line, host)?; matches_next_field(whole_line, &mut line, &port.to_string())?; - matches_next_field(whole_line, &mut line, database.unwrap_or_default())?; + matches_next_field(whole_line, &mut line, database)?; matches_next_field(whole_line, &mut line, username)?; Some(line.to_owned()) } @@ -268,41 +263,24 @@ mod tests { "localhost", 5432, "foo", - Some("bar") + "bar", ), Some("baz".to_owned()) ); // wildcard assert_eq!( - load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar")), - Some("baz".to_owned()) - ); - // accept wildcard with missing db - assert_eq!( - load_password_from_line("localhost:5432:*:foo:baz", "localhost", 5432, "foo", None), + load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", "bar"), Some("baz".to_owned()) ); // doesn't match assert_eq!( - load_password_from_line( - "thishost:5432:bar:foo:baz", - "thathost", - 5432, - "foo", - Some("bar") - ), + load_password_from_line("thishost:5432:bar:foo:baz", "thathost", 5432, "foo", "bar",), None ); // malformed entry assert_eq!( - load_password_from_line( - "localhost:5432:bar:foo", - "localhost", - 5432, - "foo", - Some("bar") - ), + load_password_from_line("localhost:5432:bar:foo", "localhost", 5432, "foo", "bar",), None ); } @@ -323,28 +301,23 @@ mod tests { // normal assert_eq!( - load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")), + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", "bar"), Some("baz".to_owned()) ); // wildcard assert_eq!( - load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")), - Some("baz".to_owned()) - ); - // accept wildcard with missing db - assert_eq!( - load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None), + load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", "foobar"), Some("baz".to_owned()) ); // doesn't match assert_eq!( - load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), + load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", "foobar"), None ); // malformed entry assert_eq!( - load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), + load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", "foobar"), None ); } diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index 3e1cf0ddf7..84623931dc 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -108,14 +108,14 @@ async fn test_context(args: &TestArgs) -> Result, Error> { Err((existing, pool)) => { // Sanity checks. assert_eq!( - existing.connect_options().host, - pool.connect_options().host, + existing.connect_options().get_host(), + pool.connect_options().get_host(), "DATABASE_URL changed at runtime, host differs" ); assert_eq!( - existing.connect_options().database, - pool.connect_options().database, + existing.connect_options().get_database(), + pool.connect_options().get_database(), "DATABASE_URL changed at runtime, database differs" );