From dffae5fb497dbead73c3d2c87f4d5f9e6e75b2b3 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Tue, 2 Mar 2021 18:35:40 -0800 Subject: [PATCH] don't panic if the connection is closed at any point --- src/client/decode.rs | 7 ++++++- tests/client_decode.rs | 25 +++++++++++++++++++++++++ tests/test_utils.rs | 34 +++++++++++++++++++++++++++------- 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/src/client/decode.rs b/src/client/decode.rs index 9ef6317..7e98bc1 100644 --- a/src/client/decode.rs +++ b/src/client/decode.rs @@ -29,7 +29,12 @@ where loop { let bytes_read = reader.read_until(LF, &mut buf).await?; // No more bytes are yielded from the stream. - assert!(bytes_read != 0, "Empty response"); // TODO: ensure? + + match (bytes_read, buf.len()) { + (0, 0) => return Err(format_err!("connection closed")), + (0, _) => return Err(format_err!("empty response")), + _ => {} + } // Prevent CWE-400 DDOS with large HTTP Headers. ensure!( diff --git a/tests/client_decode.rs b/tests/client_decode.rs index 65cd1f5..e844fed 100644 --- a/tests/client_decode.rs +++ b/tests/client_decode.rs @@ -1,4 +1,9 @@ +mod test_utils; + mod client_decode { + use std::io::Write; + + use super::test_utils::CloseableCursor; use async_h1::client; use async_std::io::Cursor; use http_types::headers; @@ -42,6 +47,26 @@ mod client_decode { Ok(()) } + #[async_std::test] + async fn connection_closure() -> Result<()> { + let mut cursor = CloseableCursor::default(); + cursor.write_all(b"HTTP/1.1 200 OK\r\nhost: example.com")?; + cursor.close(); + assert_eq!( + client::decode(cursor).await.unwrap_err().to_string(), + "empty response" + ); + + let cursor = CloseableCursor::default(); + cursor.close(); + assert_eq!( + client::decode(cursor).await.unwrap_err().to_string(), + "connection closed" + ); + + Ok(()) + } + #[async_std::test] async fn response_newlines() -> Result<()> { let res = decode_lines(vec![ diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 3c4ef4b..7b78384 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -2,7 +2,7 @@ use async_h1::{ client::Encoder, server::{ConnectionStatus, Server}, }; -use async_std::io::{Read, Write}; +use async_std::io::{Read as AsyncRead, Write as AsyncWrite}; use http_types::{Request, Response, Result}; use std::{ fmt::{Debug, Display}, @@ -58,7 +58,7 @@ where } } -impl Read for TestServer +impl AsyncRead for TestServer where F: Fn(Request) -> Fut, Fut: Future>, @@ -72,7 +72,7 @@ where } } -impl Write for TestServer +impl AsyncWrite for TestServer where F: Fn(Request) -> Fut, Fut: Future>, @@ -187,7 +187,17 @@ impl Debug for CloseableCursor { } } -impl Read for &CloseableCursor { +impl AsyncRead for CloseableCursor { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_read(cx, buf) + } +} + +impl AsyncRead for &CloseableCursor { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -209,7 +219,7 @@ impl Read for &CloseableCursor { } } -impl Write for &CloseableCursor { +impl AsyncWrite for &CloseableCursor { fn poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -237,7 +247,7 @@ impl Write for &CloseableCursor { } } -impl Read for TestIO { +impl AsyncRead for TestIO { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -247,7 +257,7 @@ impl Read for TestIO { } } -impl Write for TestIO { +impl AsyncWrite for TestIO { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -264,3 +274,13 @@ impl Write for TestIO { Pin::new(&mut &*self.write).poll_close(cx) } } + +impl std::io::Write for CloseableCursor { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write().unwrap().data.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +}