diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index 2619ade44..69a33e5d4 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -44,7 +44,7 @@ jobs: # Direct prompt for automated review (no @claude mention needed) direct_prompt: | Please review this pull request and look for bugs and security issues. - Only report on bugs and potential vulnerabilities you find. Be concise. + Only report issues you find, otherwise give a thumbs up. Be concise! # Optional: Use sticky comments to make Claude reuse the same comment on subsequent pushes to the same PR # use_sticky_comment: true diff --git a/CLAUDE.md b/CLAUDE.md index c15345d79..f898211c6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -71,3 +71,11 @@ Key architectural rule: The CDN/relay must not know about application logic, med - Run `just check` to execute all tests and linting - Rust tests are integrated within source files + +## Contributing + +For first-time contributors looking for tasks to work on: + +- Check [TODO.md](./TODO.md) for security and performance issues that need to be addressed +- Many of these issues are well-scoped and great for getting familiar with the codebase +- Security-related tasks help improve the robustness of the MoQ protocol implementation diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..d514bf4b9 --- /dev/null +++ b/TODO.md @@ -0,0 +1,72 @@ +# TODO - Security & Performance Issues + +This file contains security and performance issues that need to be addressed. These are great tasks for first-time contributors to the MoQ project. + +## Security Issues + +### 🔒 DoS Protection & Rate Limiting + +- [ ] **Enforce maximum size for paths** - Add configurable limits for path string lengths to prevent memory exhaustion attacks +- [ ] **Enforce maximum number of active announcements** - Add configurable limit per session/connection to prevent announcement flooding +- [ ] **Enforce maximum number of subscriptions** - Currently implicit via MAX_STREAMS, make it configurable and explicit +- [ ] **Enforce maximum size for each frame** - Add configurable frame size limits to prevent large frame DoS attacks +- [ ] **Enforce maximum count of frames per group** - Limit frames per group to prevent unbounded memory allocation +- [ ] **Enforce cumulative maximums per session/IP/user** - Add aggregate limits across all connections from the same source + +### 🛡️ Input Validation & Bounds Checking + +- [ ] **Fix AnnounceInit decode DoS vector (Rust)** - Add hard limit check before processing count in `rs/moq/src/message/announce.rs:108-113` +- [ ] **Fix missing DoS protection (TypeScript)** - Add count limits in `js/moq/src/wire/announce.ts:62-67` +- [ ] **Fix prefix suffix handling bug** - Correct logic in `js/moq/src/publisher.ts:92-94` for proper hierarchical path handling +- [ ] **Add timeout protection for session initialization** - Prevent indefinite hangs in `rs/moq/src/session/mod.rs:64-66` + +### 🔍 Protocol Security + +- [ ] **Validate message sequence numbers** - Ensure monotonic ordering and detect replay attacks +- [ ] **Add authentication to sensitive operations** - Require proper auth for publish/announce operations +- [ ] **Implement proper error boundaries** - Prevent cascading failures from malformed messages +- [ ] **Add message rate limiting per connection** - Prevent control message flooding + +## Performance Issues + +### ⚡ Memory Management + +- [ ] **Implement bounded collections** - Replace unbounded Vec/Array usage with size-limited collections +- [ ] **Add memory pool for frequent allocations** - Reduce GC pressure in TypeScript and allocator pressure in Rust +- [ ] **Optimize string handling** - Use string interning for frequently used path names +- [ ] **Add configurable buffer sizes** - Make frame/group buffers configurable based on use case + +### 📊 Metrics & Observability + +- [ ] **Add connection health metrics** - Track bandwidth, latency, error rates per connection +- [ ] **Implement graceful degradation** - Reduce quality/features under resource pressure +- [ ] **Add resource usage monitoring** - Track memory, CPU, network usage per session +- [ ] **Log security events** - Audit log for rate limit violations, auth failures, etc. + +## Implementation Guidelines + +When working on these issues: + +1. **Security First**: Always validate inputs and add appropriate bounds checking +2. **Configurable Limits**: Make all limits configurable via environment variables or config files +3. **Backwards Compatibility**: Ensure changes don't break existing protocol compatibility +4. **Test Coverage**: Add tests for both normal operation and edge cases/attack scenarios +5. **Documentation**: Update protocol documentation and API docs for any changes +6. **Performance Testing**: Benchmark changes to ensure they don't introduce performance regressions + +## Getting Started + +New contributors should: + +1. Read the main [CLAUDE.md](./CLAUDE.md) for project setup and development guidelines +2. Run `just setup` to install dependencies +3. Run `just check` to ensure tests pass before making changes +4. Pick a single TODO item to work on +5. Create a PR with tests and documentation for your changes + +## Questions? + +For questions about these issues or implementation guidance, please: +- Open a GitHub issue with the `question` label +- Reference the specific TODO item you're asking about +- Include your proposed approach for discussion \ No newline at end of file diff --git a/js/moq/src/connection.ts b/js/moq/src/connection.ts index 6bcb491f5..c69bb5458 100644 --- a/js/moq/src/connection.ts +++ b/js/moq/src/connection.ts @@ -108,6 +108,9 @@ export class Connection { const conn = new Connection(adjustedUrl, quic, stream); + // The connection is now ready to use + // Note: ANNOUNCE_INIT will be handled when announce streams are actually requested + const cleanup = () => { conn.close(); }; diff --git a/js/moq/src/publisher.ts b/js/moq/src/publisher.ts index 9cd3afb61..39cb0c434 100644 --- a/js/moq/src/publisher.ts +++ b/js/moq/src/publisher.ts @@ -85,6 +85,20 @@ export class Publisher { async runAnnounce(msg: Wire.AnnounceInterest, stream: Wire.Stream) { const consumer = this.#announced.consume(msg.prefix); + // Send ANNOUNCE_INIT as the first message with all currently active paths + const activePaths: string[] = []; + for (const [name] of this.#broadcasts) { + if (name.startsWith(msg.prefix)) { + // Return suffix relative to prefix + const suffix = msg.prefix ? name.slice(msg.prefix.length + 1) : name; + activePaths.push(suffix); + } + } + + const init = new Wire.AnnounceInit(activePaths); + await init.encode(stream.writer); + + // Then send updates as they occur for (;;) { const announcement = await consumer.next(); if (!announcement) break; diff --git a/js/moq/src/subscriber.ts b/js/moq/src/subscriber.ts index 991f708fa..4be143cdf 100644 --- a/js/moq/src/subscriber.ts +++ b/js/moq/src/subscriber.ts @@ -44,6 +44,18 @@ export class Subscriber { try { const stream = await Wire.Stream.open(this.#quic, msg); + // First, receive ANNOUNCE_INIT + const init = await Wire.AnnounceInit.decode(stream.reader); + + // Process initial announcements + for (const path of init.paths) { + const full = prefix.concat(path); + console.debug(`announced: broadcast=${full} active=true`); + producer.write({ name: full, active: true }); + active.add(full); + } + + // Then receive updates for (;;) { const announce = await Wire.Announce.decode_maybe(stream.reader); if (!announce) { diff --git a/js/moq/src/wire/announce.ts b/js/moq/src/wire/announce.ts index 401401bff..01bc7e023 100644 --- a/js/moq/src/wire/announce.ts +++ b/js/moq/src/wire/announce.ts @@ -43,3 +43,27 @@ export class AnnounceInterest { return new AnnounceInterest(prefix); } } + +export class AnnounceInit { + paths: string[]; + + constructor(paths: string[]) { + this.paths = paths; + } + + async encode(w: Writer) { + await w.u53(this.paths.length); + for (const path of this.paths) { + await w.string(path); + } + } + + static async decode(r: Reader): Promise { + const count = await r.u53(); + const paths: string[] = []; + for (let i = 0; i < count; i++) { + paths.push(await r.string()); + } + return new AnnounceInit(paths); + } +} diff --git a/js/moq/src/wire/session.ts b/js/moq/src/wire/session.ts index 5d3e4e163..36687240e 100644 --- a/js/moq/src/wire/session.ts +++ b/js/moq/src/wire/session.ts @@ -11,9 +11,10 @@ export const Version = { FORK_03: 0xff0bad03, FORK_04: 0xff0bad04, LITE_00: 0xff0dad00, + LITE_01: 0xff0dad01, } as const; -export const CURRENT_VERSION = Version.LITE_00; +export const CURRENT_VERSION = Version.LITE_01; export class Extensions { entries: Map; diff --git a/rs/Cargo.lock b/rs/Cargo.lock index 9610e4b4d..5b82b0552 100644 --- a/rs/Cargo.lock +++ b/rs/Cargo.lock @@ -2096,9 +2096,9 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha", "rand_core", @@ -2138,9 +2138,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.13" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" +checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ "bitflags", ] diff --git a/rs/hang-cli/src/client.rs b/rs/hang-cli/src/client.rs index d29348bc2..cb1f3c075 100644 --- a/rs/hang-cli/src/client.rs +++ b/rs/hang-cli/src/client.rs @@ -35,13 +35,11 @@ async fn connect( // Create an origin producer to publish to the broadcast. let mut publisher = moq_lite::OriginProducer::default(); + publisher.publish(&name, consumer.inner.clone()); // Establish the connection, not providing a subscriber. let session = moq_lite::Session::connect(session, publisher.consume_all(), None).await?; - // Publish the broadcast using the origin producer directly. - publisher.publish(&name, consumer.inner.clone()); - tokio::select! { // On ctrl-c, close the session and exit. _ = tokio::signal::ctrl_c() => { diff --git a/rs/hang-cli/src/server.rs b/rs/hang-cli/src/server.rs index 66399bc67..676859eb1 100644 --- a/rs/hang-cli/src/server.rs +++ b/rs/hang-cli/src/server.rs @@ -73,6 +73,7 @@ async fn run_session( // Create an origin producer to publish to the broadcast. let mut publisher = moq_lite::OriginProducer::default(); + publisher.publish(&name, consumer.inner.clone()); let session = moq_lite::Session::accept(session, publisher.consume_all(), None) .await @@ -80,8 +81,6 @@ async fn run_session( tracing::info!(?id, "accepted session"); - publisher.publish(&name, consumer.inner.clone()); - Err(session.closed().await.into()) } diff --git a/rs/hang-gst/src/sink/imp.rs b/rs/hang-gst/src/sink/imp.rs index 3b220c907..e41aa9cc6 100644 --- a/rs/hang-gst/src/sink/imp.rs +++ b/rs/hang-gst/src/sink/imp.rs @@ -171,14 +171,15 @@ impl HangSink { let session = client.connect(url.clone()).await.expect("failed to connect"); let mut publisher = moq_lite::OriginProducer::default(); - let _session = moq_lite::Session::connect(session, publisher.consume_all(), None) - .await - .expect("failed to connect"); let broadcast = hang::BroadcastProducer::new(); let name = settings.broadcast.as_ref().expect("broadcast is required"); publisher.publish(name, broadcast.consume().inner); + let _session = moq_lite::Session::connect(session, publisher.consume_all(), None) + .await + .expect("failed to connect"); + let media = hang::cmaf::Import::new(broadcast); let mut state = self.state.lock().unwrap(); diff --git a/rs/hang-gst/src/source/imp.rs b/rs/hang-gst/src/source/imp.rs index ad872687e..d18641720 100644 --- a/rs/hang-gst/src/source/imp.rs +++ b/rs/hang-gst/src/source/imp.rs @@ -136,6 +136,11 @@ impl ElementImpl for HangSrc { gst::error!(CAT, obj = self.obj(), "Failed to setup: {:?}", e); return Err(gst::StateChangeError); } + // Chain up first to let the bin handle the state change + let result = self.parent_change_state(transition); + result?; + // This is a live source - no preroll needed + return Ok(gst::StateChangeSuccess::NoPreroll); } gst::StateChange::PausedToReady => { @@ -146,7 +151,7 @@ impl ElementImpl for HangSrc { _ => (), } - // Chain up + // Chain up for other transitions self.parent_change_state(transition) } } @@ -175,18 +180,14 @@ impl HangSrc { let origin = moq_lite::OriginProducer::default(); let _session = moq_lite::Session::connect(session, None, origin.clone()).await?; - // TODO giant hack to avoid a race condition with how announcements are now populated. - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - // Wait for the broadcast to be announced (race condition workaround) - let broadcast = origin.consume(&name).expect("broadcast not found"); + let broadcast = origin + .consume(&name) + .ok_or_else(|| anyhow::anyhow!("Broadcast '{}' not found", name))?; let mut broadcast = hang::BroadcastConsumer::new(broadcast); // TODO handle catalog updates let catalog = broadcast.catalog.next().await?.context("no catalog found")?.clone(); - gst::info!(CAT, "catalog: {:?}", catalog); - for video in catalog.video { let mut track = broadcast.subscribe(&video.track); diff --git a/rs/justfile b/rs/justfile index 1b6786ca5..8427cb05e 100644 --- a/rs/justfile +++ b/rs/justfile @@ -111,6 +111,7 @@ pub-gst name url: # Run gstreamer and pipe the output to our plugin GST_PLUGIN_PATH="${PWD}/target/debug${GST_PLUGIN_PATH:+:$GST_PLUGIN_PATH}" \ + GST_DEBUG="hangsink:4" \ gst-launch-1.0 -v -e multifilesrc location="dev/{{name}}.fmp4" loop=true ! qtdemux name=demux \ demux.video_0 ! h264parse ! queue ! identity sync=true ! isofmp4mux name=mux chunk-duration=1 fragment-duration=1 ! \ hangsink url="{{url}}" tls-disable-verify=true broadcast="{{name}}" \ @@ -124,6 +125,7 @@ sub name url: # Run gstreamer and pipe the output to our plugin # This will render the video to the screen GST_PLUGIN_PATH="${PWD}/target/debug${GST_PLUGIN_PATH:+:$GST_PLUGIN_PATH}" \ + GST_DEBUG="hangsrc:4" \ gst-launch-1.0 -v -e hangsrc url="{{url}}" broadcast="{{name}}" tls-disable-verify=true ! decodebin ! videoconvert ! autovideosink # Publish a video using ffmpeg directly from hang to the localhost diff --git a/rs/moq-clock/src/main.rs b/rs/moq-clock/src/main.rs index b9ba57e0e..185e5e0cf 100644 --- a/rs/moq-clock/src/main.rs +++ b/rs/moq-clock/src/main.rs @@ -62,11 +62,10 @@ async fn main() -> anyhow::Result<()> { let clock = clock::Publisher::new(track); let mut publisher = moq_lite::OriginProducer::default(); - let session = moq_lite::Session::connect(session, publisher.consume_all(), None).await?; - - // Publish the broadcast - the broadcast name is empty because the URL contains the name publisher.publish(&config.broadcast, broadcast.consume()); + let session = moq_lite::Session::connect(session, publisher.consume_all(), None).await?; + tokio::select! { res = session.closed() => Err(res.into()), _ = clock.run() => Ok(()), diff --git a/rs/moq/src/error.rs b/rs/moq/src/error.rs index 6d46bafad..ac3e5ab93 100644 --- a/rs/moq/src/error.rs +++ b/rs/moq/src/error.rs @@ -54,9 +54,6 @@ pub enum Error { #[error("protocol violation")] ProtocolViolation, - - #[error("unauthorized")] - Unauthorized, } impl Error { @@ -76,7 +73,6 @@ impl Error { Self::NotFound => 13, Self::WrongSize => 14, Self::ProtocolViolation => 15, - Self::Unauthorized => 16, Self::App(app) => *app + 64, } } diff --git a/rs/moq/src/message/announce.rs b/rs/moq/src/message/announce.rs index 64c38306c..b9096e97f 100644 --- a/rs/moq/src/message/announce.rs +++ b/rs/moq/src/message/announce.rs @@ -50,19 +50,19 @@ impl Encode for Announce { /// Sent by the subscriber to request ANNOUNCE messages. #[derive(Clone, Debug)] -pub struct AnnounceRequest { +pub struct AnnouncePlease { // Request tracks with this prefix. pub prefix: Path, } -impl Decode for AnnounceRequest { +impl Decode for AnnouncePlease { fn decode(r: &mut R) -> Result { let prefix = Path::decode(r)?; Ok(Self { prefix }) } } -impl Encode for AnnounceRequest { +impl Encode for AnnouncePlease { fn encode(&self, w: &mut W) { self.prefix.encode(w) } @@ -92,3 +92,35 @@ impl Encode for AnnounceStatus { (*self as u8).encode(w) } } + +/// Sent after setup to communicate the initially announced paths. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct AnnounceInit { + /// List of currently active broadcasts, encoded as suffixes to be combined with the prefix. + pub suffixes: Vec, +} + +impl Decode for AnnounceInit { + fn decode(r: &mut R) -> Result { + let count = u64::decode(r)?; + + // Don't allocate more than 1024 elements upfront + let mut paths = Vec::with_capacity(count.min(1024) as usize); + + for _ in 0..count { + paths.push(Path::decode(r)?); + } + + Ok(Self { suffixes: paths }) + } +} + +impl Encode for AnnounceInit { + fn encode(&self, w: &mut W) { + (self.suffixes.len() as u64).encode(w); + for path in &self.suffixes { + path.encode(w); + } + } +} diff --git a/rs/moq/src/message/versions.rs b/rs/moq/src/message/versions.rs index b8e34a259..3f54930aa 100644 --- a/rs/moq/src/message/versions.rs +++ b/rs/moq/src/message/versions.rs @@ -38,8 +38,9 @@ impl Version { pub const FORK_04: Version = Version(0xff0bad04); pub const LITE_00: Version = Version(0xff0dad00); + pub const LITE_01: Version = Version(0xff0dad01); - pub const CURRENT: Version = Version::LITE_00; + pub const CURRENT: Version = Version::LITE_01; } /// A version number negotiated during the setup. @@ -48,7 +49,8 @@ pub struct Alpn(pub &'static str); impl Alpn { pub const LITE_00: Alpn = Alpn("moql-00"); - pub const CURRENT: Alpn = Alpn::LITE_00; + pub const LITE_01: Alpn = Alpn("moql-01"); + pub const CURRENT: Alpn = Alpn::LITE_01; } impl From for Version { diff --git a/rs/moq/src/model/origin.rs b/rs/moq/src/model/origin.rs index 096064562..e60f5c5b5 100644 --- a/rs/moq/src/model/origin.rs +++ b/rs/moq/src/model/origin.rs @@ -304,6 +304,7 @@ impl OriginConsumer { /// Get a specific broadcast by path. /// /// This is relative to the consumer's prefix. + /// Returns None if the path hasn't been announced yet. pub fn consume<'a>(&self, suffix: impl Into>) -> Option { let path = self.prefix.join(suffix.into()); diff --git a/rs/moq/src/session/mod.rs b/rs/moq/src/session/mod.rs index 6c1803c33..a232cc96a 100644 --- a/rs/moq/src/session/mod.rs +++ b/rs/moq/src/session/mod.rs @@ -10,41 +10,40 @@ use publisher::*; use reader::*; use stream::*; use subscriber::*; +use tokio::sync::oneshot; use writer::*; -/// A MoQ session, constructed with [Publisher] and [Subscriber] halves. +/// A MoQ session, constructed with [OriginProducer] and [OriginConsumer] halves. /// /// This simplifies the state machine and immediately rejects any subscriptions that don't match the origin prefix. /// You probably want to use [Session] unless you're writing a relay. -#[derive(Clone)] pub struct Session { pub webtransport: web_transport::Session, } impl Session { - fn new( + async fn new( mut session: web_transport::Session, stream: Stream, // We will publish any local broadcasts from this origin. publish: Option, // We will consume any remote broadcasts, inserting them into this origin. subscribe: Option, - ) -> Self { - let publisher = SessionPublisher::new(session.clone(), publish); - let subscriber = SessionSubscriber::new(session.clone()); + ) -> Result { + let publisher = Publisher::new(session.clone(), publish); + let subscriber = Subscriber::new(session.clone(), subscribe); let this = Self { webtransport: session.clone(), }; + let init = oneshot::channel(); + web_async::spawn(async move { let res = tokio::select! { res = Self::run_session(stream) => res, - res = Self::run_bi(session.clone(), publisher.clone()) => res, - res = Self::run_uni(session.clone(), subscriber.clone()) => res, - //res = publisher.run() => res, - // Ignore Ok (unused) or when subscribe is None. - Some(Err(res)) = async move { Some(subscriber.run(subscribe?).await) } => Err(res), + res = publisher.run() => res, + res = subscriber.run(init.0) => res, }; match res { @@ -63,23 +62,26 @@ impl Session { } }); - this + // Wait until receiving the initial announcements to prevent some race conditions. + // Otherwise, `consume()` might return not found if we don't wait long enough, so just wait. + // If the announce stream fails or is closed, this will return an error instead of hanging. + // TODO return a better error + init.1.await.map_err(|_| Error::Cancel)?; + + Ok(this) } /// Perform the MoQ handshake as a client. - pub async fn connect< - T: Into, - P: Into>, - C: Into>, - >( - session: T, - publish: P, - subscribe: C, + pub async fn connect( + session: impl Into, + publish: impl Into>, + subscribe: impl Into>, ) -> Result { let mut session = session.into(); let mut stream = Stream::open(&mut session, message::ControlType::Session).await?; Self::connect_setup(&mut stream).await?; - Ok(Self::new(session, stream, publish.into(), subscribe.into())) + let session = Self::new(session, stream, publish.into(), subscribe.into()).await?; + Ok(session) } async fn connect_setup(setup: &mut Stream) -> Result<(), Error> { @@ -115,7 +117,8 @@ impl Session { } Self::accept_setup(&mut stream).await?; - Ok(Self::new(session, stream, publish.into(), subscribe.into())) + let session = Self::new(session, stream, publish.into(), subscribe.into()).await?; + Ok(session) } async fn accept_setup(control: &mut Stream) -> Result<(), Error> { @@ -137,63 +140,12 @@ impl Session { Ok(()) } + // TODO do something useful with this async fn run_session(mut stream: Stream) -> Result<(), Error> { while let Some(_info) = stream.reader.decode_maybe::().await? {} Err(Error::Cancel) } - async fn run_uni(mut session: web_transport::Session, subscriber: SessionSubscriber) -> Result<(), Error> { - loop { - let stream = Reader::accept(&mut session).await?; - let subscriber = subscriber.clone(); - - web_async::spawn(async move { - Self::run_data(stream, subscriber).await.ok(); - }); - } - } - - async fn run_data(mut stream: Reader, mut subscriber: SessionSubscriber) -> Result<(), Error> { - let kind = stream.decode().await?; - - let res = match kind { - message::DataType::Group => subscriber.recv_group(&mut stream).await, - }; - - if let Err(err) = res { - stream.abort(&err); - } - - Ok(()) - } - - async fn run_bi(mut session: web_transport::Session, publisher: SessionPublisher) -> Result<(), Error> { - loop { - let stream = Stream::accept(&mut session).await?; - let publisher = publisher.clone(); - - web_async::spawn(async move { - Self::run_control(stream, publisher).await.ok(); - }); - } - } - - async fn run_control(mut stream: Stream, mut publisher: SessionPublisher) -> Result<(), Error> { - let kind = stream.reader.decode().await?; - - let res = match kind { - message::ControlType::Session => Err(Error::UnexpectedStream(kind)), - message::ControlType::Announce => publisher.recv_announce(&mut stream).await, - message::ControlType::Subscribe => publisher.recv_subscribe(&mut stream).await, - }; - - if let Err(err) = &res { - stream.writer.abort(err); - } - - res - } - /// Close the underlying WebTransport session. pub fn close(mut self, err: Error) { self.webtransport.close(err.to_code(), &err.to_string()); diff --git a/rs/moq/src/session/publisher.rs b/rs/moq/src/session/publisher.rs index 900934d79..a316fe352 100644 --- a/rs/moq/src/session/publisher.rs +++ b/rs/moq/src/session/publisher.rs @@ -1,57 +1,93 @@ +use futures::FutureExt; use web_async::FuturesExt; -use crate::{message, model::GroupConsumer, Error, OriginConsumer, OriginUpdate, Path, Track, TrackConsumer}; +use crate::{ + message, model::GroupConsumer, Error, OriginConsumer, OriginProducer, OriginUpdate, Path, Track, TrackConsumer, +}; use super::{Stream, Writer}; #[derive(Clone)] -pub(super) struct SessionPublisher { +pub(super) struct Publisher { session: web_transport::Session, - // If None, then error on every request. - origin: Option, + origin: OriginConsumer, } -impl SessionPublisher { +impl Publisher { pub fn new(session: web_transport::Session, origin: Option) -> Self { + // Create a dummy origin that is immediately closed. + let origin = origin.unwrap_or_else(|| OriginProducer::default().consume_all()); Self { session, origin } } - /* pub async fn run(self) -> Result<(), Error> { - let origin = match self.origin { - Some(origin) => origin, - None => return Ok(()), - }; + // TODO block on origin.closed() + self.run_bi().await + } - // TODO await origin.closed() + async fn run_bi(mut self) -> Result<(), Error> { + loop { + let stream = Stream::accept(&mut self.session).await?; + + let this = self.clone(); + web_async::spawn(async move { + this.run_control(stream).await.ok(); + }); + } } - */ - pub async fn recv_announce(&mut self, stream: &mut Stream) -> Result<(), Error> { - let interest = stream.reader.decode::().await?; + async fn run_control(self, mut stream: Stream) -> Result<(), Error> { + let kind = stream.reader.decode().await?; - // Just for logging the fully qualified prefix. - let prefix = match self.origin.as_ref() { - Some(origin) => origin.prefix().join(&interest.prefix), - None => Path::new("unauthorized").join(&interest.prefix), + let res = match kind { + message::ControlType::Session => Err(Error::UnexpectedStream(kind)), + message::ControlType::Announce => self.recv_announce(&mut stream).await, + message::ControlType::Subscribe => self.recv_subscribe(&mut stream).await, }; - tracing::debug!(%prefix, "announce started"); + if let Err(err) = &res { + stream.writer.abort(err); + } + + res + } + + pub async fn recv_announce(mut self, stream: &mut Stream) -> Result<(), Error> { + let interest = stream.reader.decode::().await?; + + // Just for logging the fully qualified prefix. + let prefix = self.origin.prefix().join(&interest.prefix); let res = self.run_announce(stream, &interest.prefix).await; match res { - Err(Error::Cancel) => tracing::debug!(%prefix, "announce cancelled"), - Err(err) => tracing::debug!(?err, %prefix, "announce error"), - _ => tracing::trace!(%prefix, "announce complete"), + Err(Error::Cancel) => tracing::debug!(%prefix, "announcing cancelled"), + Err(err) => tracing::debug!(?err, %prefix, "announcing error"), + _ => tracing::trace!(%prefix, "announcing complete"), } Ok(()) } async fn run_announce(&mut self, stream: &mut Stream, prefix: &Path) -> Result<(), Error> { - let origin = self.origin.as_ref().ok_or(Error::Unauthorized)?; + let mut announced = self.origin.consume_prefix(prefix); - let mut announced = origin.consume_prefix(prefix); + let mut init = Vec::new(); + + // Send ANNOUNCE_INIT as the first message with all currently active paths + // We use `now_or_never` so `announced` keeps track of what has been sent for us. + while let Some(Some(OriginUpdate { suffix, active })) = announced.next().now_or_never() { + if active.is_some() { + tracing::debug!(broadcast = %prefix.join(&suffix), "announce"); + init.push(suffix); + } else { + // A potential race. + tracing::debug!(broadcast = %prefix.join(&suffix), "unannounce"); + init.retain(|path| path != &suffix); + } + } + + let announce_init = message::AnnounceInit { suffixes: init }; + stream.writer.encode(&announce_init).await?; // Flush any synchronously announced paths loop { @@ -78,7 +114,7 @@ impl SessionPublisher { } } - pub async fn recv_subscribe(&mut self, stream: &mut Stream) -> Result<(), Error> { + pub async fn recv_subscribe(mut self, stream: &mut Stream) -> Result<(), Error> { let mut subscribe = stream.reader.decode::().await?; tracing::debug!(id = %subscribe.id, broadcast = %subscribe.broadcast, track = %subscribe.track, "subscribed started"); @@ -101,14 +137,12 @@ impl SessionPublisher { } async fn run_subscribe(&mut self, stream: &mut Stream, subscribe: &mut message::Subscribe) -> Result<(), Error> { - let origin = self.origin.as_ref().ok_or(Error::Unauthorized)?; - let track = Track { name: subscribe.track.clone(), priority: subscribe.priority, }; - let broadcast = origin.consume(&subscribe.broadcast).ok_or(Error::NotFound)?; + let broadcast = self.origin.consume(&subscribe.broadcast).ok_or(Error::NotFound)?; let track = broadcast.subscribe(&track); // TODO wait until track.info() to get the *real* priority @@ -250,7 +284,7 @@ impl SessionPublisher { // But even with a group per frame, it will take ~6 days to reach that point. // TODO The behavior when two tracks share the same priority is undefined. Should we round-robin? fn stream_priority(track_priority: u8, group_sequence: u64) -> i32 { - let sequence = (0xFFFFFF - group_sequence as u32) & 0xFFFFFF; + let sequence = 0xFFFFFF - (group_sequence as u32 & 0xFFFFFF); ((track_priority as i32) << 24) | sequence as i32 } } @@ -262,10 +296,7 @@ mod test { #[test] fn stream_priority() { let assert = |track_priority, group_sequence, expected| { - assert_eq!( - SessionPublisher::stream_priority(track_priority, group_sequence), - expected - ); + assert_eq!(Publisher::stream_priority(track_priority, group_sequence), expected); }; const U24: i32 = (1 << 24) - 1; diff --git a/rs/moq/src/session/subscriber.rs b/rs/moq/src/session/subscriber.rs index ea534e02b..e82719039 100644 --- a/rs/moq/src/session/subscriber.rs +++ b/rs/moq/src/session/subscriber.rs @@ -8,49 +8,94 @@ use crate::{ TrackProducer, }; +use tokio::sync::oneshot; use web_async::{spawn, Lock}; use super::{Reader, Stream}; #[derive(Clone)] -pub(super) struct SessionSubscriber { +pub(super) struct Subscriber { session: web_transport::Session, + origin: Option, broadcasts: Lock>, subscribes: Lock>, next_id: Arc, } -impl SessionSubscriber { - pub fn new(session: web_transport::Session) -> Self { +impl Subscriber { + pub fn new(session: web_transport::Session, origin: Option) -> Self { Self { session, + origin, broadcasts: Default::default(), subscribes: Default::default(), next_id: Default::default(), } } - pub async fn run(self, origin: OriginProducer) -> Result<(), Error> { - let closed = origin.clone(); - - // Wait until the producer is no longer needed or the stream is closed. + /// Send a signal when the subscriber is initialized. + pub async fn run(self, init: oneshot::Sender<()>) -> Result<(), Error> { tokio::select! { - biased; // avoid run_inner if we're already unused - // Nobody wants to consume from this origin anymore. - _ = closed.unused() => Err(Error::Cancel), - res = self.run_inner(origin) => res, + Err(err) = self.clone().run_announce(init) => Err(err), + res = self.run_uni() => res, + } + } + + async fn run_uni(mut self) -> Result<(), Error> { + loop { + let stream = Reader::accept(&mut self.session).await?; + let this = self.clone(); + + web_async::spawn(async move { + this.run_uni_stream(stream).await.ok(); + }); } } - async fn run_inner(mut self, mut origin: OriginProducer) -> Result<(), Error> { + async fn run_uni_stream(mut self, mut stream: Reader) -> Result<(), Error> { + let kind = stream.decode().await?; + + let res = match kind { + message::DataType::Group => self.recv_group(&mut stream).await, + }; + + if let Err(err) = res { + stream.abort(&err); + } + + Ok(()) + } + + async fn run_announce(mut self, init: oneshot::Sender<()>) -> Result<(), Error> { + // Don't do anything if there's no origin configured. + if self.origin.is_none() { + let _ = init.send(()); + return Ok(()); + } + let mut stream = Stream::open(&mut self.session, message::ControlType::Announce).await?; - let msg = message::AnnounceRequest { prefix: "".into() }; + let msg = message::AnnouncePlease { prefix: "".into() }; stream.writer.encode(&msg).await?; let mut producers = HashMap::new(); + let msg: message::AnnounceInit = stream.reader.decode().await?; + for path in msg.suffixes { + tracing::debug!(broadcast = %path, "received announce"); + + let producer = BroadcastProducer::new(); + let consumer = producer.consume(); + + self.origin.as_mut().unwrap().publish(&path, consumer); + producers.insert(path.clone(), producer.clone()); + + spawn(self.clone().run_broadcast(path, producer)); + } + + let _ = init.send(()); + while let Some(announce) = stream.reader.decode_maybe::().await? { match announce { message::Announce::Active { suffix: path } => { @@ -60,7 +105,7 @@ impl SessionSubscriber { let consumer = producer.consume(); // Run the broadcast in the background until all consumers are dropped. - origin.publish(&path, consumer); + self.origin.as_mut().unwrap().publish(&path, consumer); producers.insert(path.clone(), producer.clone()); spawn(self.clone().run_broadcast(path, producer)); @@ -75,7 +120,7 @@ impl SessionSubscriber { } } - // Close the writer. + // Close the stream when there's nothing more to announce. stream.writer.finish().await }