diff --git a/ntp-proto/src/nts_record.rs b/ntp-proto/src/nts_record.rs index 14f35e1be..d4006ba83 100644 --- a/ntp-proto/src/nts_record.rs +++ b/ntp-proto/src/nts_record.rs @@ -1529,6 +1529,7 @@ pub struct KeyExchangeServer { #[derive(Debug)] enum State { Active { decoder: KeyExchangeServerDecoder }, + PendingError { error: KeyExchangeError }, Done, } @@ -1564,7 +1565,7 @@ impl KeyExchangeServer { Ok(()) } - fn send_error_record(mut tls_connection: rustls::ServerConnection, error: &KeyExchangeError) { + fn send_error_record(tls_connection: &mut rustls::ServerConnection, error: &KeyExchangeError) { let error_records = [ NtsRecord::Error { errorcode: error.to_error_code(), @@ -1575,7 +1576,7 @@ impl KeyExchangeServer { NtsRecord::EndOfMessage, ]; - if let Err(io) = Self::send_records(&mut tls_connection, &error_records) { + if let Err(io) = Self::send_records(tls_connection, &error_records) { tracing::debug!(key_exchange_error = ?error, io_error = ?io, "sending error record failed"); } } @@ -1589,64 +1590,65 @@ impl KeyExchangeServer { } let mut buf = [0; 512]; - match self.tls_connection.reader().read(&mut buf) { - Ok(0) => { - // the connection was closed cleanly by the client - // see https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read - ControlFlow::Break(self.end_of_file()) - } - Ok(n) => { - match self.state { - State::Active { decoder } => match decoder.step_with_slice(&buf[..n]) { - ControlFlow::Continue(decoder) => { - // more bytes are needed - self.state = State::Active { decoder }; - - // recursively invoke the progress function. This is very unlikely! - // - // Normally, all records are written with a single write call, and - // received as one unit. Using many write calls does not really make - // sense for a client. - // - // So then, the other reason we could end up here is if the buffer is - // full. But 512 bytes is a lot of space for this interaction, and - // should be sufficient in most cases. - ControlFlow::Continue(self) - } - ControlFlow::Break(Ok(data)) => { - // all records have been decoded; send a response - // continues for a clean shutdown of the connection by the client - self.state = State::Done; - self.decoder_done(data) - } - ControlFlow::Break(Err(error)) => { - Self::send_error_record(self.tls_connection, &error); - ControlFlow::Break(Err(error)) + loop { + match self.tls_connection.reader().read(&mut buf) { + Ok(0) => { + // the connection was closed cleanly by the client + // see https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read + if self.wants_write() { + return ControlFlow::Continue(self); + } else { + return ControlFlow::Break(self.end_of_file()); + } + } + Ok(n) => { + match self.state { + State::Active { decoder } => match decoder.step_with_slice(&buf[..n]) { + ControlFlow::Continue(decoder) => { + // more bytes are needed + self.state = State::Active { decoder }; + } + ControlFlow::Break(Ok(data)) => { + // all records have been decoded; send a response + // continues for a clean shutdown of the connection by the client + self.state = State::Done; + return self.decoder_done(data); + } + ControlFlow::Break(Err(error)) => { + Self::send_error_record(&mut self.tls_connection, &error); + self.state = State::PendingError { error }; + return ControlFlow::Continue(self); + } + }, + State::PendingError { .. } | State::Done => { + // client is sending more bytes, but we don't expect any more + // these extra bytes are ignored + return ControlFlow::Continue(self); } - }, - State::Done => { - // client is sending more bytes, but we don't expect any more - // these extra bytes are ignored - ControlFlow::Continue(self) } } + Err(e) => match e.kind() { + std::io::ErrorKind::WouldBlock => { + // basically an await; give other tasks a chance + return ControlFlow::Continue(self); + } + std::io::ErrorKind::UnexpectedEof => { + // the connection was closed uncleanly by the client + // see https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read + if self.wants_write() { + return ControlFlow::Continue(self); + } else { + return ControlFlow::Break(self.end_of_file()); + } + } + _ => { + let error = KeyExchangeError::Io(e); + Self::send_error_record(&mut self.tls_connection, &error); + self.state = State::PendingError { error }; + return ControlFlow::Continue(self); + } + }, } - Err(e) => match e.kind() { - std::io::ErrorKind::WouldBlock => { - // basically an await; give other tasks a chance - ControlFlow::Continue(self) - } - std::io::ErrorKind::UnexpectedEof => { - // the connection was closed uncleanly by the client - // see https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read - ControlFlow::Break(self.end_of_file()) - } - _ => { - let error = KeyExchangeError::Io(e); - Self::send_error_record(self.tls_connection, &error); - ControlFlow::Break(Err(error)) - } - }, } } @@ -1656,6 +1658,10 @@ impl KeyExchangeServer { // there are no more client bytes, but decoding was not finished yet Err(KeyExchangeError::IncompleteResponse) } + State::PendingError { error } => { + // We can now return the error + Err(error) + } State::Done => { // we're all done Ok(self.tls_connection) @@ -1736,8 +1742,11 @@ impl KeyExchangeServer { } } Err(key_extract_error) => { - Self::send_error_record(self.tls_connection, &key_extract_error); - ControlFlow::Break(Err(key_extract_error)) + Self::send_error_record(&mut self.tls_connection, &key_extract_error); + self.state = State::PendingError { + error: key_extract_error, + }; + ControlFlow::Continue(self) } } } diff --git a/ntpd/src/daemon/keyexchange.rs b/ntpd/src/daemon/keyexchange.rs index ccae012a1..91eb1f07c 100644 --- a/ntpd/src/daemon/keyexchange.rs +++ b/ntpd/src/daemon/keyexchange.rs @@ -533,6 +533,12 @@ where let no_write = write_blocks || !this.server.wants_write(); let no_read = read_blocks || !this.server.wants_read(); if no_write && no_read { + // Do any final processing needed + this.server = match this.server.progress() { + ControlFlow::Continue(client) => client, + ControlFlow::Break(Err(e)) => return Poll::Ready(Err(e)), + ControlFlow::Break(Ok(_)) => return Poll::Ready(Ok(())), + }; outer.inner = Some(this); return Poll::Pending; } @@ -696,6 +702,76 @@ mod tests { assert_eq!(result.port, 123); } + #[tokio::test] + async fn key_exchange_weird_packet() { + let provider = KeySetProvider::new(1); + let keyset = provider.get(); + #[cfg(feature = "unstable_nts-pool")] + let pool_certs = ["testdata/certificates/nos-nl.pem"]; + + let (_sender, keyset) = tokio::sync::watch::channel(keyset); + let nts_ke_config = NtsKeConfig { + certificate_chain_path: PathBuf::from("test-keys/end.fullchain.pem"), + private_key_path: PathBuf::from("test-keys/end.key"), + #[cfg(feature = "unstable_nts-pool")] + authorized_pool_server_certificates: pool_certs.iter().map(PathBuf::from).collect(), + key_exchange_timeout_ms: 1000, + concurrent_connections: 512, + listen: "0.0.0.0:5436".parse().unwrap(), + ntp_port: None, + ntp_server: None, + }; + + let _join_handle = spawn(nts_ke_config, keyset); + + // give the server some time to make the port available + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + + let mut stream = client_tls_stream("localhost", 5436).await; + + stream.write_all(b"\x80\x01\x00\x02\x00\x00\x80\x04\x00\x02\x00\x0f\x00\x64\x03\xec\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00").await.unwrap(); + stream.flush().await.unwrap(); + + let mut buf = [0u8; 2048]; + let len = stream.read(&mut buf).await.unwrap(); + assert_eq!(len, 880); + } + + #[tokio::test] + async fn key_exchange_bad_request() { + let provider = KeySetProvider::new(1); + let keyset = provider.get(); + #[cfg(feature = "unstable_nts-pool")] + let pool_certs = ["testdata/certificates/nos-nl.pem"]; + + let (_sender, keyset) = tokio::sync::watch::channel(keyset); + let nts_ke_config = NtsKeConfig { + certificate_chain_path: PathBuf::from("test-keys/end.fullchain.pem"), + private_key_path: PathBuf::from("test-keys/end.key"), + #[cfg(feature = "unstable_nts-pool")] + authorized_pool_server_certificates: pool_certs.iter().map(PathBuf::from).collect(), + key_exchange_timeout_ms: 1000, + concurrent_connections: 512, + listen: "0.0.0.0:5436".parse().unwrap(), + ntp_port: None, + ntp_server: None, + }; + + let _join_handle = spawn(nts_ke_config, keyset); + + // give the server some time to make the port available + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + + let mut stream = client_tls_stream("localhost", 5436).await; + + stream.write_all(b"\x80\x01\x00\x02\x00\x01\x80\x04\x00\x02\x00\x0f\x00\x64\x03\xec\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00").await.unwrap(); + stream.flush().await.unwrap(); + + let mut buf = [0u8; 2048]; + let len = stream.read(&mut buf).await.unwrap(); + assert_eq!(len, 16); + } + #[cfg(not(target_os = "macos"))] #[tokio::test] async fn key_exchange_connection_limiter() {