Skip to content

Commit e17465c

Browse files
feat: add Limited body (#61)
* feat: add `Limited` body * fix: correct size_hint, remove const generic * chore: use boxed error, pin project Co-authored-by: Programatik <programatik29@gmail.com> Co-authored-by: Programatik <programatik29@gmail.com>
1 parent 730e9bd commit e17465c

File tree

2 files changed

+301
-0
lines changed

2 files changed

+301
-0
lines changed

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
1616
mod empty;
1717
mod full;
18+
mod limited;
1819
mod next;
1920
mod size_hint;
2021

2122
pub mod combinators;
2223

2324
pub use self::empty::Empty;
2425
pub use self::full::Full;
26+
pub use self::limited::{LengthLimitError, Limited};
2527
pub use self::next::{Data, Trailers};
2628
pub use self::size_hint::SizeHint;
2729

src/limited.rs

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
use crate::{Body, SizeHint};
2+
use bytes::Buf;
3+
use http::HeaderMap;
4+
use pin_project_lite::pin_project;
5+
use std::error::Error;
6+
use std::fmt;
7+
use std::pin::Pin;
8+
use std::task::{Context, Poll};
9+
10+
pin_project! {
11+
/// A length limited body.
12+
///
13+
/// This body will return an error if more than the configured number
14+
/// of bytes are returned on polling the wrapped body.
15+
#[derive(Clone, Copy, Debug)]
16+
pub struct Limited<B> {
17+
remaining: usize,
18+
#[pin]
19+
inner: B,
20+
}
21+
}
22+
23+
impl<B> Limited<B> {
24+
/// Create a new `Limited`.
25+
pub fn new(inner: B, limit: usize) -> Self {
26+
Self {
27+
remaining: limit,
28+
inner,
29+
}
30+
}
31+
}
32+
33+
impl<B> Body for Limited<B>
34+
where
35+
B: Body,
36+
B::Error: Into<Box<dyn Error + Send + Sync>>,
37+
{
38+
type Data = B::Data;
39+
type Error = Box<dyn Error + Send + Sync>;
40+
41+
fn poll_data(
42+
self: Pin<&mut Self>,
43+
cx: &mut Context<'_>,
44+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
45+
let this = self.project();
46+
let res = match this.inner.poll_data(cx) {
47+
Poll::Pending => return Poll::Pending,
48+
Poll::Ready(None) => None,
49+
Poll::Ready(Some(Ok(data))) => {
50+
if data.remaining() > *this.remaining {
51+
*this.remaining = 0;
52+
Some(Err(LengthLimitError.into()))
53+
} else {
54+
*this.remaining -= data.remaining();
55+
Some(Ok(data))
56+
}
57+
}
58+
Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
59+
};
60+
61+
Poll::Ready(res)
62+
}
63+
64+
fn poll_trailers(
65+
self: Pin<&mut Self>,
66+
cx: &mut Context<'_>,
67+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
68+
let this = self.project();
69+
let res = match this.inner.poll_trailers(cx) {
70+
Poll::Pending => return Poll::Pending,
71+
Poll::Ready(Ok(data)) => Ok(data),
72+
Poll::Ready(Err(err)) => Err(err.into()),
73+
};
74+
75+
Poll::Ready(res)
76+
}
77+
78+
fn is_end_stream(&self) -> bool {
79+
self.inner.is_end_stream()
80+
}
81+
82+
fn size_hint(&self) -> SizeHint {
83+
use std::convert::TryFrom;
84+
match u64::try_from(self.remaining) {
85+
Ok(n) => {
86+
let mut hint = self.inner.size_hint();
87+
if hint.lower() >= n {
88+
hint.set_exact(n)
89+
} else if let Some(max) = hint.upper() {
90+
hint.set_upper(n.min(max))
91+
} else {
92+
hint.set_upper(n)
93+
}
94+
hint
95+
}
96+
Err(_) => self.inner.size_hint(),
97+
}
98+
}
99+
}
100+
101+
/// An error returned when body length exceeds the configured limit.
102+
#[derive(Debug)]
103+
#[non_exhaustive]
104+
pub struct LengthLimitError;
105+
106+
impl fmt::Display for LengthLimitError {
107+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108+
f.write_str("length limit exceeded")
109+
}
110+
}
111+
112+
impl Error for LengthLimitError {}
113+
114+
#[cfg(test)]
115+
mod tests {
116+
use super::*;
117+
use crate::Full;
118+
use bytes::Bytes;
119+
use std::convert::Infallible;
120+
121+
#[tokio::test]
122+
async fn read_for_body_under_limit_returns_data() {
123+
const DATA: &[u8] = b"testing";
124+
let inner = Full::new(Bytes::from(DATA));
125+
let body = &mut Limited::new(inner, 8);
126+
127+
let mut hint = SizeHint::new();
128+
hint.set_upper(7);
129+
assert_eq!(body.size_hint().upper(), hint.upper());
130+
131+
let data = body.data().await.unwrap().unwrap();
132+
assert_eq!(data, DATA);
133+
hint.set_upper(0);
134+
assert_eq!(body.size_hint().upper(), hint.upper());
135+
136+
assert!(matches!(body.data().await, None));
137+
}
138+
139+
#[tokio::test]
140+
async fn read_for_body_over_limit_returns_error() {
141+
const DATA: &[u8] = b"testing a string that is too long";
142+
let inner = Full::new(Bytes::from(DATA));
143+
let body = &mut Limited::new(inner, 8);
144+
145+
let mut hint = SizeHint::new();
146+
hint.set_upper(8);
147+
assert_eq!(body.size_hint().upper(), hint.upper());
148+
149+
let error = body.data().await.unwrap().unwrap_err();
150+
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
151+
}
152+
153+
struct Chunky(&'static [&'static [u8]]);
154+
155+
impl Body for Chunky {
156+
type Data = &'static [u8];
157+
type Error = Infallible;
158+
159+
fn poll_data(
160+
self: Pin<&mut Self>,
161+
_cx: &mut Context<'_>,
162+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
163+
let mut this = self;
164+
match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
165+
Some((data, new_tail)) => {
166+
this.0 = new_tail;
167+
168+
Poll::Ready(Some(data))
169+
}
170+
None => Poll::Ready(None),
171+
}
172+
}
173+
174+
fn poll_trailers(
175+
self: Pin<&mut Self>,
176+
_cx: &mut Context<'_>,
177+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
178+
Poll::Ready(Ok(Some(HeaderMap::new())))
179+
}
180+
}
181+
182+
#[tokio::test]
183+
async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
184+
) {
185+
const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
186+
let inner = Chunky(DATA);
187+
let body = &mut Limited::new(inner, 8);
188+
189+
let mut hint = SizeHint::new();
190+
hint.set_upper(8);
191+
assert_eq!(body.size_hint().upper(), hint.upper());
192+
193+
let data = body.data().await.unwrap().unwrap();
194+
assert_eq!(data, DATA[0]);
195+
hint.set_upper(0);
196+
assert_eq!(body.size_hint().upper(), hint.upper());
197+
198+
let error = body.data().await.unwrap().unwrap_err();
199+
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
200+
}
201+
202+
#[tokio::test]
203+
async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
204+
const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
205+
let inner = Chunky(DATA);
206+
let body = &mut Limited::new(inner, 8);
207+
208+
let mut hint = SizeHint::new();
209+
hint.set_upper(8);
210+
assert_eq!(body.size_hint().upper(), hint.upper());
211+
212+
let error = body.data().await.unwrap().unwrap_err();
213+
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
214+
}
215+
216+
#[tokio::test]
217+
async fn read_for_chunked_body_under_limit_is_okay() {
218+
const DATA: &[&[u8]] = &[b"test", b"ing!"];
219+
let inner = Chunky(DATA);
220+
let body = &mut Limited::new(inner, 8);
221+
222+
let mut hint = SizeHint::new();
223+
hint.set_upper(8);
224+
assert_eq!(body.size_hint().upper(), hint.upper());
225+
226+
let data = body.data().await.unwrap().unwrap();
227+
assert_eq!(data, DATA[0]);
228+
hint.set_upper(4);
229+
assert_eq!(body.size_hint().upper(), hint.upper());
230+
231+
let data = body.data().await.unwrap().unwrap();
232+
assert_eq!(data, DATA[1]);
233+
hint.set_upper(0);
234+
assert_eq!(body.size_hint().upper(), hint.upper());
235+
236+
assert!(matches!(body.data().await, None));
237+
}
238+
239+
#[tokio::test]
240+
async fn read_for_trailers_propagates_inner_trailers() {
241+
const DATA: &[&[u8]] = &[b"test", b"ing!"];
242+
let inner = Chunky(DATA);
243+
let body = &mut Limited::new(inner, 8);
244+
let trailers = body.trailers().await.unwrap();
245+
assert_eq!(trailers, Some(HeaderMap::new()))
246+
}
247+
248+
#[derive(Debug)]
249+
enum ErrorBodyError {
250+
Data,
251+
Trailers,
252+
}
253+
254+
impl fmt::Display for ErrorBodyError {
255+
fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
256+
Ok(())
257+
}
258+
}
259+
260+
impl Error for ErrorBodyError {}
261+
262+
struct ErrorBody;
263+
264+
impl Body for ErrorBody {
265+
type Data = &'static [u8];
266+
type Error = ErrorBodyError;
267+
268+
fn poll_data(
269+
self: Pin<&mut Self>,
270+
_cx: &mut Context<'_>,
271+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
272+
Poll::Ready(Some(Err(ErrorBodyError::Data)))
273+
}
274+
275+
fn poll_trailers(
276+
self: Pin<&mut Self>,
277+
_cx: &mut Context<'_>,
278+
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
279+
Poll::Ready(Err(ErrorBodyError::Trailers))
280+
}
281+
}
282+
283+
#[tokio::test]
284+
async fn read_for_body_returning_error_propagates_error() {
285+
let body = &mut Limited::new(ErrorBody, 8);
286+
let error = body.data().await.unwrap().unwrap_err();
287+
assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data)));
288+
}
289+
290+
#[tokio::test]
291+
async fn trailers_for_body_returning_error_propagates_error() {
292+
let body = &mut Limited::new(ErrorBody, 8);
293+
let error = body.trailers().await.unwrap_err();
294+
assert!(matches!(
295+
error.downcast_ref(),
296+
Some(ErrorBodyError::Trailers)
297+
));
298+
}
299+
}

0 commit comments

Comments
 (0)