From f8fd742eaded71994d543bafe82b7ed5e50c257a Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Tue, 5 Mar 2024 13:14:52 -0800 Subject: [PATCH] fix: Check for ErrorKind::WouldBlock in LazyConfigAcceptor --- src/lib.rs | 27 ++++++++++++++++++++++----- tests/test.rs | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1b3119c4..f4875c8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; pub use rustls; +use rustls::server::AcceptedAlert; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -195,6 +196,7 @@ impl TlsAcceptor { pub struct LazyConfigAcceptor { acceptor: rustls::server::Acceptor, io: Option, + alert: Option<(rustls::Error, AcceptedAlert)>, } impl LazyConfigAcceptor @@ -206,6 +208,7 @@ where Self { acceptor, io: Some(io), + alert: None, } } @@ -274,6 +277,22 @@ where } }; + if let Some((err, mut alert)) = this.alert.take() { + match alert.write(&mut common::SyncWriteAdapter { io, cx }) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + this.alert = Some((err, alert)); + return Poll::Pending; + } + Ok(0) | Err(_) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))) + } + Ok(_) => { + this.alert = Some((err, alert)); + continue; + } + }; + } + let mut reader = common::SyncReadAdapter { io, cx }; match this.acceptor.read_tls(&mut reader) { Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), @@ -287,11 +306,9 @@ where let io = this.io.take().unwrap(); return Poll::Ready(Ok(StartHandshake { accepted, io })); } - Ok(None) => continue, - Err((err, mut alert)) => { - let mut writer = common::SyncWriteAdapter { io, cx }; - let _ = alert.write(&mut writer); // best effort - return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))); + Ok(None) => {} + Err((err, alert)) => { + this.alert = Some((err, alert)); } } } diff --git a/tests/test.rs b/tests/test.rs index 8d3921b7..d70d1cae 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -241,5 +241,37 @@ async fn lazy_config_acceptor_take_io() -> Result<(), rustls::Error> { Ok(()) } +#[tokio::test] +async fn lazy_config_acceptor_alert() { + // Intentionally small so that we have to call alert.write several times + let (mut cstream, sstream) = tokio::io::duplex(2); + + let (tx, rx) = oneshot::channel(); + + tokio::spawn(async move { + // This is write instead of write_all because of the short duplex size, which is necessarily + // symmetrical. We never finish writing because the LazyConfigAcceptor returns an error + let _ = cstream.write(b"not tls").await; + let mut buf = Vec::new(); + cstream.read_to_end(&mut buf).await.unwrap(); + tx.send(buf).unwrap(); + }); + + let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream); + + let Ok(accept_result) = time::timeout(Duration::from_secs(3), acceptor).await else { + panic!("timeout"); + }; + + assert!(accept_result.is_err()); + + let Ok(Ok(received)) = time::timeout(Duration::from_secs(3), rx).await else { + panic!("failed to receive"); + }; + + let fatal_alert_decode_error = b"\x15\x03\x03\x00\x02\x02\x32"; + assert_eq!(received, fatal_alert_decode_error) +} + // Include `utils` module include!("utils.rs");