diff --git a/src/lib.rs b/src/lib.rs index 1b3119c4..3bf695b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; pub use rustls; +use rustls::server::AcceptedAlert; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -195,6 +196,7 @@ impl TlsAcceptor { pub struct LazyConfigAcceptor { acceptor: rustls::server::Acceptor, io: Option, + alert: Option<(rustls::Error, AcceptedAlert)>, } impl LazyConfigAcceptor @@ -206,6 +208,7 @@ where Self { acceptor, io: Some(io), + alert: None, } } @@ -274,6 +277,16 @@ where } }; + if let Some((err, mut alert)) = this.alert.take() { + return match alert.write(&mut common::SyncWriteAdapter { io, cx }) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + this.alert = Some((err, alert)); + Poll::Pending + } + _ => Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))), + }; + } + let mut reader = common::SyncReadAdapter { io, cx }; match this.acceptor.read_tls(&mut reader) { Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), @@ -287,11 +300,9 @@ where let io = this.io.take().unwrap(); return Poll::Ready(Ok(StartHandshake { accepted, io })); } - Ok(None) => continue, - Err((err, mut alert)) => { - let mut writer = common::SyncWriteAdapter { io, cx }; - let _ = alert.write(&mut writer); // best effort - return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))); + Ok(None) => {} + Err((err, alert)) => { + this.alert = Some((err, alert)); } } }