diff --git a/rs/moq-lite/src/client.rs b/rs/moq-lite/src/client.rs index 7f5a4d3f1..3862c5492 100644 --- a/rs/moq-lite/src/client.rs +++ b/rs/moq-lite/src/client.rs @@ -90,7 +90,7 @@ impl Client { .iter() .find(|v| coding::Version::from(**v) == server.version) .copied() - .ok_or_else(|| Error::Version(client.versions.clone(), supported.clone().into()))?; + .ok_or(Error::Version)?; match version { Version::Lite(version) => { diff --git a/rs/moq-lite/src/coding/reader.rs b/rs/moq-lite/src/coding/reader.rs index 1531c49fc..de81edfd5 100644 --- a/rs/moq-lite/src/coding/reader.rs +++ b/rs/moq-lite/src/coding/reader.rs @@ -1,4 +1,4 @@ -use std::{cmp, fmt::Debug, io, sync::Arc}; +use std::{cmp, fmt::Debug, io}; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -34,18 +34,12 @@ impl Reader { } Err(DecodeError::Short) => { // Try to read more data - if self - .stream - .read_buf(&mut self.buffer) - .await - .map_err(|e| Error::Transport(Arc::new(e)))? - .is_none() - { + if !self.read_more().await? { // Stream closed while we still need more data - return Err(Error::Decode(DecodeError::Short)); + return Err(Error::Decode); } } - Err(e) => return Err(Error::Decode(e)), + Err(e) => return Err(e.into()), } } } @@ -55,11 +49,11 @@ impl Reader { where V: Clone, { - match self.closed().await { - Ok(()) => Ok(None), - Err(Error::Decode(DecodeError::ExpectedEnd)) => Ok(Some(self.decode().await?)), - Err(e) => Err(e), + if !self.has_more().await? { + return Ok(None); } + + Ok(Some(self.decode().await?)) } /// Decode the next message from the stream without consuming it. @@ -73,18 +67,12 @@ impl Reader { Ok(msg) => return Ok(msg), Err(DecodeError::Short) => { // Try to read more data - if self - .stream - .read_buf(&mut self.buffer) - .await - .map_err(|e| Error::Transport(Arc::new(e)))? - .is_none() - { + if !self.read_more().await? { // Stream closed while we still need more data - return Err(Error::Decode(DecodeError::Short)); + return Err(Error::Decode); } } - Err(e) => return Err(Error::Decode(e)), + Err(e) => return Err(e.into()), } } } @@ -97,10 +85,7 @@ impl Reader { return Ok(Some(data)); } - self.stream - .read_chunk(max) - .await - .map_err(|e| Error::Transport(Arc::new(e))) + self.stream.read_chunk(max).await.map_err(Error::from_transport) } /// Read exactly the given number of bytes from the stream. @@ -118,10 +103,11 @@ impl Reader { buf.put(data); while buf.has_remaining_mut() { - self.stream - .read_buf(&mut buf) - .await - .map_err(|e| Error::Transport(Arc::new(e)))?; + match self.stream.read_buf(&mut buf).await { + Ok(Some(_)) => {} + Ok(None) => return Err(Error::Decode), + Err(e) => return Err(Error::from_transport(e)), + } } Ok(buf.into_inner().freeze()) @@ -138,8 +124,8 @@ impl Reader { .stream .read_chunk(size) .await - .map_err(|e| Error::Transport(Arc::new(e)))? - .ok_or(Error::Decode(DecodeError::Short))?; + .map_err(Error::from_transport)? + .ok_or(Error::Decode)?; size -= chunk.len(); } @@ -148,18 +134,29 @@ impl Reader { /// Wait until the stream is closed, erroring if there are any additional bytes. pub async fn closed(&mut self) -> Result<(), Error> { - if self.buffer.is_empty() - && self - .stream - .read_buf(&mut self.buffer) - .await - .map_err(|e| Error::Transport(Arc::new(e)))? - .is_none() - { - return Ok(()); + if self.has_more().await? { + return Err(Error::Decode); + } + + Ok(()) + } + + /// Returns true if there is more data available in the buffer or stream. + async fn has_more(&mut self) -> Result { + if !self.buffer.is_empty() { + return Ok(true); } - Err(DecodeError::ExpectedEnd.into()) + self.read_more().await + } + + /// Try to read more data from the stream. Returns true if data was read, false if stream closed. + async fn read_more(&mut self) -> Result { + match self.stream.read_buf(&mut self.buffer).await { + Ok(Some(_)) => Ok(true), + Ok(None) => Ok(false), + Err(e) => Err(Error::from_transport(e)), + } } /// Abort the stream with the given error. diff --git a/rs/moq-lite/src/coding/stream.rs b/rs/moq-lite/src/coding/stream.rs index 44c6c8a77..f92b99e44 100644 --- a/rs/moq-lite/src/coding/stream.rs +++ b/rs/moq-lite/src/coding/stream.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::Error; use crate::coding::{Reader, Writer}; @@ -15,7 +13,7 @@ impl Stream { where V: Clone, { - let (send, recv) = session.open_bi().await.map_err(|err| Error::Transport(Arc::new(err)))?; + let (send, recv) = session.open_bi().await.map_err(Error::from_transport)?; let writer = Writer::new(send, version.clone()); let reader = Reader::new(recv, version); @@ -28,10 +26,7 @@ impl Stream { where V: Clone, { - let (send, recv) = session - .accept_bi() - .await - .map_err(|err| Error::Transport(Arc::new(err)))?; + let (send, recv) = session.accept_bi().await.map_err(Error::from_transport)?; let writer = Writer::new(send, version.clone()); let reader = Reader::new(recv, version); diff --git a/rs/moq-lite/src/coding/writer.rs b/rs/moq-lite/src/coding/writer.rs index f8519aa95..eb4d70b4c 100644 --- a/rs/moq-lite/src/coding/writer.rs +++ b/rs/moq-lite/src/coding/writer.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, sync::Arc}; +use std::fmt::Debug; use crate::{Error, coding::*}; @@ -33,7 +33,7 @@ impl Writer { .unwrap() .write_buf(&mut self.buffer) .await - .map_err(|e| Error::Transport(Arc::new(e)))?; + .map_err(Error::from_transport)?; } Ok(()) @@ -46,7 +46,7 @@ impl Writer { .unwrap() .write_buf(buf) .await - .map_err(|e| Error::Transport(Arc::new(e))) + .map_err(Error::from_transport) } /// Write the entire [Buf] to the stream. @@ -61,11 +61,7 @@ impl Writer { /// Mark the stream as finished. pub fn finish(&mut self) -> Result<(), Error> { - self.stream - .as_mut() - .unwrap() - .finish() - .map_err(|e| Error::Transport(Arc::new(e))) + self.stream.as_mut().unwrap().finish().map_err(Error::from_transport) } /// Abort the stream with the given error. @@ -80,7 +76,7 @@ impl Writer { .unwrap() .closed() .await - .map_err(|e| Error::Transport(Arc::new(e)))?; + .map_err(Error::from_transport)?; Ok(()) } diff --git a/rs/moq-lite/src/error.rs b/rs/moq-lite/src/error.rs index 2ce9f990e..fb0ec6f6f 100644 --- a/rs/moq-lite/src/error.rs +++ b/rs/moq-lite/src/error.rs @@ -1,30 +1,22 @@ -use std::sync::Arc; - use crate::coding; -use web_transport_trait::{MaybeSend, MaybeSync}; - -/// A trait that is Send+Sync except on WASM. -pub trait SendSyncError: std::error::Error + MaybeSend + MaybeSync {} - -impl SendSyncError for T where T: std::error::Error + MaybeSend + MaybeSync {} /// A list of possible errors that can occur during the session. #[derive(thiserror::Error, Debug, Clone)] #[non_exhaustive] pub enum Error { - #[error("transport error: {0}")] - Transport(Arc), + #[error("transport error")] + Transport, - #[error("decode error: {0}")] - Decode(#[from] coding::DecodeError), + #[error("decode error")] + Decode, // TODO move to a ConnectError - #[error("unsupported versions: client={0:?} server={1:?}")] - Version(coding::Versions, coding::Versions), + #[error("unsupported versions")] + Version, /// A required extension was not present - #[error("extension required: {0}")] - RequiredExtension(u64), + #[error("extension required")] + RequiredExtension, /// An unexpected stream type was received #[error("unexpected stream type")] @@ -32,7 +24,7 @@ pub enum Error { /// Some VarInt was too large and we were too lazy to handle it #[error("varint bounds exceeded")] - BoundsExceeded(#[from] coding::BoundsExceeded), + BoundsExceeded, /// A duplicate ID was used // The broadcast/track is a duplicate @@ -53,7 +45,7 @@ pub enum Error { // The application closes the stream with a code. #[error("app code={0}")] - App(u32), + App(u16), #[error("not found")] NotFound, @@ -91,15 +83,15 @@ impl Error { pub fn to_code(&self) -> u32 { match self { Self::Cancel => 0, - Self::RequiredExtension(_) => 1, + Self::RequiredExtension => 1, Self::Old => 2, Self::Timeout => 3, - Self::Transport(_) => 4, - Self::Decode(_) => 5, + Self::Transport => 4, + Self::Decode => 5, Self::Unauthorized => 6, - Self::Version(..) => 9, + Self::Version => 9, Self::UnexpectedStream => 10, - Self::BoundsExceeded(_) => 11, + Self::BoundsExceeded => 11, Self::Duplicate => 12, Self::NotFound => 13, Self::WrongSize => 14, @@ -110,8 +102,62 @@ impl Error { Self::TooManyParameters => 19, Self::InvalidRole => 20, Self::UnknownAlpn(_) => 21, - Self::App(app) => *app + 64, + Self::App(app) => *app as u32 + 64, + } + } + + /// Decode an error from a wire code. + pub fn from_code(code: u32) -> Self { + match code { + 0 => Self::Cancel, + 1 => Self::RequiredExtension, + 2 => Self::Old, + 3 => Self::Timeout, + 4 => Self::Transport, + 5 => Self::Decode, + 6 => Self::Unauthorized, + 9 => Self::Version, + 10 => Self::UnexpectedStream, + 11 => Self::BoundsExceeded, + 12 => Self::Duplicate, + 13 => Self::NotFound, + 14 => Self::WrongSize, + 15 => Self::ProtocolViolation, + 16 => Self::UnexpectedMessage, + 17 => Self::Unsupported, + 18 => Self::TooLarge, + 19 => Self::TooManyParameters, + 20 => Self::InvalidRole, + code if code >= 64 => match u16::try_from(code - 64) { + Ok(app) => Self::App(app), + Err(_) => Self::ProtocolViolation, + }, + _ => Self::ProtocolViolation, + } + } + + /// Convert a transport error into an [Error], decoding stream reset codes. + pub fn from_transport(err: impl web_transport_trait::Error) -> Self { + if let Some(code) = err.stream_error() { + return Self::from_code(code); } + + tracing::warn!(%err, "transport error"); + Self::Transport + } +} + +impl From for Error { + fn from(err: coding::DecodeError) -> Self { + tracing::warn!(%err, "decode error"); + Error::Decode + } +} + +impl From for Error { + fn from(err: coding::BoundsExceeded) -> Self { + tracing::warn!(%err, "bounds exceeded"); + Error::BoundsExceeded } } diff --git a/rs/moq-lite/src/ietf/control.rs b/rs/moq-lite/src/ietf/control.rs index c231d3f28..e95aee766 100644 --- a/rs/moq-lite/src/ietf/control.rs +++ b/rs/moq-lite/src/ietf/control.rs @@ -48,7 +48,7 @@ impl Control { tracing::trace!(id = T::ID, size = buf.len(), hex = %hex::encode(&buf), "encoded control message"); - self.tx.send(buf).map_err(|e| Error::Transport(Arc::new(e)))?; + self.tx.send(buf).map_err(|_| Error::Transport)?; Ok(()) } diff --git a/rs/moq-lite/src/ietf/publisher.rs b/rs/moq-lite/src/ietf/publisher.rs index 212c49dde..73acad71a 100644 --- a/rs/moq-lite/src/ietf/publisher.rs +++ b/rs/moq-lite/src/ietf/publisher.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use tokio::sync::oneshot; use web_async::{FuturesExt, Lock}; @@ -264,10 +264,7 @@ impl Publisher { version: Version, ) -> Result<(), Error> { // TODO add a way to open in priority order. - let mut stream = session - .open_uni() - .await - .map_err(|err| Error::Transport(Arc::new(err)))?; + let mut stream = session.open_uni().await.map_err(Error::from_transport)?; stream.set_priority(priority); let mut stream = Writer::new(stream, version); @@ -455,10 +452,7 @@ impl Publisher { // We literally just create a stream and FIN it. async fn run_fetch(session: S, request_id: RequestId, version: Version) -> Result<(), Error> { - let stream = session - .open_uni() - .await - .map_err(|err| Error::Transport(Arc::new(err)))?; + let stream = session.open_uni().await.map_err(Error::from_transport)?; let mut writer = Writer::new(stream, version); diff --git a/rs/moq-lite/src/ietf/session.rs b/rs/moq-lite/src/ietf/session.rs index 270c660f4..698db97f9 100644 --- a/rs/moq-lite/src/ietf/session.rs +++ b/rs/moq-lite/src/ietf/session.rs @@ -27,7 +27,7 @@ pub(crate) async fn start( ) .await { - Err(Error::Transport(_)) => { + Err(Error::Transport) => { tracing::info!("session terminated"); session.close(1, ""); } diff --git a/rs/moq-lite/src/ietf/subscriber.rs b/rs/moq-lite/src/ietf/subscriber.rs index f6ab97484..a8936e675 100644 --- a/rs/moq-lite/src/ietf/subscriber.rs +++ b/rs/moq-lite/src/ietf/subscriber.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{HashMap, hash_map::Entry}, - sync::Arc, -}; +use std::collections::{HashMap, hash_map::Entry}; use crate::{ Broadcast, Error, Frame, FrameProducer, Group, GroupProducer, OriginProducer, Path, PathOwned, Track, @@ -253,11 +250,7 @@ impl Subscriber { pub async fn run(self) -> Result<(), Error> { loop { - let stream = self - .session - .accept_uni() - .await - .map_err(|err| Error::Transport(Arc::new(err)))?; + let stream = self.session.accept_uni().await.map_err(Error::from_transport)?; let stream = Reader::new(stream, self.version); let this = self.clone(); @@ -385,7 +378,7 @@ impl Subscriber { }; match res { - Err(Error::Cancel) | Err(Error::Transport(_)) => { + Err(Error::Cancel) => { tracing::trace!(group = %producer.info.sequence, "group cancelled"); producer.abort(Error::Cancel); } diff --git a/rs/moq-lite/src/lite/publisher.rs b/rs/moq-lite/src/lite/publisher.rs index 4f0c90c65..893122989 100644 --- a/rs/moq-lite/src/lite/publisher.rs +++ b/rs/moq-lite/src/lite/publisher.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use futures::{FutureExt, StreamExt, stream::FuturesUnordered}; use web_async::FuturesExt; @@ -68,7 +66,7 @@ impl Publisher { Error::Cancel => { tracing::debug!(prefix = %origin.absolute(prefix), "announcing cancelled"); } - Error::Transport(_) => { + Error::Transport => { tracing::debug!(prefix = %origin.absolute(prefix), "announcing cancelled"); } err => { @@ -160,7 +158,7 @@ impl Publisher { { match &err { // TODO better classify WebTransport errors. - Error::Cancel | Error::Transport(_) => { + Error::Cancel | Error::Transport => { tracing::info!(%id, broadcast = %absolute, %track, "subscribed cancelled") } err => { @@ -250,10 +248,7 @@ impl Publisher { version: Version, ) -> Result<(), Error> { // TODO add a way to open in priority order. - let stream = session - .open_uni() - .await - .map_err(|err| Error::Transport(Arc::new(err)))?; + let stream = session.open_uni().await.map_err(Error::from_transport)?; let mut stream = Writer::new(stream, version); stream.set_priority(priority.current()); diff --git a/rs/moq-lite/src/lite/session.rs b/rs/moq-lite/src/lite/session.rs index 230b6925e..79226e398 100644 --- a/rs/moq-lite/src/lite/session.rs +++ b/rs/moq-lite/src/lite/session.rs @@ -32,7 +32,7 @@ pub(crate) async fn start( }; match res { - Err(Error::Transport(_)) => { + Err(Error::Transport) => { tracing::info!("session terminated"); session.close(1, ""); } diff --git a/rs/moq-lite/src/lite/subscriber.rs b/rs/moq-lite/src/lite/subscriber.rs index 41bbd3007..9ab5cf41b 100644 --- a/rs/moq-lite/src/lite/subscriber.rs +++ b/rs/moq-lite/src/lite/subscriber.rs @@ -45,11 +45,7 @@ impl Subscriber { async fn run_uni(self) -> Result<(), Error> { loop { - let stream = self - .session - .accept_uni() - .await - .map_err(|err| Error::Transport(Arc::new(err)))?; + let stream = self.session.accept_uni().await.map_err(Error::from_transport)?; let stream = Reader::new(stream, self.version); let this = self.clone(); @@ -191,7 +187,7 @@ impl Subscriber { }; match res { - Err(Error::Cancel) | Err(Error::Transport(_)) => { + Err(Error::Cancel) => { tracing::info!(id, broadcast = %self.log_path(&broadcast), track = %track.info.name, "subscribe cancelled"); track.abort(Error::Cancel); } @@ -252,7 +248,7 @@ impl Subscriber { }; match res { - Err(Error::Cancel) | Err(Error::Transport(_)) => { + Err(Error::Cancel) => { tracing::trace!(group = %group.info.sequence, "group cancelled"); group.abort(Error::Cancel); } diff --git a/rs/moq-lite/src/server.rs b/rs/moq-lite/src/server.rs index fde5ab204..6789d377d 100644 --- a/rs/moq-lite/src/server.rs +++ b/rs/moq-lite/src/server.rs @@ -72,7 +72,7 @@ impl Server { .iter() .flat_map(|v| Version::try_from(*v).ok()) .find(|v| supported.contains(v)) - .ok_or_else(|| Error::Version(client.versions.clone(), supported.into()))?; + .ok_or(Error::Version)?; // Only encode parameters if we're using the IETF draft because it has max_request_id let parameters = match version { diff --git a/rs/moq-lite/src/session.rs b/rs/moq-lite/src/session.rs index 5eec662b6..de661e77f 100644 --- a/rs/moq-lite/src/session.rs +++ b/rs/moq-lite/src/session.rs @@ -1,4 +1,4 @@ -use std::{future::Future, pin::Pin, sync::Arc}; +use std::{future::Future, pin::Pin}; use crate::Error; @@ -8,13 +8,13 @@ use crate::Error; /// - [`crate::Client::connect`] for clients. /// - [`crate::Server::accept`] for servers. pub struct Session { - session: Arc, + session: Box, } impl Session { pub(super) fn new(session: S) -> Self { Self { - session: Arc::new(session), + session: Box::new(session), } } @@ -26,8 +26,8 @@ impl Session { /// Block until the transport session is closed. // TODO Remove the Result the next time we make a breaking change. pub async fn closed(&self) -> Result<(), Error> { - let err = self.session.closed().await; - Err(Error::Transport(err)) + self.session.closed().await; + Err(Error::Transport) } } @@ -40,7 +40,7 @@ impl Drop for Session { // We use a wrapper type that is dyn-compatible to remove the generic bounds from Session. trait SessionInner: Send + Sync { fn close(&self, code: u32, reason: &str); - fn closed(&self) -> Pin> + Send + '_>>; + fn closed(&self) -> Pin + Send + '_>>; } impl SessionInner for S { @@ -48,7 +48,9 @@ impl SessionInner for S { S::close(self, code, reason); } - fn closed(&self) -> Pin> + Send + '_>> { - Box::pin(async move { Arc::new(S::closed(self).await) as Arc }) + fn closed(&self) -> Pin + Send + '_>> { + Box::pin(async move { + let _ = S::closed(self).await; + }) } }