diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index b46b9d91b..e968dea4b 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -25,23 +25,23 @@ jobs: include: - rust: stable os: ubuntu-latest - features: "" + features: "--features run_tokio_rustls_tests" target: "x86_64-unknown-linux-gnu" - rust: beta os: ubuntu-latest - features: "" + features: "--features run_tokio_rustls_tests" target: "x86_64-unknown-linux-gnu" - rust: "msrv" os: ubuntu-latest - features: "" + features: "--features run_tokio_rustls_tests" target: "x86_64-unknown-linux-gnu" - rust: "stable" os: ubuntu-latest - features: "" + features: "--features run_tokio_rustls_tests" target: "x86_64-unknown-linux-musl" - rust: "stable" os: macos-latest - features: "" + features: "--features run_tokio_rustls_tests" target: "x86_64-apple-darwin" - rust: "stable" os: ubuntu-latest @@ -59,6 +59,18 @@ jobs: os: macos-latest features: "--all-features" target: "x86_64-apple-darwin" + - rust: "stable" + os: ubuntu-latest + features: "--no-default-features --features rustls21" + target: "x86_64-unknown-linux-gnu" + - rust: "stable" + os: ubuntu-latest + features: "--no-default-features --features rustls22" + target: "x86_64-unknown-linux-gnu" + - rust: "stable" + os: ubuntu-latest + features: "--no-default-features --features rustls23" + target: "x86_64-unknown-linux-gnu" steps: - name: Checkout sources uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 @@ -237,7 +249,7 @@ jobs: shared-key: "stable-${{matrix.target}}" - name: Run clippy - run: cargo clippy --target ${{matrix.target}} --workspace --all-targets --all-features -- -D warnings + run: cargo clippy --target ${{matrix.target}} --workspace --all-targets --features hardware-timestamping,unstable_ntpv5,unstable_nts-pool,__internal-fuzz,__internal-test,__internal-api,ntpv5,nts-pool -- -D warnings - name: Run clippy (fuzzers) run: cargo clippy --target ${{matrix.target}} --manifest-path ./fuzz/Cargo.toml --all-targets -- -D warnings if: ${{matrix.fuzzer}} diff --git a/Cargo.lock b/Cargo.lock index 4f59e333c..1fec21cfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "bitflags" version = "2.6.0" @@ -357,8 +363,14 @@ dependencies = [ "arbitrary", "md-5", "rand", - "rustls", - "rustls-pemfile", + "rustls 0.21.12", + "rustls 0.22.4", + "rustls 0.23.20", + "rustls-native-certs 0.6.3", + "rustls-native-certs 0.7.3", + "rustls-pemfile 1.0.4", + "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_test", "tracing", @@ -374,9 +386,7 @@ dependencies = [ "libc", "ntp-proto", "rand", - "rustls", - "rustls-native-certs", - "rustls-pemfile", + "rustls 0.23.20", "serde", "serde_json", "timestamped-socket", @@ -392,9 +402,9 @@ name = "nts-pool-ke" version = "1.3.1" dependencies = [ "ntp-proto", - "rustls", - "rustls-native-certs", - "rustls-pemfile", + "rustls 0.23.20", + "rustls-native-certs 0.7.3", + "rustls-pemfile 2.2.0", "serde", "tokio", "tokio-rustls", @@ -524,6 +534,32 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + [[package]] name = "rustls" version = "0.23.20" @@ -534,11 +570,23 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.102.8", "subtle", "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework", +] + [[package]] name = "rustls-native-certs" version = "0.7.3" @@ -546,12 +594,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" dependencies = [ "openssl-probe", - "rustls-pemfile", + "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -567,6 +624,16 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustls-webpki" version = "0.102.8" @@ -593,6 +660,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -769,7 +846,7 @@ version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ - "rustls", + "rustls 0.23.20", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index c109c6e0c..005b83f0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,9 +48,14 @@ timestamped-socket = "0.2.2" clock-steering = "0.2.1" # TLS -rustls = { version = "0.23.0", default-features = false, features = ["ring", "logging", "std", "tls12"] } -rustls-pemfile = "2.0" -rustls-native-certs = "0.7.0" +rustls23 = { package = "rustls", version = "0.23.0", default-features = false, features = ["ring", "logging", "std", "tls12"] } +rustls22 = { package = "rustls", version = "0.22.0", default-features = false, features = ["ring", "logging", "tls12"] } +rustls21 = { package = "rustls", version = "0.21.0" } +rustls-pemfile1 = { package = "rustls-pemfile", version = "1.0" } +rustls-pemfile2 = { package = "rustls-pemfile", version = "2.0" } +rustls-pki-types = "1.2" +rustls-native-certs6 = { package = "rustls-native-certs", version = "0.6" } +rustls-native-certs7 = { package = "rustls-native-certs", version = "0.7" } tokio-rustls = { version = "0.26.0", default-features = false, features = ["ring", "logging", "tls12"] } # testing only # crypto @@ -62,4 +67,4 @@ zeroize = "1.7" # our own crates used as dependencies, same version as the workspace version # NOTE: keep this part at the bottom of the file, do not change this line -ntp-proto = { version = "1.3.1", path = "./ntp-proto", features = ["__internal-api"] } +ntp-proto = { version = "1.3.1", path = "./ntp-proto", default-features = false, features = ["__internal-api"] } diff --git a/ntp-proto/Cargo.toml b/ntp-proto/Cargo.toml index 3f6dcdbb8..0d5eeefb1 100644 --- a/ntp-proto/Cargo.toml +++ b/ntp-proto/Cargo.toml @@ -13,12 +13,15 @@ rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = [] +default = ["rustls23"] __internal-fuzz = ["arbitrary", "__internal-api"] __internal-test = ["__internal-api"] __internal-api = [] ntpv5 = [] -nts-pool = [] +nts-pool = [ "rustls23" ] +rustls23 = [ "dep:rustls23", "dep:rustls-pemfile2", "dep:rustls-native-certs7" ] +rustls22 = [ "dep:rustls22", "dep:rustls-pemfile2", "dep:rustls-native-certs7", "dep:rustls-pki-types" ] +rustls21 = [ "dep:rustls21", "dep:rustls-pemfile1", "dep:rustls-native-certs6" ] [dependencies] # Note: md5 is needed to calculate ReferenceIDs for IPv6 addresses per RFC5905 @@ -27,11 +30,17 @@ rand.workspace = true tracing.workspace = true serde.workspace = true arbitrary = { workspace = true, optional = true } -rustls.workspace = true +rustls23 = { workspace = true, optional = true } +rustls22 = { workspace = true, optional = true } +rustls21 = { workspace = true, optional = true } +rustls-pki-types = { workspace = true, optional = true } +rustls-pemfile2 = { workspace = true, optional = true } +rustls-pemfile1 = { workspace = true, optional = true } +rustls-native-certs6 = { workspace = true, optional = true } +rustls-native-certs7 = { workspace = true, optional = true } aead.workspace = true aes-siv.workspace = true zeroize.workspace = true [dev-dependencies] -rustls-pemfile.workspace = true serde_test.workspace = true diff --git a/ntp-proto/src/lib.rs b/ntp-proto/src/lib.rs index ec06cf9c3..fab02f281 100644 --- a/ntp-proto/src/lib.rs +++ b/ntp-proto/src/lib.rs @@ -25,7 +25,6 @@ mod time_types; #[cfg(feature = "nts-pool")] mod nts_pool_ke; -#[cfg(feature = "nts-pool")] pub mod tls_utils; pub(crate) mod exitcode { diff --git a/ntp-proto/src/nts_pool_ke.rs b/ntp-proto/src/nts_pool_ke.rs index 0cfe5fad7..908048506 100644 --- a/ntp-proto/src/nts_pool_ke.rs +++ b/ntp-proto/src/nts_pool_ke.rs @@ -89,7 +89,7 @@ pub struct ClientToPoolData { impl ClientToPoolData { pub fn extract_nts_keys( &self, - stream: &rustls::ConnectionCommon, + stream: &rustls23::ConnectionCommon, ) -> Result { self.algorithm .extract_nts_keys(self.protocol, stream) diff --git a/ntp-proto/src/nts_record.rs b/ntp-proto/src/nts_record.rs index 554366fe6..b31b90cdc 100644 --- a/ntp-proto/src/nts_record.rs +++ b/ntp-proto/src/nts_record.rs @@ -5,7 +5,7 @@ use std::{ sync::Arc, }; -use rustls::pki_types::ServerName; +use crate::tls_utils::{self, ServerName}; use crate::{ cookiestash::CookieStash, @@ -676,9 +676,9 @@ pub enum KeyExchangeError { InvalidFixedKeyLength, NoCookies, Io(std::io::Error), - Tls(rustls::Error), - Certificate(rustls::Error), - DnsName(rustls::pki_types::InvalidDnsNameError), + Tls(tls_utils::Error), + Certificate(tls_utils::Error), + DnsName(tls_utils::InvalidDnsNameError), IncompleteResponse, } @@ -719,14 +719,14 @@ impl From for KeyExchangeError { } } -impl From for KeyExchangeError { - fn from(value: rustls::Error) -> Self { +impl From for KeyExchangeError { + fn from(value: crate::tls_utils::Error) -> Self { Self::Tls(value) } } -impl From for KeyExchangeError { - fn from(value: rustls::pki_types::InvalidDnsNameError) -> Self { +impl From for KeyExchangeError { + fn from(value: tls_utils::InvalidDnsNameError) -> Self { Self::DnsName(value) } } @@ -847,8 +847,8 @@ impl AeadAlgorithm { pub(crate) fn extract_nts_keys( &self, protocol: ProtocolId, - tls_connection: &rustls::ConnectionCommon, - ) -> Result { + tls_connection: &tls_utils::ConnectionCommon, + ) -> Result { match self { AeadAlgorithm::AeadAesSivCmac256 => { let c2s = extract_nts_key(tls_connection, self.c2s_context(protocol))?; @@ -923,9 +923,9 @@ impl std::fmt::Debug for NtsKeys { } fn extract_nts_key, ConnectionData>( - tls_connection: &rustls::ConnectionCommon, + tls_connection: &tls_utils::ConnectionCommon, context: [u8; 5], -) -> Result { +) -> Result { let mut key = T::default(); tls_connection.export_keying_material( &mut key, @@ -1156,7 +1156,7 @@ pub struct KeyExchangeResult { } pub struct KeyExchangeClient { - tls_connection: rustls::ClientConnection, + tls_connection: tls_utils::ClientConnection, decoder: KeyExchangeResultDecoder, server_name: String, } @@ -1239,14 +1239,14 @@ impl KeyExchangeClient { // should only be used in tests! fn new_without_tls_write( server_name: String, - mut tls_config: rustls::ClientConfig, + mut tls_config: tls_utils::ClientConfig, ) -> Result { // Ensure we send only ntske/1 as alpn tls_config.alpn_protocols.clear(); tls_config.alpn_protocols.push(b"ntske/1".to_vec()); // TLS only works when the server name is a DNS name; an IP address does not work - let tls_connection = rustls::ClientConnection::new( + let tls_connection = tls_utils::ClientConnection::new( Arc::new(tls_config), ServerName::try_from(&server_name as &str)?.to_owned(), )?; @@ -1260,7 +1260,7 @@ impl KeyExchangeClient { pub fn new( server_name: String, - tls_config: rustls::ClientConfig, + tls_config: tls_utils::ClientConfig, ntp_version: Option, denied_servers: impl IntoIterator, ) -> Result { @@ -1539,13 +1539,13 @@ impl KeyExchangeServerDecoder { #[derive(Debug)] pub struct KeyExchangeServer { - tls_connection: rustls::ServerConnection, + tls_connection: tls_utils::ServerConnection, state: State, keyset: Arc, ntp_port: Option, ntp_server: Option, #[cfg(feature = "nts-pool")] - pool_certificates: Arc<[rustls::pki_types::CertificateDer<'static>]>, + pool_certificates: Arc<[tls_utils::Certificate]>, } #[derive(Debug)] @@ -1573,7 +1573,7 @@ impl KeyExchangeServer { } fn send_records( - tls_connection: &mut rustls::ServerConnection, + tls_connection: &mut tls_utils::ServerConnection, records: &[NtsRecord], ) -> std::io::Result<()> { let mut buffer = Vec::with_capacity(1024); @@ -1587,7 +1587,10 @@ impl KeyExchangeServer { Ok(()) } - fn send_error_record(tls_connection: &mut rustls::ServerConnection, error: &KeyExchangeError) { + fn send_error_record( + tls_connection: &mut tls_utils::ServerConnection, + error: &KeyExchangeError, + ) { let error_records = [ NtsRecord::Error { errorcode: error.to_error_code(), @@ -1605,7 +1608,7 @@ impl KeyExchangeServer { pub fn progress( mut self, - ) -> ControlFlow, Self> { + ) -> ControlFlow, Self> { // Move any received data from tls to decoder if let Err(e) = self.tls_connection.process_new_packets() { return ControlFlow::Break(Err(e.into())); @@ -1674,7 +1677,7 @@ impl KeyExchangeServer { } } - fn end_of_file(self) -> Result { + fn end_of_file(self) -> Result { match self.state { State::Active { .. } => { // there are no more client bytes, but decoding was not finished yet @@ -1736,7 +1739,7 @@ impl KeyExchangeServer { fn decoder_done( mut self, data: ServerKeyExchangeData, - ) -> ControlFlow, Self> { + ) -> ControlFlow, Self> { let algorithm = data.algorithm; let protocol = data.protocol; //TODO: see comment in fn server_key_exchange_records() @@ -1774,17 +1777,17 @@ impl KeyExchangeServer { } pub fn new( - tls_config: Arc, + tls_config: Arc, keyset: Arc, ntp_port: Option, ntp_server: Option, - pool_certificates: Arc<[rustls::pki_types::CertificateDer<'static>]>, + pool_certificates: Arc<[tls_utils::Certificate]>, ) -> Result { // Ensure we send only ntske/1 as alpn debug_assert_eq!(tls_config.alpn_protocols, &[b"ntske/1".to_vec()]); // TLS only works when the server name is a DNS name; an IP address does not work - let tls_connection = rustls::ServerConnection::new(tls_config)?; + let tls_connection = tls_utils::ServerConnection::new(tls_config)?; #[cfg(not(feature = "nts-pool"))] let _ = pool_certificates; @@ -2836,34 +2839,46 @@ mod test { #[test] fn test_keyexchange_client() { - let cert_chain: Vec = rustls_pemfile::certs( + let cert_chain: Vec = tls_utils::pemfile::certs( &mut std::io::BufReader::new(include_bytes!("../test-keys/end.fullchain.pem") as &[u8]), ) .map(|res| res.unwrap()) .collect(); - let key_der = rustls_pemfile::pkcs8_private_keys(&mut std::io::BufReader::new( + + let key_der = tls_utils::pemfile::pkcs8_private_keys(&mut std::io::BufReader::new( include_bytes!("../test-keys/end.key") as &[u8], )) .map(|res| res.unwrap()) .next() .unwrap(); - let serverconfig = rustls::ServerConfig::builder() + + let serverconfig = tls_utils::server_config_builder() .with_no_client_auth() .with_single_cert(cert_chain, key_der.into()) .unwrap(); - let mut root_store = rustls::RootCertStore::empty(); + + let mut root_store = tls_utils::RootCertStore::empty(); + #[cfg(any(feature = "rustls22", feature = "rustls23"))] root_store.add_parsable_certificates( - rustls_pemfile::certs(&mut std::io::BufReader::new(include_bytes!( + tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( "../test-keys/testca.pem" ) as &[u8])) .map(|res| res.unwrap()), ); + #[cfg(not(any(feature = "rustls22", feature = "rustls23")))] + root_store.add_parsable_certificates( + &tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( + "../test-keys/testca.pem" + ) as &[u8])) + .map(|res| res.unwrap()) + .collect::>(), + ); - let clientconfig = rustls::ClientConfig::builder() + let clientconfig = tls_utils::client_config_builder() .with_root_certificates(root_store) .with_no_client_auth(); - let mut server = rustls::ServerConnection::new(Arc::new(serverconfig)).unwrap(); + let mut server = tls_utils::ServerConnection::new(Arc::new(serverconfig)).unwrap(); let mut client = KeyExchangeClient::new("localhost".into(), clientconfig, None, vec![]).unwrap(); @@ -2907,34 +2922,46 @@ mod test { } fn client_server_pair(client_type: ClientType) -> (KeyExchangeClient, KeyExchangeServer) { - let cert_chain: Vec = rustls_pemfile::certs( + #[allow(unused)] + use tls_utils::CloneKeyShim; + + let cert_chain: Vec = tls_utils::pemfile::certs( &mut std::io::BufReader::new(include_bytes!("../test-keys/end.fullchain.pem") as &[u8]), ) .map(|res| res.unwrap()) .collect(); - let key_der = rustls_pemfile::pkcs8_private_keys(&mut std::io::BufReader::new( + let key_der = tls_utils::pemfile::pkcs8_private_keys(&mut std::io::BufReader::new( include_bytes!("../test-keys/end.key") as &[u8], )) .map(|res| res.unwrap()) .next() .unwrap(); - let mut root_store = rustls::RootCertStore::empty(); + let mut root_store = tls_utils::RootCertStore::empty(); + #[cfg(any(feature = "rustls22", feature = "rustls23"))] root_store.add_parsable_certificates( - rustls_pemfile::certs(&mut std::io::BufReader::new(include_bytes!( + tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( "../test-keys/testca.pem" ) as &[u8])) .map(|res| res.unwrap()), ); + #[cfg(not(any(feature = "rustls22", feature = "rustls23")))] + root_store.add_parsable_certificates( + &tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( + "../test-keys/testca.pem" + ) as &[u8])) + .map(|res| res.unwrap()) + .collect::>(), + ); - let mut serverconfig = rustls::ServerConfig::builder() + let mut serverconfig = tls_utils::server_config_builder() .with_client_cert_verifier(Arc::new( #[cfg(not(feature = "nts-pool"))] - rustls::server::NoClientAuth, + tls_utils::NoClientAuth, #[cfg(feature = "nts-pool")] crate::tls_utils::AllowAnyAnonymousOrCertificateBearingClient::new( // We know that our previous call to ServerConfig::builder already // installed a default provider, but this is undocumented - rustls::crypto::CryptoProvider::get_default().unwrap(), + rustls23::crypto::CryptoProvider::get_default().unwrap(), ), )) .with_single_cert(cert_chain.clone(), key_der.clone_key().into()) @@ -2944,10 +2971,10 @@ mod test { serverconfig.alpn_protocols.push(b"ntske/1".to_vec()); let clientconfig = match client_type { - ClientType::Uncertified => rustls::ClientConfig::builder() + ClientType::Uncertified => tls_utils::client_config_builder() .with_root_certificates(root_store) .with_no_client_auth(), - ClientType::Certified => rustls::ClientConfig::builder() + ClientType::Certified => tls_utils::client_config_builder() .with_root_certificates(root_store) .with_client_auth_cert(cert_chain, key_der.into()) .unwrap(), @@ -2955,7 +2982,7 @@ mod test { let keyset = KeySetProvider::new(8).get(); - let pool_cert: Vec = rustls_pemfile::certs( + let pool_cert: Vec = tls_utils::pemfile::certs( &mut std::io::BufReader::new(include_bytes!("../test-keys/end.pem") as &[u8]), ) .map(|res| res.unwrap()) diff --git a/ntp-proto/src/tls_utils.rs b/ntp-proto/src/tls_utils.rs index b57373f0d..fe576fe34 100644 --- a/ntp-proto/src/tls_utils.rs +++ b/ntp-proto/src/tls_utils.rs @@ -1,68 +1,315 @@ -/// The intent of this ClientCertVerifier is that it accepts any connections that are either -/// a.) not presenting a client certificate -/// b.) are presenting a well-formed, but otherwise not checked (against a trust root) client certificate -/// -/// This is because RusTLS apparently doesn't accept every kind of self-signed certificate. -/// -/// The only goal of this ClientCertVerifier is to achieve that, if a client presents a TLS certificate, -/// this certificate shows up in the .peer_certificates() for that connection. - -#[derive(Debug)] -pub struct AllowAnyAnonymousOrCertificateBearingClient { - supported_algs: WebPkiSupportedAlgorithms, +#[cfg(feature = "rustls23")] +mod rustls23_shim { + /// The intent of this ClientCertVerifier is that it accepts any connections that are either + /// a.) not presenting a client certificate + /// b.) are presenting a well-formed, but otherwise not checked (against a trust root) client certificate + /// + /// This is because RusTLS apparently doesn't accept every kind of self-signed certificate. + /// + /// The only goal of this ClientCertVerifier is to achieve that, if a client presents a TLS certificate, + /// this certificate shows up in the .peer_certificates() for that connection. + #[cfg(feature = "nts-pool")] + #[derive(Debug)] + pub struct AllowAnyAnonymousOrCertificateBearingClient { + supported_algs: WebPkiSupportedAlgorithms, + } + + #[cfg(feature = "nts-pool")] + use rustls23::{ + crypto::{CryptoProvider, WebPkiSupportedAlgorithms}, + pki_types::CertificateDer, + server::danger::ClientCertVerified, + }; + + #[cfg(feature = "nts-pool")] + impl AllowAnyAnonymousOrCertificateBearingClient { + pub fn new(provider: &CryptoProvider) -> Self { + AllowAnyAnonymousOrCertificateBearingClient { + supported_algs: provider.signature_verification_algorithms, + } + } + } + + #[cfg(feature = "nts-pool")] + impl rustls23::server::danger::ClientCertVerifier for AllowAnyAnonymousOrCertificateBearingClient { + fn verify_client_cert( + &self, + _end_entity: &CertificateDer, + _intermediates: &[CertificateDer], + _now: rustls23::pki_types::UnixTime, + ) -> Result { + Ok(ClientCertVerified::assertion()) + } + + fn client_auth_mandatory(&self) -> bool { + false + } + + fn root_hint_subjects(&self) -> &[rustls23::DistinguishedName] { + &[] + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls23::pki_types::CertificateDer<'_>, + dss: &rustls23::DigitallySignedStruct, + ) -> Result { + rustls23::crypto::verify_tls12_signature(message, cert, dss, &self.supported_algs) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls23::pki_types::CertificateDer<'_>, + dss: &rustls23::DigitallySignedStruct, + ) -> Result { + rustls23::crypto::verify_tls13_signature(message, cert, dss, &self.supported_algs) + } + + fn supported_verify_schemes(&self) -> Vec { + self.supported_algs.supported_schemes() + } + } + + pub use rustls23::pki_types::InvalidDnsNameError; + pub use rustls23::pki_types::ServerName; + pub use rustls23::server::NoClientAuth; + pub use rustls23::version::TLS13; + pub use rustls23::ClientConfig; + pub use rustls23::ClientConnection; + pub use rustls23::ConnectionCommon; + pub use rustls23::Error; + pub use rustls23::RootCertStore; + pub use rustls23::ServerConfig; + pub use rustls23::ServerConnection; + + pub type Certificate = rustls23::pki_types::CertificateDer<'static>; + pub type PrivateKey = rustls23::pki_types::PrivateKeyDer<'static>; + + pub mod pemfile { + pub use rustls_native_certs7::load_native_certs; + pub use rustls_pemfile2::certs; + pub use rustls_pemfile2::pkcs8_private_keys; + pub use rustls_pemfile2::private_key; + + pub fn rootstore_ref_shim(cert: &super::Certificate) -> super::Certificate { + cert.clone() + } + } + + pub trait CloneKeyShim {} + + pub fn client_config_builder( + ) -> rustls23::ConfigBuilder { + ClientConfig::builder() + } + + pub fn client_config_builder_with_protocol_versions( + versions: &[&'static rustls23::SupportedProtocolVersion], + ) -> rustls23::ConfigBuilder { + ClientConfig::builder_with_protocol_versions(versions) + } + + pub fn server_config_builder( + ) -> rustls23::ConfigBuilder { + ServerConfig::builder() + } + + pub fn server_config_builder_with_protocol_versions( + versions: &[&'static rustls23::SupportedProtocolVersion], + ) -> rustls23::ConfigBuilder { + ServerConfig::builder_with_protocol_versions(versions) + } } -use rustls::pki_types::CertificateDer; -use rustls::{ - crypto::{CryptoProvider, WebPkiSupportedAlgorithms}, - server::danger::ClientCertVerified, -}; +#[cfg(feature = "rustls22")] +mod rustls22_shim { + pub use rustls22::server::NoClientAuth; + pub use rustls22::version::TLS13; + pub use rustls22::ClientConfig; + pub use rustls22::ClientConnection; + pub use rustls22::ConnectionCommon; + pub use rustls22::Error; + pub use rustls22::RootCertStore; + pub use rustls22::ServerConfig; + pub use rustls22::ServerConnection; + pub use rustls_pki_types::InvalidDnsNameError; + pub use rustls_pki_types::ServerName; + + pub type Certificate = rustls_pki_types::CertificateDer<'static>; + pub type PrivateKey = rustls_pki_types::PrivateKeyDer<'static>; -impl AllowAnyAnonymousOrCertificateBearingClient { - pub fn new(provider: &CryptoProvider) -> Self { - AllowAnyAnonymousOrCertificateBearingClient { - supported_algs: provider.signature_verification_algorithms, + pub mod pemfile { + pub use rustls_native_certs7::load_native_certs; + pub use rustls_pemfile2::certs; + pub use rustls_pemfile2::pkcs8_private_keys; + pub use rustls_pemfile2::private_key; + + pub fn rootstore_ref_shim(cert: &super::Certificate) -> super::Certificate { + cert.clone() } } + + pub trait CloneKeyShim {} + + pub fn client_config_builder( + ) -> rustls22::ConfigBuilder { + ClientConfig::builder() + } + + pub fn client_config_builder_with_protocol_versions( + versions: &[&'static rustls22::SupportedProtocolVersion], + ) -> rustls22::ConfigBuilder { + ClientConfig::builder_with_protocol_versions(versions) + } + + pub fn server_config_builder( + ) -> rustls22::ConfigBuilder { + ServerConfig::builder() + } + + pub fn server_config_builder_with_protocol_versions( + versions: &[&'static rustls22::SupportedProtocolVersion], + ) -> rustls22::ConfigBuilder { + ServerConfig::builder_with_protocol_versions(versions) + } } -impl rustls::server::danger::ClientCertVerifier for AllowAnyAnonymousOrCertificateBearingClient { - fn verify_client_cert( - &self, - _end_entity: &CertificateDer, - _intermediates: &[CertificateDer], - _now: rustls::pki_types::UnixTime, - ) -> Result { - Ok(ClientCertVerified::assertion()) +#[cfg(feature = "rustls21")] +mod rustls21_shim { + pub use rustls21::client::InvalidDnsNameError; + pub use rustls21::client::ServerName; + pub use rustls21::server::NoClientAuth; + pub use rustls21::version::TLS13; + pub use rustls21::Certificate; + pub use rustls21::ClientConfig; + pub use rustls21::ClientConnection; + pub use rustls21::ConnectionCommon; + pub use rustls21::Error; + pub use rustls21::PrivateKey; + pub use rustls21::RootCertStore; + pub use rustls21::ServerConfig; + pub use rustls21::ServerConnection; + + pub fn client_config_builder( + ) -> rustls21::ConfigBuilder { + ClientConfig::builder().with_safe_defaults() } - fn client_auth_mandatory(&self) -> bool { - false + pub fn server_config_builder( + ) -> rustls21::ConfigBuilder { + ServerConfig::builder().with_safe_defaults() } - fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { - &[] + pub fn client_config_builder_with_protocol_versions( + versions: &[&'static rustls21::SupportedProtocolVersion], + ) -> rustls21::ConfigBuilder { + // Expect is ok here as this should never fail (not user controlled) + ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(versions) + .expect("Could not set protocol versions") + } + pub fn server_config_builder_with_protocol_versions( + versions: &[&'static rustls21::SupportedProtocolVersion], + ) -> rustls21::ConfigBuilder { + // Expect is ok here as this should never fail (not user controlled) + ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(versions) + .expect("Could not set protocol versions") } - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &rustls::pki_types::CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - rustls::crypto::verify_tls12_signature(message, cert, dss, &self.supported_algs) + pub trait CloneKeyShim { + fn clone_key(&self) -> Self; } - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &rustls::pki_types::CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - rustls::crypto::verify_tls13_signature(message, cert, dss, &self.supported_algs) + impl CloneKeyShim for PrivateKey { + fn clone_key(&self) -> Self { + self.clone() + } } - fn supported_verify_schemes(&self) -> Vec { - self.supported_algs.supported_schemes() + pub mod pemfile { + enum Either { + L(T), + R(U), + } + impl Iterator for Either + where + T: Iterator, + U: Iterator, + { + type Item = V; + + fn next(&mut self) -> Option { + match self { + Self::L(l) => l.next(), + Self::R(r) => r.next(), + } + } + } + + pub fn certs( + rd: &mut dyn std::io::BufRead, + ) -> impl Iterator> { + match rustls_pemfile1::certs(rd) { + Ok(v) => Either::L(v.into_iter().map(super::Certificate).map(Ok)), + Err(e) => Either::R(core::iter::once(Err(e))), + } + } + + pub fn pkcs8_private_keys( + rd: &mut dyn std::io::BufRead, + ) -> impl Iterator> { + match rustls_pemfile1::pkcs8_private_keys(rd) { + Ok(v) => Either::L(v.into_iter().map(super::PrivateKey).map(Ok)), + Err(e) => Either::R(core::iter::once(Err(e))), + } + } + + pub fn private_key( + rd: &mut dyn std::io::BufRead, + ) -> Result, std::io::Error> { + for item in std::iter::from_fn(|| rustls_pemfile1::read_one(rd).transpose()) { + match item { + Ok(rustls_pemfile1::Item::RSAKey(key)) + | Ok(rustls_pemfile1::Item::PKCS8Key(key)) + | Ok(rustls_pemfile1::Item::ECKey(key)) => { + return Ok(Some(super::PrivateKey(key))) + } + Err(e) => return Err(e), + _ => {} + } + } + + Ok(None) + } + + pub fn load_native_certs() -> Result, std::io::Error> { + Ok(rustls_native_certs6::load_native_certs()? + .into_iter() + .map(|v| super::Certificate(v.0)) + .collect()) + } + + pub fn rootstore_ref_shim(cert: &super::Certificate) -> &super::Certificate { + cert + } } } + +#[cfg(feature = "rustls23")] +pub use rustls23_shim::*; + +#[cfg(all(feature = "rustls22", not(feature = "rustls23")))] +pub use rustls22_shim::*; + +#[cfg(all( + feature = "rustls21", + not(any(feature = "rustls23", feature = "rustls22")) +))] +pub use rustls21_shim::*; diff --git a/ntpd/Cargo.toml b/ntpd/Cargo.toml index 6e1881456..9e12b4f7b 100644 --- a/ntpd/Cargo.toml +++ b/ntpd/Cargo.toml @@ -27,19 +27,21 @@ clock-steering.workspace = true serde.workspace = true serde_json.workspace = true -rustls.workspace = true -rustls-native-certs.workspace = true -rustls-pemfile.workspace = true +rustls23 = { workspace = true, optional = true } +tokio-rustls = { workspace = true, optional = true } [dev-dependencies] ntp-proto = { workspace = true, features = ["__internal-test",] } -tokio-rustls.workspace = true [features] -default = [] +default = [ "rustls23" ] hardware-timestamping = [] unstable_ntpv5 = ["ntp-proto/ntpv5"] -unstable_nts-pool = [ "ntp-proto/nts-pool" ] +unstable_nts-pool = [ "ntp-proto/nts-pool", "dep:rustls23" ] +rustls23 = [ "ntp-proto/rustls23" ] +rustls22 = [ "ntp-proto/rustls22" ] +rustls21 = [ "ntp-proto/rustls21" ] +run_tokio_rustls_tests = [ "rustls23", "dep:tokio-rustls" ] [lib] name = "ntpd" diff --git a/ntpd/src/daemon/config/ntp_source.rs b/ntpd/src/daemon/config/ntp_source.rs index bc6c96cf4..d7f7eb271 100644 --- a/ntpd/src/daemon/config/ntp_source.rs +++ b/ntpd/src/daemon/config/ntp_source.rs @@ -6,10 +6,10 @@ use std::{ sync::{Arc, Mutex}, }; +use ntp_proto::tls_utils::Certificate; use ntp_proto::NtpDuration; #[cfg(feature = "unstable_ntpv5")] use ntp_proto::NtpVersion; -use rustls::pki_types::CertificateDer; use serde::{de, Deserialize, Deserializer}; use super::super::keyexchange::certificates_from_file; @@ -46,7 +46,7 @@ pub struct NtsSourceConfig { default = "default_certificate_authorities", rename = "certificate-authority" )] - pub certificate_authorities: Arc<[CertificateDer<'static>]>, + pub certificate_authorities: Arc<[Certificate]>, #[cfg(feature = "unstable_ntpv5")] #[serde(default, deserialize_with = "deserialize_ntp_version")] pub ntp_version: Option, @@ -54,7 +54,7 @@ pub struct NtsSourceConfig { fn deserialize_certificate_authorities<'de, D>( deserializer: D, -) -> Result]>, D::Error> +) -> Result, D::Error> where D: Deserializer<'de>, { @@ -69,7 +69,7 @@ where } } -fn default_certificate_authorities() -> Arc<[CertificateDer<'static>]> { +fn default_certificate_authorities() -> Arc<[Certificate]> { Arc::from([]) } @@ -102,7 +102,7 @@ pub struct NtsPoolSourceConfig { default = "default_certificate_authorities", rename = "certificate-authority" )] - pub certificate_authorities: Arc<[CertificateDer<'static>]>, + pub certificate_authorities: Arc<[Certificate]>, #[serde(default = "max_sources_default")] pub count: usize, #[cfg(feature = "unstable_ntpv5")] diff --git a/ntpd/src/daemon/keyexchange.rs b/ntpd/src/daemon/keyexchange.rs index 6ffd1f792..22d3da88b 100644 --- a/ntpd/src/daemon/keyexchange.rs +++ b/ntpd/src/daemon/keyexchange.rs @@ -9,13 +9,10 @@ use std::{ }; use libc::{ECONNABORTED, EMFILE, ENFILE, ENOBUFS, ENOMEM}; +use ntp_proto::tls_utils::{self, Certificate, PrivateKey, TLS13}; use ntp_proto::{ KeyExchangeClient, KeyExchangeError, KeyExchangeResult, KeyExchangeServer, KeySet, NtpVersion, }; -use rustls::{ - pki_types::{CertificateDer, PrivateKeyDer}, - version::TLS13, -}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::TcpListener, @@ -27,12 +24,14 @@ use super::config::NtsKeConfig; use super::exitcode; async fn build_client_config( - extra_certificates: &[CertificateDer<'_>], -) -> Result { + extra_certificates: &[Certificate], +) -> Result { let mut roots = tokio::task::spawn_blocking(move || { - let mut roots = rustls::RootCertStore::empty(); - for cert in rustls_native_certs::load_native_certs()? { - roots.add(cert).map_err(KeyExchangeError::Certificate)?; + let mut roots = tls_utils::RootCertStore::empty(); + for cert in tls_utils::pemfile::load_native_certs()? { + roots + .add(tls_utils::pemfile::rootstore_ref_shim(&cert)) + .map_err(KeyExchangeError::Certificate)?; } Ok::<_, KeyExchangeError>(roots) }) @@ -41,12 +40,12 @@ async fn build_client_config( for cert in extra_certificates { roots - .add(cert.clone()) + .add(tls_utils::pemfile::rootstore_ref_shim(cert)) .map_err(KeyExchangeError::Certificate)?; } Ok( - rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13]) + tls_utils::client_config_builder_with_protocol_versions(&[&TLS13]) .with_root_certificates(roots) .with_no_client_auth(), ) @@ -55,7 +54,7 @@ async fn build_client_config( pub(crate) async fn key_exchange_client( server_name: String, port: u16, - extra_certificates: &[CertificateDer<'_>], + extra_certificates: &[Certificate], ntp_version: Option, ) -> Result { let socket = tokio::net::TcpStream::connect((server_name.as_str(), port)).await?; @@ -68,7 +67,7 @@ pub(crate) async fn key_exchange_client( pub(crate) async fn key_exchange_client_with_denied_servers( server_name: String, port: u16, - extra_certificates: &[CertificateDer<'_>], + extra_certificates: &[Certificate], ntp_version: Option, denied_servers: impl IntoIterator, ) -> Result { @@ -122,12 +121,12 @@ async fn run_nts_ke( )) })?; - let cert_chain: Vec = - rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_chain_file)) - .collect::>>()?; + let cert_chain: Vec = + ntp_proto::tls_utils::pemfile::certs(&mut std::io::BufReader::new(certificate_chain_file)) + .collect::>>()?; #[cfg_attr(not(feature = "unstable_nts-pool"), allow(unused_mut))] - let mut pool_certs: Vec = Vec::new(); + let mut pool_certs: Vec = Vec::new(); #[cfg(feature = "unstable_nts-pool")] for client_cert in &nts_ke_config.authorized_pool_server_certificates { let pool_certificate_file = std::fs::File::open(client_cert).map_err(|e| { @@ -136,9 +135,10 @@ async fn run_nts_ke( client_cert, e )) })?; - let mut certs: Vec<_> = - rustls_pemfile::certs(&mut std::io::BufReader::new(pool_certificate_file)) - .collect::>>()?; + let mut certs: Vec<_> = ntp_proto::tls_utils::pemfile::certs(&mut std::io::BufReader::new( + pool_certificate_file, + )) + .collect::>>()?; // forbid certificate chains at this point if certs.len() == 1 { pool_certs.push(certs.pop().unwrap()) @@ -150,25 +150,26 @@ async fn run_nts_ke( } } - let private_key = rustls_pemfile::private_key(&mut std::io::BufReader::new(private_key_file))? - .ok_or(io_error("could not parse private key"))?; + let private_key = + ntp_proto::tls_utils::pemfile::private_key(&mut std::io::BufReader::new(private_key_file))? + .ok_or(io_error("could not parse private key"))?; key_exchange_server(keyset, nts_ke_config, cert_chain, pool_certs, private_key).await } fn build_server_config( - certificate_chain: Vec>, - private_key: PrivateKeyDer<'static>, -) -> std::io::Result> { - let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13]) + certificate_chain: Vec, + private_key: PrivateKey, +) -> std::io::Result> { + let mut config = tls_utils::server_config_builder_with_protocol_versions(&[&TLS13]) .with_client_cert_verifier(Arc::new( #[cfg(not(feature = "unstable_nts-pool"))] - rustls::server::NoClientAuth, + tls_utils::NoClientAuth, #[cfg(feature = "unstable_nts-pool")] ntp_proto::tls_utils::AllowAnyAnonymousOrCertificateBearingClient::new( // We know that our previous call to ServerConfig::builder already // installed a default provider, but this is undocumented - rustls::crypto::CryptoProvider::get_default().unwrap(), + rustls23::crypto::CryptoProvider::get_default().unwrap(), ), )) .with_single_cert(certificate_chain, private_key) @@ -183,9 +184,9 @@ fn build_server_config( async fn key_exchange_server( keyset: tokio::sync::watch::Receiver>, ke_config: NtsKeConfig, - certificate_chain: Vec>, - pool_certs: Vec>, - private_key: PrivateKeyDer<'static>, + certificate_chain: Vec, + pool_certs: Vec, + private_key: PrivateKey, ) -> std::io::Result<()> { let config = build_server_config(certificate_chain, private_key)?; let pool_certs = Arc::<[_]>::from(pool_certs); @@ -284,7 +285,7 @@ where pub fn new( io: IO, server_name: String, - config: rustls::ClientConfig, + config: tls_utils::ClientConfig, ntp_version: Option, denied_servers: impl IntoIterator, ) -> Result { @@ -414,11 +415,11 @@ where { pub fn new( io: IO, - config: Arc, + config: Arc, keyset: Arc, ntp_port: Option, ntp_server: Option, - pool_certs: Arc<[rustls::pki_types::CertificateDer<'static>]>, + pool_certs: Arc<[Certificate]>, ) -> Result { let data = BoundKeyExchangeServerData { io, @@ -431,11 +432,11 @@ where pub async fn run( io: IO, - config: Arc, + config: Arc, keyset: Arc, ntp_port: Option, ntp_server: Option, - pool_certs: Arc<[rustls::pki_types::CertificateDer<'static>]>, + pool_certs: Arc<[Certificate]>, ) -> Result<(), KeyExchangeError> { let this = Self::new(io, config, keyset, ntp_port, ntp_server, pool_certs)?; @@ -601,24 +602,24 @@ impl Read for ReaderAdapter<'_, '_, T> { } } -pub(crate) fn certificates_from_file(path: &Path) -> std::io::Result>> { +pub(crate) fn certificates_from_file(path: &Path) -> std::io::Result> { let file = std::fs::File::open(path)?; let reader = BufReader::new(file); certificates_from_bufread(reader) } -fn certificates_from_bufread( - mut reader: impl BufRead, -) -> std::io::Result>> { - rustls_pemfile::certs(&mut reader).collect() +fn certificates_from_bufread(mut reader: impl BufRead) -> std::io::Result> { + ntp_proto::tls_utils::pemfile::certs(&mut reader).collect() } #[cfg(test)] mod tests { use std::{io::Cursor, net::SocketAddr, path::PathBuf}; + #[allow(unused)] use ntp_proto::{KeySetProvider, NtsRecord}; + #[allow(unused)] use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::test::alloc_port; @@ -644,31 +645,31 @@ mod tests { #[test] fn parse_private_keys() { let input = include_bytes!("../../test-keys/end.key"); - let _ = rustls_pemfile::private_key(&mut input.as_slice()) + let _ = ntp_proto::tls_utils::pemfile::private_key(&mut input.as_slice()) .unwrap() .unwrap(); let input = include_bytes!("../../test-keys/testca.key"); - let _ = rustls_pemfile::private_key(&mut input.as_slice()) + let _ = ntp_proto::tls_utils::pemfile::private_key(&mut input.as_slice()) .unwrap() .unwrap(); // openssl does no longer seem to want to generate this format // so we use https://github.com/rustls/pemfile/blob/main/tests/data/rsa1024.pkcs1.pem let input = include_bytes!("../../test-keys/rsa_key.pem"); - let _ = rustls_pemfile::private_key(&mut input.as_slice()) + let _ = ntp_proto::tls_utils::pemfile::private_key(&mut input.as_slice()) .unwrap() .unwrap(); // openssl ecparam -name prime256v1 -genkey -noout -out ec_key.pem let input = include_bytes!("../../test-keys/ec_key.pem"); - let _ = rustls_pemfile::private_key(&mut input.as_slice()) + let _ = ntp_proto::tls_utils::pemfile::private_key(&mut input.as_slice()) .unwrap() .unwrap(); // openssl genpkey -algorithm EC -out pkcs8_key.pem -pkeyopt ec_paramgen_curve:prime256v1 let input = include_bytes!("../../test-keys/pkcs8_key.pem"); - let _ = rustls_pemfile::private_key(&mut input.as_slice()) + let _ = ntp_proto::tls_utils::pemfile::private_key(&mut input.as_slice()) .unwrap() .unwrap(); } @@ -712,6 +713,7 @@ mod tests { assert_eq!(result.port, 123); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn key_exchange_weird_packet() { let port = alloc_port(); @@ -749,6 +751,7 @@ mod tests { assert_eq!(len, 880); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn key_exchange_bad_request() { let port = alloc_port(); @@ -950,6 +953,7 @@ mod tests { } } + #[cfg(feature = "run_tokio_rustls_tests")] fn client_key_exchange_message_length() -> usize { let mut buffer = Vec::with_capacity(1024); for record in ntp_proto::NtsRecord::client_key_exchange_records(None, vec![]).iter() { @@ -959,6 +963,7 @@ mod tests { buffer.len() } + #[cfg(feature = "run_tokio_rustls_tests")] async fn send_records_to_client( records: Vec, ) -> Result { @@ -973,7 +978,7 @@ mod tests { certificates_from_bufread(BufReader::new(Cursor::new(cc))).unwrap(); let pk = include_bytes!("../../test-keys/end.key"); - let private_key = rustls_pemfile::private_key(&mut pk.as_slice()) + let private_key = ntp_proto::tls_utils::pemfile::private_key(&mut pk.as_slice()) .unwrap() .unwrap(); @@ -1003,12 +1008,13 @@ mod tests { key_exchange_client("localhost".to_string(), port, extra_certificates, None).await } + #[cfg(feature = "run_tokio_rustls_tests")] async fn run_server(listener: tokio::net::TcpListener) -> Result<(), KeyExchangeError> { let cc = include_bytes!("../../test-keys/end.fullchain.pem"); let certificate_chain = certificates_from_bufread(BufReader::new(Cursor::new(cc)))?; let pk = include_bytes!("../../test-keys/end.key"); - let private_key = rustls_pemfile::private_key(&mut pk.as_slice()) + let private_key = ntp_proto::tls_utils::pemfile::private_key(&mut pk.as_slice()) .unwrap() .unwrap(); @@ -1023,6 +1029,7 @@ mod tests { BoundKeyExchangeServer::run(stream, config, keyset, None, None, pool_certs).await } + #[cfg(feature = "run_tokio_rustls_tests")] async fn client_tls_stream( server_name: &str, port: u16, @@ -1037,7 +1044,7 @@ mod tests { let config = build_client_config(extra_certificates).await.unwrap(); - let domain = rustls::pki_types::ServerName::try_from(server_name) + let domain = tls_utils::ServerName::try_from(server_name) .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname")) .unwrap() .to_owned(); @@ -1046,6 +1053,7 @@ mod tests { connector.connect(domain, stream).await.unwrap() } + #[cfg(feature = "run_tokio_rustls_tests")] async fn send_records_to_server(records: Vec) -> Result<(), KeyExchangeError> { let port = alloc_port(); let listener = TcpListener::bind(&("localhost", port)).await?; @@ -1072,6 +1080,7 @@ mod tests { run_server(listener).await } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn receive_cookies() { let result = send_records_to_client(vec![ @@ -1092,6 +1101,7 @@ mod tests { assert!(result.is_ok()); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn records_after_end_are_ignored() { let result = send_records_to_client(vec![ @@ -1115,6 +1125,7 @@ mod tests { assert!(result.is_ok()); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn no_cookies() { let result = send_records_to_client(vec![ @@ -1134,6 +1145,7 @@ mod tests { assert!(matches!(error, KeyExchangeError::NoCookies)); } + #[cfg(feature = "run_tokio_rustls_tests")] async fn client_error_record(errorcode: u16) -> KeyExchangeError { let result = send_records_to_client(vec![ NtsRecord::Error { errorcode }, @@ -1144,6 +1156,7 @@ mod tests { result.unwrap_err() } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn client_receives_error_record() { use KeyExchangeError as KEE; @@ -1158,6 +1171,7 @@ mod tests { assert!(matches!(error, KEE::InternalServerError)); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn server_expected_client_records() { let records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); @@ -1166,6 +1180,7 @@ mod tests { assert!(result.is_ok()); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn immediate_end_of_message() { let records = vec![NtsRecord::EndOfMessage]; @@ -1174,6 +1189,7 @@ mod tests { assert!(matches!(result, Err(KeyExchangeError::NoValidProtocol))); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn double_next_protocol() { let records = vec![ @@ -1190,6 +1206,7 @@ mod tests { assert!(matches!(result, Err(KeyExchangeError::BadRequest))); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn records_after_end_of_message() { let records = vec![ @@ -1210,6 +1227,7 @@ mod tests { assert!(result.is_ok()); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn client_no_valid_algorithm() { let records = vec![ @@ -1227,6 +1245,7 @@ mod tests { assert!(matches!(result, Err(KeyExchangeError::NoValidAlgorithm))); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn client_no_valid_protocol() { let records = vec![ @@ -1244,6 +1263,7 @@ mod tests { assert!(matches!(result, Err(KeyExchangeError::NoValidProtocol))); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn unrecognized_critical_record() { let records = vec![ @@ -1262,6 +1282,7 @@ mod tests { )); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn client_sends_no_records_clean_shutdown() { let port = alloc_port(); @@ -1280,6 +1301,7 @@ mod tests { assert!(matches!(result, Err(KeyExchangeError::IncompleteResponse))); } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] #[ignore = "Flaky on macos and not that interesting"] async fn client_sends_no_records_dirty_shutdown() { @@ -1299,6 +1321,7 @@ mod tests { assert!(matches!(result, Err(KeyExchangeError::IncompleteResponse))); } + #[cfg(feature = "run_tokio_rustls_tests")] async fn server_error_record(errorcode: u16) -> KeyExchangeError { let result = send_records_to_server(vec![ NtsRecord::Error { errorcode }, @@ -1309,6 +1332,7 @@ mod tests { result.unwrap_err() } + #[cfg(feature = "run_tokio_rustls_tests")] #[tokio::test] async fn server_receives_error_record() { use KeyExchangeError as KEE; diff --git a/nts-pool-ke/Cargo.toml b/nts-pool-ke/Cargo.toml index 6b8de61ff..090a69f7e 100644 --- a/nts-pool-ke/Cargo.toml +++ b/nts-pool-ke/Cargo.toml @@ -10,18 +10,25 @@ description.workspace = true publish.workspace = true rust-version.workspace = true +[package.metadata.cargo-udeps.ignore] +normal = [ "ntp-proto", "rustls-native-certs7", "rustls-pemfile2", "rustls23", "serde", "tokio-rustls", "toml", "tracing", "tracing-subscriber" ] + [dependencies] tokio = { workspace = true, features = ["rt-multi-thread", "io-util", "fs", "net", "macros", "time" ] } toml.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3.0", default-features = false, features = ["std", "fmt", "ansi"] } -rustls.workspace = true -rustls-pemfile.workspace = true -rustls-native-certs.workspace = true +rustls23.workspace = true +rustls-pemfile2.workspace = true +rustls-native-certs7.workspace = true serde.workspace = true -ntp-proto = { workspace = true, features = ["nts-pool"] } +ntp-proto = { workspace = true } tokio-rustls.workspace = true +[features] +default = [] +unstable_nts-pool = [ "ntp-proto/nts-pool" ] + [[bin]] name = "nts-pool-ke" path = "bin/nts-pool-ke.rs" diff --git a/nts-pool-ke/src/cli.rs b/nts-pool-ke/src/condcompile/cli.rs similarity index 99% rename from nts-pool-ke/src/cli.rs rename to nts-pool-ke/src/condcompile/cli.rs index c3260797b..7ab142643 100644 --- a/nts-pool-ke/src/cli.rs +++ b/nts-pool-ke/src/condcompile/cli.rs @@ -1,4 +1,4 @@ -use crate::daemon_tracing::LogLevel; +use super::daemon_tracing::LogLevel; use std::path::PathBuf; use std::str::FromStr; diff --git a/nts-pool-ke/src/config.rs b/nts-pool-ke/src/condcompile/config.rs similarity index 98% rename from nts-pool-ke/src/config.rs rename to nts-pool-ke/src/condcompile/config.rs index b8eb8034f..dea2168c5 100644 --- a/nts-pool-ke/src/config.rs +++ b/nts-pool-ke/src/condcompile/config.rs @@ -77,7 +77,7 @@ impl Config { #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub struct ObservabilityConfig { #[serde(default)] - pub log_level: Option, + pub log_level: Option, } #[derive(Debug, PartialEq, Eq, Clone, Deserialize)] diff --git a/nts-pool-ke/src/tracing.rs b/nts-pool-ke/src/condcompile/tracing.rs similarity index 100% rename from nts-pool-ke/src/tracing.rs rename to nts-pool-ke/src/condcompile/tracing.rs diff --git a/nts-pool-ke/src/lib.rs b/nts-pool-ke/src/lib.rs index 02d7a025e..f52ff57d0 100644 --- a/nts-pool-ke/src/lib.rs +++ b/nts-pool-ke/src/lib.rs @@ -1,565 +1,582 @@ -mod cli; -mod config; - -mod tracing; - -use std::{io::ErrorKind, ops::ControlFlow, path::PathBuf, sync::Arc}; - -use ::tracing::{info, warn}; -use cli::NtsPoolKeOptions; -use config::{Config, NtsPoolKeConfig}; -use ntp_proto::{ - AeadAlgorithm, ClientToPoolData, KeyExchangeError, NtsRecord, PoolToServerData, - PoolToServerDecoder, SupportedAlgorithmsDecoder, -}; -use rustls::{ - pki_types::{CertificateDer, ServerName}, - version::TLS13, -}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpListener, ToSocketAddrs}, -}; -use tokio_rustls::TlsConnector; - -use crate::tracing as daemon_tracing; -use daemon_tracing::LogLevel; -use tracing_subscriber::util::SubscriberInitExt; - -pub(crate) mod exitcode { - /// An internal software error has been detected. This - /// should be limited to non-operating system related - /// errors as possible. - pub const SOFTWARE: i32 = 70; - - /// Something was found in an unconfigured or misconfigured state. - pub const CONFIG: i32 = 78; -} +#[cfg(feature = "unstable_nts-pool")] +mod condcompile { + extern crate rustls23 as rustls; + extern crate rustls_native_certs7 as rustls_native_certs; + extern crate rustls_pemfile2 as rustls_pemfile; -const VERSION: &str = env!("CARGO_PKG_VERSION"); + mod cli; + mod config; -pub async fn nts_pool_ke_main() -> Result<(), Box> { - let options = NtsPoolKeOptions::try_parse_from(std::env::args())?; + mod tracing; - match options.action { - cli::NtsPoolKeAction::Help => { - println!("{}", cli::long_help_message()); - } - cli::NtsPoolKeAction::Version => { - eprintln!("nts-pool-ke {VERSION}"); - } - cli::NtsPoolKeAction::Run => run(options).await?, + use std::{io::ErrorKind, ops::ControlFlow, path::PathBuf, sync::Arc}; + + use ::tracing::{info, warn}; + use cli::NtsPoolKeOptions; + use config::{Config, NtsPoolKeConfig}; + use ntp_proto::{ + AeadAlgorithm, ClientToPoolData, KeyExchangeError, NtsRecord, PoolToServerData, + PoolToServerDecoder, SupportedAlgorithmsDecoder, + }; + use rustls::{ + pki_types::{CertificateDer, ServerName}, + version::TLS13, + }; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, ToSocketAddrs}, + }; + use tokio_rustls::TlsConnector; + + use self::tracing as daemon_tracing; + use daemon_tracing::LogLevel; + use tracing_subscriber::util::SubscriberInitExt; + + pub(crate) mod exitcode { + /// An internal software error has been detected. This + /// should be limited to non-operating system related + /// errors as possible. + pub const SOFTWARE: i32 = 70; + + /// Something was found in an unconfigured or misconfigured state. + pub const CONFIG: i32 = 78; } - Ok(()) -} + const VERSION: &str = env!("CARGO_PKG_VERSION"); -// initializes the logger so that logs during config parsing are reported. Then it overrides the -// log level based on the config if required. -pub(crate) async fn initialize_logging_parse_config( - initial_log_level: Option, - config_path: Option, -) -> Config { - let mut log_level = initial_log_level.unwrap_or_default(); - - let config_tracing = daemon_tracing::tracing_init(log_level); - let config = ::tracing::subscriber::with_default(config_tracing, || { - async { - match config_path { - None => { - eprintln!("no configuration path specified"); - std::process::exit(exitcode::CONFIG); - } - Some(config_path) => { - match Config::from_args(config_path).await { - Ok(c) => c, - Err(e) => { - // print to stderr because tracing is not yet setup - eprintln!("There was an error loading the config: {e}"); - std::process::exit(exitcode::CONFIG); + pub async fn nts_pool_ke_main() -> Result<(), Box> { + let options = NtsPoolKeOptions::try_parse_from(std::env::args())?; + + match options.action { + cli::NtsPoolKeAction::Help => { + println!("{}", cli::long_help_message()); + } + cli::NtsPoolKeAction::Version => { + eprintln!("nts-pool-ke {VERSION}"); + } + cli::NtsPoolKeAction::Run => run(options).await?, + } + + Ok(()) + } + + // initializes the logger so that logs during config parsing are reported. Then it overrides the + // log level based on the config if required. + pub(crate) async fn initialize_logging_parse_config( + initial_log_level: Option, + config_path: Option, + ) -> Config { + let mut log_level = initial_log_level.unwrap_or_default(); + + let config_tracing = daemon_tracing::tracing_init(log_level); + let config = ::tracing::subscriber::with_default(config_tracing, || { + async { + match config_path { + None => { + eprintln!("no configuration path specified"); + std::process::exit(exitcode::CONFIG); + } + Some(config_path) => { + match Config::from_args(config_path).await { + Ok(c) => c, + Err(e) => { + // print to stderr because tracing is not yet setup + eprintln!("There was an error loading the config: {e}"); + std::process::exit(exitcode::CONFIG); + } } } } } - } - }) - .await; + }) + .await; - if let Some(config_log_level) = config.observability.log_level { - if initial_log_level.is_none() { - log_level = config_log_level; + if let Some(config_log_level) = config.observability.log_level { + if initial_log_level.is_none() { + log_level = config_log_level; + } } - } - // set a default global subscriber from now on - let tracing_inst = daemon_tracing::tracing_init(log_level); - tracing_inst.init(); + // set a default global subscriber from now on + let tracing_inst = daemon_tracing::tracing_init(log_level); + tracing_inst.init(); - config -} + config + } -async fn run(options: NtsPoolKeOptions) -> Result<(), Box> { - let config = initialize_logging_parse_config(options.log_level, options.config).await; + async fn run(options: NtsPoolKeOptions) -> Result<(), Box> { + let config = initialize_logging_parse_config(options.log_level, options.config).await; - // give the user a warning that we use the command line option - if config.observability.log_level.is_some() && options.log_level.is_some() { - ::tracing::info!("Log level override from command line arguments is active"); - } + // give the user a warning that we use the command line option + if config.observability.log_level.is_some() && options.log_level.is_some() { + ::tracing::info!("Log level override from command line arguments is active"); + } - // Warn/error if the config is unreasonable. We do this after finishing - // tracing setup to ensure logging is fully configured. - config.check(); + // Warn/error if the config is unreasonable. We do this after finishing + // tracing setup to ensure logging is fully configured. + config.check(); - let result = run_nts_pool_ke(config.nts_pool_ke_server).await; + let result = run_nts_pool_ke(config.nts_pool_ke_server).await; - match result { - Ok(v) => Ok(v), - Err(e) => { - ::tracing::error!("Abnormal termination of NTS KE server: {e}"); - std::process::exit(exitcode::SOFTWARE) + match result { + Ok(v) => Ok(v), + Err(e) => { + ::tracing::error!("Abnormal termination of NTS KE server: {e}"); + std::process::exit(exitcode::SOFTWARE) + } } } -} - -async fn run_nts_pool_ke(nts_pool_ke_config: NtsPoolKeConfig) -> std::io::Result<()> { - let certificate_authority_file = - std::fs::File::open(&nts_pool_ke_config.certificate_authority_path).map_err(|e| { - io_error(&format!( - "error reading certificate_authority_path at `{:?}`: {:?}", - nts_pool_ke_config.certificate_authority_path, e - )) - })?; - - let certificate_chain_file = std::fs::File::open(&nts_pool_ke_config.certificate_chain_path) - .map_err(|e| { - io_error(&format!( - "error reading certificate_chain_path at `{:?}`: {:?}", - nts_pool_ke_config.certificate_chain_path, e - )) - })?; - - let private_key_file = - std::fs::File::open(&nts_pool_ke_config.private_key_path).map_err(|e| { - io_error(&format!( - "error reading key_der_path at `{:?}`: {:?}", - nts_pool_ke_config.private_key_path, e - )) - })?; - - let certificate_authority: Arc<[rustls::pki_types::CertificateDer]> = - rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_authority_file)) - .collect::>>()?; - - let certificate_chain: Vec = - rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_chain_file)) - .collect::>>()?; - - let private_key = rustls_pemfile::private_key(&mut std::io::BufReader::new(private_key_file))? - .ok_or(io_error("could not parse private key"))?; - - pool_key_exchange_server( - nts_pool_ke_config.listen, - certificate_authority, - certificate_chain, - private_key, - nts_pool_ke_config.key_exchange_servers, - nts_pool_ke_config.key_exchange_timeout_ms, - ) - .await -} - -fn io_error(msg: &str) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::Other, msg) -} -async fn pool_key_exchange_server( - address: impl ToSocketAddrs, - certificate_authority: Arc<[rustls::pki_types::CertificateDer<'static>]>, - certificate_chain: Vec>, - private_key: rustls::pki_types::PrivateKeyDer<'static>, - servers: Vec, - timeout_ms: u64, -) -> std::io::Result<()> { - let listener = TcpListener::bind(address).await?; - - let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13]) - .with_no_client_auth() - .with_single_cert(certificate_chain.clone(), private_key.clone_key()) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?; - - config.alpn_protocols.clear(); - config.alpn_protocols.push(b"ntske/1".to_vec()); - - let config = Arc::new(config); - let servers: Arc<[_]> = servers.into(); - - info!("listening on '{:?}'", listener.local_addr()); - - loop { - let (client_stream, source_address) = listener.accept().await?; - let client_to_pool_config = config.clone(); - let servers = servers.clone(); - let certificate_chain = certificate_chain.clone(); - let private_key = private_key.clone_key(); - - let certificate_authority = certificate_authority.clone(); - let fut = handle_client( - client_stream, - client_to_pool_config, + async fn run_nts_pool_ke(nts_pool_ke_config: NtsPoolKeConfig) -> std::io::Result<()> { + let certificate_authority_file = + std::fs::File::open(&nts_pool_ke_config.certificate_authority_path).map_err(|e| { + io_error(&format!( + "error reading certificate_authority_path at `{:?}`: {:?}", + nts_pool_ke_config.certificate_authority_path, e + )) + })?; + + let certificate_chain_file = + std::fs::File::open(&nts_pool_ke_config.certificate_chain_path).map_err(|e| { + io_error(&format!( + "error reading certificate_chain_path at `{:?}`: {:?}", + nts_pool_ke_config.certificate_chain_path, e + )) + })?; + + let private_key_file = + std::fs::File::open(&nts_pool_ke_config.private_key_path).map_err(|e| { + io_error(&format!( + "error reading key_der_path at `{:?}`: {:?}", + nts_pool_ke_config.private_key_path, e + )) + })?; + + let certificate_authority: Arc<[rustls::pki_types::CertificateDer]> = + rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_authority_file)) + .collect::>>()?; + + let certificate_chain: Vec = + rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_chain_file)) + .collect::>>()?; + + let private_key = + rustls_pemfile::private_key(&mut std::io::BufReader::new(private_key_file))? + .ok_or(io_error("could not parse private key"))?; + + pool_key_exchange_server( + nts_pool_ke_config.listen, certificate_authority, certificate_chain, private_key, - servers, - ); - - tokio::spawn(async move { - let timeout = std::time::Duration::from_millis(timeout_ms); - match tokio::time::timeout(timeout, fut).await { - Err(_) => ::tracing::debug!(?source_address, "NTS Pool KE timed out"), - Ok(Err(err)) => ::tracing::debug!(?err, ?source_address, "NTS Pool KE failed"), - Ok(Ok(())) => ::tracing::debug!(?source_address, "NTS Pool KE completed"), - } - }); + nts_pool_ke_config.key_exchange_servers, + nts_pool_ke_config.key_exchange_timeout_ms, + ) + .await } -} -async fn try_nts_ke_server<'a>( - connector: &TlsConnector, - config::KeyExchangeServer { domain, port }: &'a config::KeyExchangeServer, - selected_algorithm: AeadAlgorithm, -) -> Result<(&'a str, u16, ServerName<'static>), KeyExchangeError> { - info!("checking supported algorithms for '{domain}:{port}'"); - - let domain = domain.as_str(); - let server_name = rustls::pki_types::ServerName::try_from(domain) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname"))? - .to_owned(); - - let server_stream = match tokio::net::TcpStream::connect((domain, *port)).await { - Ok(server_stream) => server_stream, - Err(e) => return Err(e.into()), - }; - let mut server_stream = connector - .connect(server_name.clone(), server_stream) - .await?; + fn io_error(msg: &str) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, msg) + } - info!("established connection to the server"); + async fn pool_key_exchange_server( + address: impl ToSocketAddrs, + certificate_authority: Arc<[rustls::pki_types::CertificateDer<'static>]>, + certificate_chain: Vec>, + private_key: rustls::pki_types::PrivateKeyDer<'static>, + servers: Vec, + timeout_ms: u64, + ) -> std::io::Result<()> { + let listener = TcpListener::bind(address).await?; + + let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13]) + .with_no_client_auth() + .with_single_cert(certificate_chain.clone(), private_key.clone_key()) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?; + + config.alpn_protocols.clear(); + config.alpn_protocols.push(b"ntske/1".to_vec()); + + let config = Arc::new(config); + let servers: Arc<[_]> = servers.into(); + + info!("listening on '{:?}'", listener.local_addr()); + + loop { + let (client_stream, source_address) = listener.accept().await?; + let client_to_pool_config = config.clone(); + let servers = servers.clone(); + let certificate_chain = certificate_chain.clone(); + let private_key = private_key.clone_key(); + + let certificate_authority = certificate_authority.clone(); + let fut = handle_client( + client_stream, + client_to_pool_config, + certificate_authority, + certificate_chain, + private_key, + servers, + ); + + tokio::spawn(async move { + let timeout = std::time::Duration::from_millis(timeout_ms); + match tokio::time::timeout(timeout, fut).await { + Err(_) => ::tracing::debug!(?source_address, "NTS Pool KE timed out"), + Ok(Err(err)) => ::tracing::debug!(?err, ?source_address, "NTS Pool KE failed"), + Ok(Ok(())) => ::tracing::debug!(?source_address, "NTS Pool KE completed"), + } + }); + } + } - let supported_algorithms = supported_algorithms_request(&mut server_stream).await?; + async fn try_nts_ke_server<'a>( + connector: &TlsConnector, + config::KeyExchangeServer { domain, port }: &'a config::KeyExchangeServer, + selected_algorithm: AeadAlgorithm, + ) -> Result<(&'a str, u16, ServerName<'static>), KeyExchangeError> { + info!("checking supported algorithms for '{domain}:{port}'"); + + let domain = domain.as_str(); + let server_name = rustls::pki_types::ServerName::try_from(domain) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname"))? + .to_owned(); + + let server_stream = match tokio::net::TcpStream::connect((domain, *port)).await { + Ok(server_stream) => server_stream, + Err(e) => return Err(e.into()), + }; + let mut server_stream = connector + .connect(server_name.clone(), server_stream) + .await?; - info!("received supported algorithms from the NTS KE server"); + info!("established connection to the server"); - if supported_algorithms - .iter() - .any(|(algorithm_id, _)| *algorithm_id == selected_algorithm as u16) - { - Ok((domain, *port, server_name)) - } else { - Err(KeyExchangeError::NoValidAlgorithm) - } -} + let supported_algorithms = supported_algorithms_request(&mut server_stream).await?; -async fn pick_nts_ke_servers<'a>( - connector: &TlsConnector, - servers: &'a [config::KeyExchangeServer], - selected_algorithm: AeadAlgorithm, - denied_servers: &[String], -) -> Result<(&'a str, u16, ServerName<'static>), KeyExchangeError> { - use std::sync::atomic::{AtomicUsize, Ordering}; - static START_INDEX: AtomicUsize = AtomicUsize::new(0); - let start_index = START_INDEX.fetch_add(1, Ordering::Relaxed); - - // rotate the serverlist so that an error caused by a single NTS-KE server doesn't - // permanently cripple the pool - let (left, right) = servers.split_at(start_index % servers.len()); - let rotated_servers = right.iter().chain(left.iter()); - let mut connection_error = false; - - for server in rotated_servers { - if denied_servers.contains(&server.domain) { - continue; - } + info!("received supported algorithms from the NTS KE server"); - match try_nts_ke_server(connector, server, selected_algorithm).await { - Ok(x) => return Ok(x), - Err(e) => match e { - // only if the connection was refused do we try another server during this - // connection, because otherwise key material from this TLS connection could have - // been shared with multiple servers through the FixedKeyRequest - KeyExchangeError::Io(e) if e.kind() == ErrorKind::ConnectionRefused => { - connection_error = true; - ::tracing::debug!("connection refused: {e}"); - continue; - } - _ => return Err(e), - }, + if supported_algorithms + .iter() + .any(|(algorithm_id, _)| *algorithm_id == selected_algorithm as u16) + { + Ok((domain, *port, server_name)) + } else { + Err(KeyExchangeError::NoValidAlgorithm) } } - warn!("pool could not find a valid KE server"); - - // if finding a KE server failed because of a connection error, - // report a Internal Server Error; if it failed because all our servers - // were able and willing, but rejected by the client, return a Bad Request - if connection_error { - Err(KeyExchangeError::InternalServerError) - } else { - Err(KeyExchangeError::BadRequest) - } -} + async fn pick_nts_ke_servers<'a>( + connector: &TlsConnector, + servers: &'a [config::KeyExchangeServer], + selected_algorithm: AeadAlgorithm, + denied_servers: &[String], + ) -> Result<(&'a str, u16, ServerName<'static>), KeyExchangeError> { + use std::sync::atomic::{AtomicUsize, Ordering}; + static START_INDEX: AtomicUsize = AtomicUsize::new(0); + let start_index = START_INDEX.fetch_add(1, Ordering::Relaxed); + + // rotate the serverlist so that an error caused by a single NTS-KE server doesn't + // permanently cripple the pool + let (left, right) = servers.split_at(start_index % servers.len()); + let rotated_servers = right.iter().chain(left.iter()); + let mut connection_error = false; + + for server in rotated_servers { + if denied_servers.contains(&server.domain) { + continue; + } -async fn handle_client( - client_stream: tokio::net::TcpStream, - config: Arc, - certificate_authority: Arc<[rustls::pki_types::CertificateDer<'static>]>, - certificate_chain: Vec>, - private_key: rustls::pki_types::PrivateKeyDer<'static>, - servers: Arc<[config::KeyExchangeServer]>, -) -> Result<(), KeyExchangeError> { - // handle the initial client to pool - let acceptor = tokio_rustls::TlsAcceptor::from(config); - let mut client_stream = acceptor.accept(client_stream).await?; - - // read all records from the client - let client_data = client_to_pool_request(&mut client_stream).await?; - - info!("received records from the client",); - - let connector = - pool_to_server_connector(&certificate_authority, certificate_chain, private_key)?; - - let pick = pick_nts_ke_servers( - &connector, - &servers, - client_data.algorithm, - &client_data.denied_servers, - ) - .await; - - let (server_name, port, domain) = match pick { - Ok(x) => x, - Err(e) => { - // for now, just send back to the client that its algorithms were invalid - // AeadAlgorithm::AeadAesSivCmac256 should always be supported by servers and clients - info!(?e, "could not find a valid KE server"); - - let records = [ - NtsRecord::NextProtocol { - protocol_ids: vec![0], - }, - NtsRecord::Error { - errorcode: e.to_error_code(), + match try_nts_ke_server(connector, server, selected_algorithm).await { + Ok(x) => return Ok(x), + Err(e) => match e { + // only if the connection was refused do we try another server during this + // connection, because otherwise key material from this TLS connection could have + // been shared with multiple servers through the FixedKeyRequest + KeyExchangeError::Io(e) if e.kind() == ErrorKind::ConnectionRefused => { + connection_error = true; + ::tracing::debug!("connection refused: {e}"); + continue; + } + _ => return Err(e), }, - NtsRecord::EndOfMessage, - ]; - - let mut buffer = Vec::with_capacity(1024); - for record in records { - record.write(&mut buffer)?; } + } - client_stream.write_all(&buffer).await?; - client_stream.shutdown().await?; + warn!("pool could not find a valid KE server"); - return Ok(()); + // if finding a KE server failed because of a connection error, + // report a Internal Server Error; if it failed because all our servers + // were able and willing, but rejected by the client, return a Bad Request + if connection_error { + Err(KeyExchangeError::InternalServerError) + } else { + Err(KeyExchangeError::BadRequest) } - }; + } - // this is inefficient of course, but spec-compliant: the TLS connection is closed when the server - // receives a EndOfMessage record, so we have to establish a new one. re-using the TCP - // connection runs into issues (seems to leave the server in an invalid state). - let server_stream = tokio::net::TcpStream::connect((server_name, port)).await?; - let server_stream = connector.connect(domain, server_stream).await?; + async fn handle_client( + client_stream: tokio::net::TcpStream, + config: Arc, + certificate_authority: Arc<[rustls::pki_types::CertificateDer<'static>]>, + certificate_chain: Vec>, + private_key: rustls::pki_types::PrivateKeyDer<'static>, + servers: Arc<[config::KeyExchangeServer]>, + ) -> Result<(), KeyExchangeError> { + // handle the initial client to pool + let acceptor = tokio_rustls::TlsAcceptor::from(config); + let mut client_stream = acceptor.accept(client_stream).await?; + + // read all records from the client + let client_data = client_to_pool_request(&mut client_stream).await?; + + info!("received records from the client",); + + let connector = + pool_to_server_connector(&certificate_authority, certificate_chain, private_key)?; + + let pick = pick_nts_ke_servers( + &connector, + &servers, + client_data.algorithm, + &client_data.denied_servers, + ) + .await; + + let (server_name, port, domain) = match pick { + Ok(x) => x, + Err(e) => { + // for now, just send back to the client that its algorithms were invalid + // AeadAlgorithm::AeadAesSivCmac256 should always be supported by servers and clients + info!(?e, "could not find a valid KE server"); + + let records = [ + NtsRecord::NextProtocol { + protocol_ids: vec![0], + }, + NtsRecord::Error { + errorcode: e.to_error_code(), + }, + NtsRecord::EndOfMessage, + ]; + + let mut buffer = Vec::with_capacity(1024); + for record in records { + record.write(&mut buffer)?; + } - info!("fetching cookies from the NTS KE server"); + client_stream.write_all(&buffer).await?; + client_stream.shutdown().await?; - // get the cookies from the NTS KE server - let records_for_server = prepare_records_for_server(&client_stream, client_data)?; - match cookie_request(server_stream, &records_for_server).await { - Err(e) => { - warn!(?e, "NTS KE server returned an error"); + return Ok(()); + } + }; - Err(e) - } - Ok(records_for_client) => { - info!("received cookies from the NTS KE server"); + // this is inefficient of course, but spec-compliant: the TLS connection is closed when the server + // receives a EndOfMessage record, so we have to establish a new one. re-using the TCP + // connection runs into issues (seems to leave the server in an invalid state). + let server_stream = tokio::net::TcpStream::connect((server_name, port)).await?; + let server_stream = connector.connect(domain, server_stream).await?; - // now we just forward the response - let mut buffer = Vec::with_capacity(1024); - let (mut mentions_server, mut mentions_port) = (false, false); + info!("fetching cookies from the NTS KE server"); - for record in &records_for_client { - mentions_server |= matches!(record, NtsRecord::Server { .. }); - mentions_port |= matches!(record, NtsRecord::Port { .. }); + // get the cookies from the NTS KE server + let records_for_server = prepare_records_for_server(&client_stream, client_data)?; + match cookie_request(server_stream, &records_for_server).await { + Err(e) => { + warn!(?e, "NTS KE server returned an error"); + + Err(e) } + Ok(records_for_client) => { + info!("received cookies from the NTS KE server"); + + // now we just forward the response + let mut buffer = Vec::with_capacity(1024); + let (mut mentions_server, mut mentions_port) = (false, false); - if !mentions_server { - NtsRecord::Server { - critical: true, - name: server_name.to_string(), + for record in &records_for_client { + mentions_server |= matches!(record, NtsRecord::Server { .. }); + mentions_port |= matches!(record, NtsRecord::Port { .. }); } - .write(&mut buffer)?; - } - const NTP_DEFAULT_PORT: u16 = 123; - if !mentions_port { - NtsRecord::Port { - critical: true, - port: NTP_DEFAULT_PORT, + if !mentions_server { + NtsRecord::Server { + critical: true, + name: server_name.to_string(), + } + .write(&mut buffer)?; } - .write(&mut buffer)?; - } - for record in records_for_client { - record.write(&mut buffer)?; - } + const NTP_DEFAULT_PORT: u16 = 123; + if !mentions_port { + NtsRecord::Port { + critical: true, + port: NTP_DEFAULT_PORT, + } + .write(&mut buffer)?; + } + + for record in records_for_client { + record.write(&mut buffer)?; + } - client_stream.write_all(&buffer).await?; - client_stream.shutdown().await?; + client_stream.write_all(&buffer).await?; + client_stream.shutdown().await?; - info!("wrote records for client"); + info!("wrote records for client"); - Ok(()) + Ok(()) + } } } -} -fn prepare_records_for_server( - client_stream: &tokio_rustls::server::TlsStream, - client_data: ClientToPoolData, -) -> Result, KeyExchangeError> { - let nts_keys = client_data.extract_nts_keys(client_stream.get_ref().1)?; - - let mut records_for_server = client_data.records; - records_for_server.extend([ - NtsRecord::NextProtocol { - protocol_ids: vec![0], - }, - NtsRecord::AeadAlgorithm { - critical: false, - algorithm_ids: vec![client_data.algorithm as u16], - }, - nts_keys.as_fixed_key_request(), - NtsRecord::EndOfMessage, - ]); - - Ok(records_for_server) -} + fn prepare_records_for_server( + client_stream: &tokio_rustls::server::TlsStream, + client_data: ClientToPoolData, + ) -> Result, KeyExchangeError> { + let nts_keys = client_data.extract_nts_keys(client_stream.get_ref().1)?; -fn pool_to_server_connector( - extra_certificates: &[CertificateDer<'static>], - certificate_chain: Vec>, - private_key: rustls::pki_types::PrivateKeyDer<'static>, -) -> Result { - let mut roots = rustls::RootCertStore::empty(); - for cert in rustls_native_certs::load_native_certs()? { - roots.add(cert).map_err(KeyExchangeError::Certificate)?; - } + let mut records_for_server = client_data.records; + records_for_server.extend([ + NtsRecord::NextProtocol { + protocol_ids: vec![0], + }, + NtsRecord::AeadAlgorithm { + critical: false, + algorithm_ids: vec![client_data.algorithm as u16], + }, + nts_keys.as_fixed_key_request(), + NtsRecord::EndOfMessage, + ]); - for cert in extra_certificates { - roots - .add(cert.clone()) - .map_err(KeyExchangeError::Certificate)?; + Ok(records_for_server) } - let config = rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13]) - .with_root_certificates(roots) - .with_client_auth_cert(certificate_chain, private_key) - .unwrap(); + fn pool_to_server_connector( + extra_certificates: &[CertificateDer<'static>], + certificate_chain: Vec>, + private_key: rustls::pki_types::PrivateKeyDer<'static>, + ) -> Result { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs()? { + roots.add(cert).map_err(KeyExchangeError::Certificate)?; + } - // already has the FixedKeyRequest record - Ok(tokio_rustls::TlsConnector::from(Arc::new(config))) -} + for cert in extra_certificates { + roots + .add(cert.clone()) + .map_err(KeyExchangeError::Certificate)?; + } -async fn client_to_pool_request( - stream: &mut tokio_rustls::server::TlsStream, -) -> Result { - let mut decoder = ntp_proto::ClientToPoolDecoder::default(); + let config = rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13]) + .with_root_certificates(roots) + .with_client_auth_cert(certificate_chain, private_key) + .unwrap(); - let mut buf = [0; 1024]; + // already has the FixedKeyRequest record + Ok(tokio_rustls::TlsConnector::from(Arc::new(config))) + } - loop { - let n = stream.read(&mut buf).await?; + async fn client_to_pool_request( + stream: &mut tokio_rustls::server::TlsStream, + ) -> Result { + let mut decoder = ntp_proto::ClientToPoolDecoder::default(); - if n == 0 { - break Err(KeyExchangeError::IncompleteResponse); - } + let mut buf = [0; 1024]; - decoder = match decoder.step_with_slice(&buf[..n]) { - ControlFlow::Continue(decoder) => decoder, - ControlFlow::Break(done) => break done, - }; - } -} + loop { + let n = stream.read(&mut buf).await?; -async fn cookie_request( - mut stream: tokio_rustls::client::TlsStream, - nts_records: &[NtsRecord], -) -> Result, KeyExchangeError> { - // now we just forward the response - let mut buf = Vec::with_capacity(1024); - for record in nts_records { - record.write(&mut buf)?; + if n == 0 { + break Err(KeyExchangeError::IncompleteResponse); + } + + decoder = match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => decoder, + ControlFlow::Break(done) => break done, + }; + } } - stream.write_all(&buf).await?; + async fn cookie_request( + mut stream: tokio_rustls::client::TlsStream, + nts_records: &[NtsRecord], + ) -> Result, KeyExchangeError> { + // now we just forward the response + let mut buf = Vec::with_capacity(1024); + for record in nts_records { + record.write(&mut buf)?; + } - let mut buf = [0; 1024]; - let mut decoder = PoolToServerDecoder::default(); + stream.write_all(&buf).await?; - loop { - let n = stream.read(&mut buf).await?; + let mut buf = [0; 1024]; + let mut decoder = PoolToServerDecoder::default(); - if n == 0 { - break Err(KeyExchangeError::IncompleteResponse); - } + loop { + let n = stream.read(&mut buf).await?; - decoder = match decoder.step_with_slice(&buf[..n]) { - ControlFlow::Continue(decoder) => decoder, - ControlFlow::Break(Ok(PoolToServerData { - records, - algorithm: _, - protocol: _, - })) => { - stream.shutdown().await?; - break Ok(records); + if n == 0 { + break Err(KeyExchangeError::IncompleteResponse); } - ControlFlow::Break(Err(error)) => break Err(error), - }; - } -} -async fn supported_algorithms_request( - stream: &mut tokio_rustls::client::TlsStream, -) -> Result, KeyExchangeError> { - let nts_records = [ - NtsRecord::SupportedAlgorithmList { - supported_algorithms: vec![], - }, - NtsRecord::EndOfMessage, - ]; - - // now we just forward the response - let mut buf = Vec::with_capacity(1024); - for record in nts_records { - record.write(&mut buf)?; + decoder = match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => decoder, + ControlFlow::Break(Ok(PoolToServerData { + records, + algorithm: _, + protocol: _, + })) => { + stream.shutdown().await?; + break Ok(records); + } + ControlFlow::Break(Err(error)) => break Err(error), + }; + } } - stream.write_all(&buf).await?; + async fn supported_algorithms_request( + stream: &mut tokio_rustls::client::TlsStream, + ) -> Result, KeyExchangeError> { + let nts_records = [ + NtsRecord::SupportedAlgorithmList { + supported_algorithms: vec![], + }, + NtsRecord::EndOfMessage, + ]; - let mut buf = [0; 1024]; - let mut decoder = SupportedAlgorithmsDecoder::default(); + // now we just forward the response + let mut buf = Vec::with_capacity(1024); + for record in nts_records { + record.write(&mut buf)?; + } - loop { - let n = stream.read(&mut buf).await?; + stream.write_all(&buf).await?; - if n == 0 { - break Err(KeyExchangeError::IncompleteResponse); - } + let mut buf = [0; 1024]; + let mut decoder = SupportedAlgorithmsDecoder::default(); - decoder = match decoder.step_with_slice(&buf[..n]) { - ControlFlow::Continue(decoder) => decoder, - ControlFlow::Break(result) => break result, - }; + loop { + let n = stream.read(&mut buf).await?; + + if n == 0 { + break Err(KeyExchangeError::IncompleteResponse); + } + + decoder = match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => decoder, + ControlFlow::Break(result) => break result, + }; + } } } + +#[cfg(feature = "unstable_nts-pool")] +pub use condcompile::*; + +#[cfg(not(feature = "unstable_nts-pool"))] +pub async fn nts_pool_ke_main() -> Result<(), ()> { + println!("Not enabled, enable `unstable_nts-pool` feature at build time to build functional pool ke server"); + Ok(()) +}