Skip to content

Commit

Permalink
Add ability to read fixed size array without allocating
Browse files Browse the repository at this point in the history
also rename len to size
  • Loading branch information
rklaehn committed Apr 10, 2024
1 parent defb4bd commit 5498e30
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pub mod http_adapter {
Ok(res.freeze())
}

async fn len(&mut self) -> io::Result<u64> {
async fn size(&mut self) -> io::Result<u64> {
let io_err = |text: &str| io::Error::new(io::ErrorKind::Other, text);
let head_response = self
.head_request()
Expand Down
99 changes: 75 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
//! an allocation.
#![deny(missing_docs, rustdoc::broken_intra_doc_links)]

use bytes::{Bytes, BytesMut};
use bytes::{Buf, Bytes, BytesMut};
use std::future::Future;
use std::io::{self, Cursor};

Expand All @@ -64,16 +64,16 @@ pub trait AsyncSliceReader {

/// Get the length of the resource
#[must_use = "io futures must be polled to completion"]
fn len(&mut self) -> impl Future<Output = io::Result<u64>>;
fn size(&mut self) -> impl Future<Output = io::Result<u64>>;
}

impl<'b, T: AsyncSliceReader> AsyncSliceReader for &'b mut T {
async fn read_at(&mut self, offset: u64, len: usize) -> io::Result<Bytes> {
(**self).read_at(offset, len).await
}

async fn len(&mut self) -> io::Result<u64> {
(**self).len().await
async fn size(&mut self) -> io::Result<u64> {
(**self).size().await
}
}

Expand All @@ -82,8 +82,8 @@ impl<T: AsyncSliceReader> AsyncSliceReader for Box<T> {
(**self).read_at(offset, len).await
}

async fn len(&mut self) -> io::Result<u64> {
(**self).len().await
async fn size(&mut self) -> io::Result<u64> {
(**self).size().await
}
}

Expand Down Expand Up @@ -172,45 +172,94 @@ pub trait AsyncStreamReader {
/// Read at most `len` bytes. To read to the end, pass u64::MAX.
///
/// returns an empty buffer to indicate EOF.
fn read(&mut self, len: usize) -> impl Future<Output = io::Result<Bytes>>;
fn read_bytes(&mut self, len: usize) -> impl Future<Output = io::Result<Bytes>>;

/// Read a fixed size buffer.
///
/// If there are less than L bytes available, an io::ErrorKind::UnexpectedEof error is returned.
fn read<const L: usize>(&mut self) -> impl Future<Output = io::Result<[u8; L]>>;
}

impl<T: AsyncStreamReader> AsyncStreamReader for &mut T {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
(**self).read(len).await
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
(**self).read_bytes(len).await
}

async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
(**self).read().await
}
}

impl AsyncStreamReader for Bytes {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
let res = self.split_to(len.min(Bytes::len(self)));
Ok(res)
}

async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
if Bytes::len(self) < L {
return Err(io::ErrorKind::UnexpectedEof.into());
}
let mut res = [0u8; L];
self.split_to(L).copy_to_slice(&mut res);
Ok(res)
}
}

impl AsyncStreamReader for BytesMut {
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
let res = self.split_to(len.min(BytesMut::len(self)));
Ok(res.freeze())
}

async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
if BytesMut::len(self) < L {
return Err(io::ErrorKind::UnexpectedEof.into());
}
let mut res = [0u8; L];
self.split_to(L).copy_to_slice(&mut res);
Ok(res)
}
}

impl AsyncStreamReader for &[u8] {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
let len = len.min(self.len());
let res = Bytes::copy_from_slice(&self[..len]);
*self = &self[len..];
Ok(res)
}
}

impl AsyncStreamReader for BytesMut {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
let res = self.split_to(len.min(BytesMut::len(self)));
Ok(res.freeze())
async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
if self.len() < L {
return Err(io::ErrorKind::UnexpectedEof.into());
}
let mut res = [0u8; L];
res.copy_from_slice(&self[..L]);
*self = &self[L..];
Ok(res)
}
}

