diff --git a/postgres-native-tls/src/lib.rs b/postgres-native-tls/src/lib.rs index 2e2f04573..b631e8b44 100644 --- a/postgres-native-tls/src/lib.rs +++ b/postgres-native-tls/src/lib.rs @@ -1,33 +1,48 @@ +//! Native TLS support for the `postgres` crate. pub extern crate native_tls; extern crate postgres; use native_tls::TlsConnector; use postgres::tls::{Stream, TlsHandshake, TlsStream}; use std::error::Error; -use std::fmt::{self, Debug}; +use std::fmt; use std::io::{self, Read, Write}; #[cfg(test)] mod test; -pub struct NativeTls { - connector: TlsConnector, -} +/// A `TlsHandshake` implementation that uses the native-tls crate. +/// +/// Requires the `with-native-tls` feature. +pub struct NativeTls(TlsConnector); -impl Debug for NativeTls { +impl fmt::Debug for NativeTls { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.debug_struct("NativeTls").finish() } } impl NativeTls { + /// Creates a new `NativeTls` with its default configuration. pub fn new() -> Result { let connector = TlsConnector::builder().build()?; - Ok(NativeTls::with_connector(connector)) + Ok(NativeTls(connector)) + } + + /// Returns a reference to the inner `TlsConnector`. + pub fn connector(&self) -> &TlsConnector { + &self.0 } - pub fn with_connector(connector: TlsConnector) -> NativeTls { - NativeTls { connector } + /// Returns a mutable reference to the inner `TlsConnector`. + pub fn connector_mut(&mut self) -> &mut TlsConnector { + &mut self.0 + } +} + +impl From for NativeTls { + fn from(connector: TlsConnector) -> NativeTls { + NativeTls(connector) } } @@ -36,8 +51,8 @@ impl TlsHandshake for NativeTls { &self, domain: &str, stream: Stream, - ) -> Result, Box> { - let stream = self.connector.connect(domain, stream)?; + ) -> Result, Box> { + let stream = self.0.connect(domain, stream)?; Ok(Box::new(NativeTlsStream(stream))) } } diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index fae5fd9b2..c77a1d400 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -12,7 +12,7 @@ fn connect() { builder.add_root_certificate(cert); let connector = builder.build().unwrap(); - let handshake = NativeTls::with_connector(connector); + let handshake = NativeTls::from(connector); let conn = Connection::connect( "postgres://ssl_user@localhost:5433/postgres", TlsMode::Require(&handshake), diff --git a/postgres-openssl/src/lib.rs b/postgres-openssl/src/lib.rs index 8b01e83a2..f5eb45c98 100644 --- a/postgres-openssl/src/lib.rs +++ b/postgres-openssl/src/lib.rs @@ -1,19 +1,23 @@ +//! OpenSSL support for the `postgres` crate. pub extern crate openssl; extern crate postgres; -use openssl::error::ErrorStack; -use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslStream}; -use postgres::tls::{Stream, TlsHandshake, TlsStream}; use std::error::Error; -use std::fmt; use std::io::{self, Read, Write}; +use std::fmt; +use openssl::error::ErrorStack; +use openssl::ssl::{SslMethod, SslConnector, SslStream}; +use postgres::tls::{TlsStream, Stream, TlsHandshake}; #[cfg(test)] mod test; +/// A `TlsHandshake` implementation that uses OpenSSL. +/// +/// Requires the `with-openssl` feature. pub struct OpenSsl { connector: SslConnector, - config: Box Result<(), ErrorStack> + Sync + Send>, + disable_verification: bool, } impl fmt::Debug for OpenSsl { @@ -23,23 +27,39 @@ impl fmt::Debug for OpenSsl { } impl OpenSsl { + /// Creates a `OpenSsl` with `SslConnector`'s default configuration. pub fn new() -> Result { let connector = SslConnector::builder(SslMethod::tls())?.build(); - Ok(OpenSsl::with_connector(connector)) + Ok(OpenSsl::from(connector)) } - pub fn with_connector(connector: SslConnector) -> OpenSsl { - OpenSsl { - connector, - config: Box::new(|_| Ok(())), - } + /// Returns a reference to the inner `SslConnector`. + pub fn connector(&self) -> &SslConnector { + &self.connector + } + + /// Returns a mutable reference to the inner `SslConnector`. + pub fn connector_mut(&mut self) -> &mut SslConnector { + &mut self.connector } - pub fn callback(&mut self, f: F) - where - F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send, - { - self.config = Box::new(f); + /// If set, the + /// `SslConnector::danger_connect_without_providing_domain_for_certificate_verification_and_server_name_indication` + /// method will be used to connect. + /// + /// If certificate verification has been disabled in the `SslConnector`, verification must be + /// additionally disabled here for that setting to take effect. + pub fn danger_disable_hostname_verification(&mut self, disable_verification: bool) { + self.disable_verification = disable_verification; + } +} + +impl From for OpenSsl { + fn from(connector: SslConnector) -> OpenSsl { + OpenSsl { + connector: connector, + disable_verification: false, + } } } @@ -48,11 +68,13 @@ impl TlsHandshake for OpenSsl { &self, domain: &str, stream: Stream, - ) -> Result, Box> { + ) -> Result, Box> { let mut ssl = self.connector.configure()?; - (self.config)(&mut ssl)?; + if self.disable_verification { + ssl.set_use_server_name_indication(false); + ssl.set_verify_hostname(false); + } let stream = ssl.connect(domain, stream)?; - Ok(Box::new(OpenSslStream(stream))) } } diff --git a/postgres-openssl/src/test.rs b/postgres-openssl/src/test.rs index 1d3b7278e..711528c2d 100644 --- a/postgres-openssl/src/test.rs +++ b/postgres-openssl/src/test.rs @@ -7,7 +7,7 @@ use OpenSsl; fn require() { let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_ca_file("../test/server.crt").unwrap(); - let negotiator = OpenSsl::with_connector(builder.build()); + let negotiator = OpenSsl::from(builder.build()); let conn = Connection::connect( "postgres://ssl_user@localhost:5433/postgres", TlsMode::Require(&negotiator), @@ -19,7 +19,7 @@ fn require() { fn prefer() { let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_ca_file("../test/server.crt").unwrap(); - let negotiator = OpenSsl::with_connector(builder.build()); + let negotiator = OpenSsl::from(builder.build()); let conn = Connection::connect( "postgres://ssl_user@localhost:5433/postgres", TlsMode::Require(&negotiator),