diff --git a/bindings/rust/s2n-tls-tokio/src/lib.rs b/bindings/rust/s2n-tls-tokio/src/lib.rs index 52f31daaa69..3e6d60e687b 100644 --- a/bindings/rust/s2n-tls-tokio/src/lib.rs +++ b/bindings/rust/s2n-tls-tokio/src/lib.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 use errno::{set_errno, Errno}; -use pin_project_lite::pin_project; use s2n_tls::{ config::Config, connection::{Builder, Connection}, @@ -12,7 +11,7 @@ use s2n_tls::{ use std::{ fmt, future::Future, - io, mem, + io, os::raw::{c_int, c_void}, pin::Pin, task::{ @@ -139,18 +138,6 @@ where } } -pin_project! { -struct BlindingState { - #[pin] - timer: Sleep, - - // The remembered error if we got into blinding because of - // an error, or Ok(()) if we didn't. After returning the error, - // this goes back to Ok(()). - remembered_error: Result<(), Error>, -} -} - pub struct TlsStream where C: AsRef + AsMut + Unpin, @@ -158,7 +145,8 @@ where { conn: C, stream: S, - blinding: Option>>, + blinding: Option>>, + shutdown_error: Option, } impl TlsStream @@ -182,6 +170,7 @@ where conn, stream, blinding: None, + shutdown_error: None, }; TlsHandshake { tls: &mut tls, @@ -255,35 +244,6 @@ where }) } - // Sets the blinding timer to the remaining blinding delay and possibly - // remembers an error. - // - // Returns the error if there was no blinding needed and the error - // did not need to be remembered. - fn set_blinding_timer( - self: Pin<&mut Self>, - mut remembered_error: Result<(), Error>, - ) -> Result<(), Error> { - let tls = self.get_mut(); - - if tls.blinding.is_none() { - let delay = tls.as_ref().remaining_blinding_delay()?; - if !delay.is_zero() { - // Sleep operates at the milisecond resolution, so add an extra - // millisecond to account for any stray nanoseconds. - let safety = Duration::from_millis(1); - // Return the error *later*, after the blinding is done - let remembered_error = mem::replace(&mut remembered_error, Ok(())); - tls.blinding = Some(Box::pin(BlindingState { - timer: sleep(delay.saturating_add(safety)), - remembered_error, - })); - } - } - - remembered_error - } - /// Polls the blinding timer, if there is any. /// /// s2n has a "blinding" functionality - when a bad behavior from the peer @@ -296,25 +256,24 @@ where /// before dropping an s2n connection, you should wait until either /// `poll_blinding` or `poll_shutdown` (which calls `poll_blinding` /// internally) returns ready. - pub fn poll_blinding( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll> { - self.as_mut().set_blinding_timer(Ok(()))?; - + pub fn poll_blinding(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { let tls = self.get_mut(); - if let Some(blinding) = &mut tls.blinding { - ready!(blinding.as_mut().project().timer.as_mut().poll(ctx)); - - // Set blinding to None to ensure the next go can have blinding - let mut blinding = tls.blinding.take().unwrap(); + if tls.blinding.is_none() { + let delay = tls.as_ref().remaining_blinding_delay()?; + if !delay.is_zero() { + // Sleep operates at the milisecond resolution, so add an extra + // millisecond to account for any stray nanoseconds. + let safety = Duration::from_millis(1); + tls.blinding = Some(Box::pin(sleep(delay.saturating_add(safety)))); + } + }; - // If there is an error, return it - mem::replace(blinding.as_mut().project().remembered_error, Ok(()))?; + if let Some(timer) = tls.blinding.as_mut() { + ready!(timer.as_mut().poll(ctx)); + tls.blinding = None; } - // Otherwise we are OK Poll::Ready(Ok(())) } @@ -404,19 +363,33 @@ where fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_blinding(ctx))?; - let status = ready!(self.as_mut().with_io(ctx, |mut context| { - context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ())) - })); + // s2n_shutdown must not be called again if it errors + if self.shutdown_error.is_none() { + let result = ready!(self.as_mut().with_io(ctx, |mut context| { + context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ())) + })); + if let Err(error) = result { + self.shutdown_error = Some(error); + // s2n_shutdown reading might have triggered blinding again + ready!(self.as_mut().poll_blinding(ctx))?; + } + }; + + let tcp_result = ready!(Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx)); - if let Err(e) = status { - // In case of an error shutting down, make sure you wait for - // the blinding timeout. - self.as_mut().set_blinding_timer(Err(e))?; - ready!(self.as_mut().poll_blinding(ctx))?; - unreachable!("should have returned the error we just put in!"); - } + if let Some(err) = self.shutdown_error.take() { + // poll methods shouldn't be called again after returning Ready, but + // nothing actually prevents it so poll_shutdown should handle it. + // s2n_shutdown can be polled indefinitely after succeeding, but not after failing. + // s2n_tls::error::Error isn't cloneable, so we can't just return the same error + // if poll_shutdown is called again. Instead, save a different error. + let next_error = Error::application("Shutdown called again after error".into()); + self.shutdown_error = Some(next_error); - Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx) + Ready(Err(io::Error::from(err))) + } else { + Ready(tcp_result) + } } } diff --git a/bindings/rust/s2n-tls-tokio/tests/common/mod.rs b/bindings/rust/s2n-tls-tokio/tests/common/mod.rs index b9fbf7832d8..fbcab0e1e06 100644 --- a/bindings/rust/s2n-tls-tokio/tests/common/mod.rs +++ b/bindings/rust/s2n-tls-tokio/tests/common/mod.rs @@ -40,6 +40,8 @@ pub static RSA_KEY_PEM: &[u8] = include_bytes!(concat!( pub const MIN_BLINDING_SECS: Duration = Duration::from_secs(10); pub const MAX_BLINDING_SECS: Duration = Duration::from_secs(30); +pub static TEST_STR: &str = "hello world"; + pub async fn get_streams() -> Result<(TcpStream, TcpStream), tokio::io::Error> { let localhost = "127.0.0.1".to_owned(); let listener = TcpListener::bind(format!("{}:0", localhost)).await?; diff --git a/bindings/rust/s2n-tls-tokio/tests/common/stream.rs b/bindings/rust/s2n-tls-tokio/tests/common/stream.rs index 6ef86cfce0c..889f2926282 100644 --- a/bindings/rust/s2n-tls-tokio/tests/common/stream.rs +++ b/bindings/rust/s2n-tls-tokio/tests/common/stream.rs @@ -14,11 +14,13 @@ use tokio::{ type ReadFn = Box, &mut Context, &mut ReadBuf) -> Poll>>; type WriteFn = Box, &mut Context, &[u8]) -> Poll>>; +type ShutdownFn = Box, &mut Context) -> Poll>>; #[derive(Default)] struct OverrideMethods { next_read: Option, next_write: Option, + next_shutdown: Option, } #[derive(Default)] @@ -36,6 +38,22 @@ impl Overrides { overrides.next_write = input; } } + + pub fn next_shutdown(&self, input: Option) { + if let Ok(mut overrides) = self.0.lock() { + overrides.next_shutdown = input; + } + } + + pub fn is_consumed(&self) -> bool { + if let Ok(overrides) = self.0.lock() { + overrides.next_read.is_none() + && overrides.next_write.is_none() + && overrides.next_shutdown.is_none() + } else { + false + } + } } unsafe impl Send for Overrides {} @@ -100,7 +118,17 @@ impl AsyncWrite for TestStream { Pin::new(&mut self.stream).poll_flush(ctx) } - fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.stream).poll_shutdown(ctx) + fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let s = self.get_mut(); + let stream = Pin::new(&mut s.stream); + let action = match s.overrides.0.lock() { + Ok(mut overrides) => overrides.next_shutdown.take(), + _ => None, + }; + if let Some(f) = action { + (f)(stream, ctx) + } else { + stream.poll_shutdown(ctx) + } } } diff --git a/bindings/rust/s2n-tls-tokio/tests/shutdown.rs b/bindings/rust/s2n-tls-tokio/tests/shutdown.rs index 29931bd08bf..e67b87599dd 100644 --- a/bindings/rust/s2n-tls-tokio/tests/shutdown.rs +++ b/bindings/rust/s2n-tls-tokio/tests/shutdown.rs @@ -3,7 +3,12 @@ use s2n_tls::error; use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream}; -use std::{convert::TryFrom, sync::Arc}; +use std::{ + convert::TryFrom, + io, + sync::Arc, + task::Poll::{Pending, Ready}, +}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, time, @@ -164,7 +169,11 @@ async fn shutdown_with_blinding() -> Result<(), Box> { // Attempt to shutdown the client. This will eventually fail because the // server has not written the close_notify message yet, but it will at least // write the close_notify message that the server needs. - // Because time is mocked for testing, this does not actually take LONG_TIMEOUT. + // + // Because this test begins paused and relies on auto-advancing, this does + // not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs: + // https://docs.rs/tokio/latest/tokio/time/fn.pause.html + // // TODO: replace this with a half-close once the bindings support half-close. let timeout = time::timeout(LONG_TIMEOUT, client.shutdown()).await; assert!(timeout.is_err()); @@ -285,3 +294,113 @@ async fn shutdown_with_poll_blinding() -> Result<(), Box> Ok(()) } + +#[tokio::test(start_paused = true)] +async fn shutdown_with_tcp_error() -> Result<(), Box> { + let client = TlsConnector::new(common::client_config()?.build()?); + let server = TlsAcceptor::new(common::server_config()?.build()?); + + let (server_stream, client_stream) = common::get_streams().await?; + let server_stream = common::TestStream::new(server_stream); + let overrides = server_stream.overrides(); + + let (mut client, mut server) = + common::run_negotiate(&client, client_stream, &server, server_stream).await?; + + // Attempt to shutdown the client. This will eventually fail because the + // server has not written the close_notify message yet, but it will at least + // write the close_notify message that the server needs. + // + // Because this test begins paused and relies on auto-advancing, this does + // not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs: + // https://docs.rs/tokio/latest/tokio/time/fn.pause.html + // + // TODO: replace this with a half-close once the bindings support half-close. + _ = time::timeout(time::Duration::from_secs(600), client.shutdown()).await; + + // The underlying stream should return a unique error on shutdown + overrides.next_shutdown(Some(Box::new(|_, _| { + Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR))) + }))); + + // Shutdown should complete with the correct error from the underlying stream + let result = server.shutdown().await; + let error = result.unwrap_err().into_inner().unwrap(); + assert!(error.to_string() == common::TEST_STR); + + Ok(()) +} + +#[tokio::test] +async fn shutdown_with_tls_error_and_tcp_error() -> Result<(), Box> { + let client = TlsConnector::new(common::client_config()?.build()?); + let server = TlsAcceptor::new(common::server_config()?.build()?); + + let (server_stream, client_stream) = common::get_streams().await?; + let server_stream = common::TestStream::new(server_stream); + let overrides = server_stream.overrides(); + + let (_, mut server) = + common::run_negotiate(&client, client_stream, &server, server_stream).await?; + + // Both s2n_shutdown and the underlying stream should error on shutdown + overrides.next_read(Some(Box::new(|_, _, _| { + Ready(Err(io::Error::from(io::ErrorKind::Other))) + }))); + overrides.next_shutdown(Some(Box::new(|_, _| { + Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR))) + }))); + + // Shutdown should complete with the correct error from s2n_shutdown + let result = server.shutdown().await; + let io_error = result.unwrap_err(); + let error: error::Error = io_error.try_into()?; + // Any non-blocking read error is translated as "IOError" + assert!(error.kind() == error::ErrorType::IOError); + + // Even if s2n_shutdown fails, we need to close the underlying stream. + // Make sure we called our mock shutdown, consuming it. + assert!(overrides.is_consumed()); + + Ok(()) +} + +#[tokio::test] +async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box> { + let client = TlsConnector::new(common::client_config()?.build()?); + let server = TlsAcceptor::new(common::server_config()?.build()?); + + let (server_stream, client_stream) = common::get_streams().await?; + let server_stream = common::TestStream::new(server_stream); + let overrides = server_stream.overrides(); + + let (_, mut server) = + common::run_negotiate(&client, client_stream, &server, server_stream).await?; + + // We want s2n_shutdown to fail on read in order to ensure that it is only + // called once on failure. + // If s2n_shutdown were called again, the second call would hang waiting + // for nonexistent input from the peer. + overrides.next_read(Some(Box::new(|_, _, _| { + Ready(Err(io::Error::from(io::ErrorKind::Other))) + }))); + + // The underlying stream should initially return Pending, delaying shutdown + overrides.next_shutdown(Some(Box::new(|_, ctx| { + ctx.waker().wake_by_ref(); + Pending + }))); + + // Shutdown should complete with the correct error from s2n_shutdown + let result = server.shutdown().await; + let io_error = result.unwrap_err(); + let error: error::Error = io_error.try_into()?; + // Any non-blocking read error is translated as "IOError" + assert!(error.kind() == error::ErrorType::IOError); + + // Even if s2n_shutdown fails, we need to close the underlying stream. + // Make sure we at least called our mock shutdown, consuming it. + assert!(overrides.is_consumed()); + + Ok(()) +}