impl<T: AsyncSliceReader> AsyncStreamReader for Cursor<T> {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
let offset = self.position();
let res = self.get_mut().read_at(offset, len).await?;
self.set_position(offset + res.len() as u64);
Ok(res)
}

async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
let offset = self.position();
let res = self.get_mut().read_at(offset, L).await?;
if res.len() < L {
return Err(io::ErrorKind::UnexpectedEof.into());
}
self.set_position(offset + res.len() as u64);
let mut buf = [0u8; L];
buf.copy_from_slice(&res);
Ok(buf)
}
}

/// A non seekable writer, e.g. a network socket.
Expand Down Expand Up @@ -304,10 +353,10 @@ where
}
}

async fn len(&mut self) -> io::Result<u64> {
async fn size(&mut self) -> io::Result<u64> {
match self {
Self::Left(l) => l.len().await,
Self::Right(r) => r.len().await,
Self::Left(l) => l.size().await,
Self::Right(r) => r.size().await,
}
}
}
Expand Down Expand Up @@ -452,7 +501,7 @@ mod tests {
let res = file.read_at(0, usize::MAX).await?;
assert_eq!(res, expected);

let res = file.len().await?;
let res = file.size().await?;
assert_eq!(res, 100);

// read 3 bytes at offset 0
Expand Down Expand Up @@ -687,7 +736,7 @@ mod tests {
current = offset.checked_add(len as u64).unwrap();
}
ReadOp::Len => {
let len = AsyncSliceReader::len(&mut file).await?;
let len = AsyncSliceReader::size(&mut file).await?;
assert_eq!(len, actual.len() as u64);
}
}
Expand Down Expand Up @@ -717,7 +766,7 @@ mod tests {
let url = reqwest::Url::parse(&url).unwrap();
let server = tokio::spawn(server);
let mut reader = HttpAdapter::new(url);
let len = reader.len().await.unwrap();
let len = reader.size().await.unwrap();
assert_eq!(len, 11);
println!("len: {:?}", reader);
let part = reader.read_at(0, 11).await.unwrap();
Expand Down Expand Up @@ -747,7 +796,9 @@ mod tests {

#[test]
fn bytes_read(data in proptest::collection::vec(any::<u8>(), 0..1024), ops in random_read_ops(1024, 1024, 2)) {
async_test(read_op_test(ops, Bytes::from(data.clone()), &data)).unwrap();
async_test(read_op_test(ops.clone(), Bytes::from(data.clone()), &data)).unwrap();
async_test(read_op_test(ops.clone(), BytesMut::from(data.as_slice()), &data)).unwrap();
async_test(read_op_test(ops, data.as_slice(), &data)).unwrap();
}

#[cfg(feature = "tokio-io")]
Expand Down
14 changes: 12 additions & 2 deletions src/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ impl AsyncSliceReader for bytes::Bytes {
Ok(get_limited_slice(self, offset, len))
}

async fn len(&mut self) -> io::Result<u64> {
async fn size(&mut self) -> io::Result<u64> {
Ok(Bytes::len(self) as u64)
}
}
Expand All @@ -17,11 +17,21 @@ impl AsyncSliceReader for bytes::BytesMut {
Ok(copy_limited_slice(self, offset, len))
}

async fn len(&mut self) -> io::Result<u64> {
async fn size(&mut self) -> io::Result<u64> {
Ok(BytesMut::len(self) as u64)
}
}

impl AsyncSliceReader for &[u8] {
async fn read_at(&mut self, offset: u64, len: usize) -> io::Result<Bytes> {
Ok(copy_limited_slice(self, offset, len))
}

async fn size(&mut self) -> io::Result<u64> {
Ok(self.len() as u64)
}
}

impl AsyncSliceWriter for bytes::BytesMut {
async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
write_extend(self, offset, &data)
Expand Down
18 changes: 11 additions & 7 deletions src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,12 @@ impl<W> TrackingStreamReader<W> {
}

impl<W: AsyncStreamReader> AsyncStreamReader for TrackingStreamReader<W> {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
AggregateSizeAndStats::new(self.inner.read(len), &mut self.stats.read).await
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
AggregateSizeAndStats::new(self.inner.read_bytes(len), &mut self.stats.read).await
}

async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
AggregateSizeAndStats::new(self.inner.read(), &mut self.stats.read).await
}
}

Expand Down Expand Up @@ -288,8 +292,8 @@ impl<R: AsyncSliceReader> AsyncSliceReader for TrackingSliceReader<R> {
AggregateSizeAndStats::new(self.inner.read_at(offset, len), &mut self.stats.read_at).await
}

async fn len(&mut self) -> io::Result<u64> {
AggregateStats::new(self.inner.len(), &mut self.stats.len).await
async fn size(&mut self) -> io::Result<u64> {
AggregateStats::new(self.inner.size(), &mut self.stats.len).await
}
}

Expand Down Expand Up @@ -507,8 +511,8 @@ mod tests {
#[tokio::test]
async fn tracking_stream_reader() {
let mut writer = TrackingStreamReader::new(Bytes::from(vec![0, 1, 2, 3]));
writer.read(2).await.unwrap();
writer.read(3).await.unwrap();
writer.read_bytes(2).await.unwrap();
writer.read_bytes(3).await.unwrap();
assert_eq!(writer.stats().read.size, 4); // not 5, because the last read was only 2 bytes
assert_eq!(writer.stats().read.stats.count, 2);
}
Expand Down Expand Up @@ -537,7 +541,7 @@ mod tests {
let mut reader = TrackingSliceReader::new(Bytes::from(vec![1u8, 2, 3]));
let _ = reader.read_at(0, 1).await.unwrap();
let _ = reader.read_at(10, 1).await.unwrap();
let _ = reader.len().await.unwrap();
let _ = reader.size().await.unwrap();
assert_eq!(reader.stats().read_at.size, 1);
assert_eq!(reader.stats().read_at.stats.count, 2);
assert_eq!(reader.stats().len.count, 1);
Expand Down
29 changes: 25 additions & 4 deletions src/tokio_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub mod file {
Asyncify::from(self.0.take().map(|t| (t.read_at(offset, len), &mut self.0))).await
}

async fn len(&mut self) -> io::Result<u64> {
async fn size(&mut self) -> io::Result<u64> {
Asyncify::from(self.0.take().map(|t| (t.len(), &mut self.0))).await
}
}
Expand Down Expand Up @@ -289,12 +289,33 @@ impl<T: tokio::io::AsyncWrite + Unpin> AsyncStreamWriter for TokioStreamWriter<T

/// Utility to convert a [tokio::io::AsyncRead] into an [AsyncStreamReader].
#[derive(Debug, Clone)]
pub struct TokioStreamReader<T>(T);
pub struct TokioStreamReader<T>(pub T);

impl<T> TokioStreamReader<T> {
/// Create a new `TokioStreamReader` from an inner reader
pub fn new(inner: T) -> Self {
Self(inner)
}

/// Return the inner reader
pub fn into_inner(self) -> T {
self.0
}
}

impl<T: tokio::io::AsyncRead + Unpin> AsyncStreamReader for TokioStreamReader<T> {
async fn read(&mut self, len: usize) -> io::Result<Bytes> {
async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
let mut buf = Vec::with_capacity(len.min(MAX_PREALLOC));
(&mut self.0).take(len as u64).read_to_end(&mut buf).await?;
(&mut self.0)
.take(len as u64)
.read_to_end(&mut buf)
.await?;
Ok(buf.into())
}

async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
let mut buf = [0; L];
self.0.read_exact(&mut buf).await?;
Ok(buf)
}
}

0 comments on commit 5498e30

Please sign in to comment.