diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 943f7fa9..8009e359 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -8,10 +8,11 @@ on: env: RUST_BACKTRACE: 1 toolchain_style: stable - toolchain_msrv: 1.57 - toolchain_doc: nightly-2022-12-01 + toolchain_msrv: 1.63 + toolchain_h3_quinn_msrv: 1.66 + toolchain_doc: nightly-2023-10-21 toolchain_lint: stable - toolchain_fuzz: nightly-2022-12-01 + toolchain_fuzz: nightly-2023-10-21 jobs: ci-pass: @@ -21,6 +22,7 @@ jobs: - style - lint - msrv + - msrv_h3_quinn - test - doc - fuzz @@ -77,11 +79,30 @@ jobs: toolchain: ${{ env.toolchain_msrv }} override: true - uses: Swatinem/rust-cache@v2 - - name: cargo check --workspace + - name: cargo check -p h3 uses: actions-rs/cargo@v1 with: command: check - args: --workspace + args: -p h3 + + msrv_h3_quinn: + name: Check MSRV of `h3-quinn` + needs: [style] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Rust ${{ env.toolchain_h3_quinn_msrv }} + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ env.toolchain_h3_quinn_msrv }} + override: true + - uses: Swatinem/rust-cache@v2 + - name: cargo check -p h3-quinn + uses: actions-rs/cargo@v1 + with: + command: check + args: -p h3-quinn test: name: Test ${{ matrix.toolchain }} ${{ matrix.os }} @@ -118,7 +139,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: ${{ env.toolchain_doc }} + toolchain: ${{ env.toolchain_doc }} override: true - uses: Swatinem/rust-cache@v2 - name: cargo rustdoc -p h3 -- -D intra-doc-link-resolution-failure diff --git a/.gitignore b/.gitignore index c1e5bc64..973db0e1 100755 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,9 @@ target/ # macOS .DS_Store ._* + +# Example Certificates +localhost-key.pem +localhost.crt +localhost.key +localhost.pem diff --git a/Cargo.toml b/Cargo.toml index 89a45c59..355c5a1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,9 @@ members = [ "h3", "h3-quinn", + "h3-webtransport", # Internal "examples", ] +resolver = "2" diff --git a/changelog-h3-quinn.md b/changelog-h3-quinn.md index 61c19dcd..1a6a3c5c 100644 --- a/changelog-h3-quinn.md +++ b/changelog-h3-quinn.md @@ -1,3 +1,27 @@ +### v0.0.6 (2024-05-20) +* use quinn 0.11 + +### v0.0.5 (2024-01-24) + +* Update to `http` v1. + +### v0.0.4 (2023-10-23) + +* Fix `cargo doc` warning ([3ef7c1a](https://github.com/hyperium/h3/commit/3ef7c1a37b635e8446322d8f8d3a68580a208ad8)) +* Initial WebTransport support ([22da938](https://github.com/hyperium/h3/commit/22da9387f19d724852b3bf1dfd7e66f0fd45cb81)) + + +### v0.0.3 (2023-05-16) + +* Update dependencies (quinn 0.10 and rustls 0.21) ([fabf614](https://github.com/hyperium/h3/commit/fabf6149c9ee57d7803ea5fb4426f895cbb5d244)) + +### v0.0.2 (2023-04-11) + +#### Bug Fixes + +* support quinn 0.9 ([49301f1](https://github.com/hyperium/h3/commit/49301f18e15d3acffc2a8d8bea1a8038c5f3fe6d)) + + ### v0.0.1 (2023-03-09) -initial release \ No newline at end of file +initial release diff --git a/changelog-h3.md b/changelog-h3.md index 61c19dcd..34febf67 100644 --- a/changelog-h3.md +++ b/changelog-h3.md @@ -1,3 +1,33 @@ +### v0.0.5 (2024-05-20) +* add `poll_recv_data()` for server +* use 2021 edition +* some cleanups + +### v0.0.4 (2024-01-24) + +* Update to `http` v1.0 +* Fix `try_recv` potentially hanging +* Fix prefixed integers on 32bit targets + +### v0.0.3 (2023-10-23) + +* Split out a `Settings` struct from `Config` ([a57ed22](https://github.com/hyperium/h3/commit/a57ed224ac5d17a635eb71eb6f83c1196f581a51)) +* Add a test-only send_settings config option ([3991dca](https://github.com/hyperium/h3/commit/3991dcaf3801595e49d0bb7fb1649b4cf50292b7)) +* Expose setting to disable grease ([dccb3cd](https://github.com/hyperium/h3/commit/dccb3cdae9d5a9d720fae5f774b53f0bd8a16019)) +* bugfix: Actually encode extensions in header ([a38b194](https://github.com/hyperium/h3/commit/a38b194a2f00dc0b2b60564c299093204d349d7e)) +* Initial support for RFC 9298 "Proxying UDP in HTTP" ([5a87580](https://github.com/hyperium/h3/commit/5a87580bd060b6a7d4dc625e990526d6998fda5c)) +* Bump H3_DATAGRAM setting ID according to RFC9297 ([58c8e5c](https://github.com/hyperium/h3/commit/58c8e5cecb2b0c367d738989fe9c505936bc5ff3)) +* Fix `cargo doc` warning ([3ef7c1a](https://github.com/hyperium/h3/commit/3ef7c1a37b635e8446322d8f8d3a68580a208ad8)) +* Initial WebTransport support (in h3 is just some necessary code to support a WebTransport crate which contains most of the WebTransport implementation) ([22da938](https://github.com/hyperium/h3/commit/22da9387f19d724852b3bf1dfd7e66f0fd45cb81)) + + +### v0.0.2 (2023-04-11) + +#### Bug Fixes + +* distinguish push and stream ids ([da29aea](https://github.com/hyperium/h3/commit/da29aea305d61146664189346b3718458cb9f4d6)) + + ### v0.0.1 (2023-03-09) -initial release \ No newline at end of file +initial release diff --git a/ci/compliance/specs/rfc9114/todo/4.2.toml b/ci/compliance/specs/rfc9114/todo/4.2.toml index d6b4eaa9..a37908c8 100644 --- a/ci/compliance/specs/rfc9114/todo/4.2.toml +++ b/ci/compliance/specs/rfc9114/todo/4.2.toml @@ -1,12 +1,5 @@ target = "https://www.rfc-editor.org/rfc/rfc9114#section-4.2" -[[TODO]] -quote = ''' -A request or -response containing uppercase characters in field names MUST be -treated as malformed. -''' - [[TODO]] quote = ''' An endpoint MUST NOT generate diff --git a/examples/Cargo.toml b/examples/Cargo.toml index bdbad438..6cdf5feb 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -2,24 +2,42 @@ name = "examples" version = "0.0.0" publish = false -edition = "2018" +edition = "2021" # If you copy one of the examples into a new project, you should be using # [dependencies] instead. [dev-dependencies] +anyhow = "1.0" bytes = "1" futures = "0.3" h3 = { path = "../h3" } h3-quinn = { path = "../h3-quinn" } -http = "0.2" -quinn = { version = "0.8", default-features = false, features = ["tls-rustls", "ring"] } -rcgen = { version = "0.9" } -rustls = { version = "0.20", features = ["dangerous_configuration"] } -rustls-native-certs = "0.6" +h3-webtransport = { path = "../h3-webtransport" } +http = "1" +quinn = { version = "0.11", default-features = false, features = [ + "runtime-tokio", + "rustls", + "ring", +] } +rcgen = { version = "0.13" } +rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std"] } +rustls-native-certs = "0.7" structopt = "0.3" -tokio = { version = "1.6", features = ["full"] } -tracing = "0.1.10" -tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter", "time", "tracing-log"] } +tokio = { version = "1.27", features = ["full"] } +tracing = "0.1.37" +tracing-subscriber = { version = "0.3", default-features = false, features = [ + "fmt", + "ansi", + "env-filter", + "time", + "tracing-log", +] } +octets = "0.3.0" + +tracing-tree = { version = "0.3" } + +[features] +tree = [] [[example]] name = "client" @@ -28,3 +46,7 @@ path = "client.rs" [[example]] name = "server" path = "server.rs" + +[[example]] +name = "webtransport_server" +path = "webtransport_server.rs" diff --git a/examples/client.rs b/examples/client.rs index 3f8d692b..f389e7f1 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,6 +1,7 @@ use std::{path::PathBuf, sync::Arc}; use futures::future; +use rustls::pki_types::CertificateDer; use structopt::StructOpt; use tokio::io::AsyncWriteExt; use tracing::{error, info}; @@ -64,7 +65,7 @@ async fn main() -> Result<(), Box> { match rustls_native_certs::load_native_certs() { Ok(certs) => { for cert in certs { - if let Err(e) = roots.add(&rustls::Certificate(cert.0)) { + if let Err(e) = roots.add(cert) { error!("failed to parse trust anchor: {}", e); } } @@ -76,14 +77,11 @@ async fn main() -> Result<(), Box> { // load certificate of CA who issues the server certificate // NOTE that this should be used for dev only - if let Err(e) = roots.add(&rustls::Certificate(std::fs::read(opt.ca)?)) { + if let Err(e) = roots.add(CertificateDer::from(std::fs::read(opt.ca)?)) { error!("failed to parse trust anchor: {}", e); } let mut tls_config = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13])? .with_root_certificates(roots) .with_no_client_auth(); @@ -99,7 +97,9 @@ async fn main() -> Result<(), Box> { let mut client_endpoint = h3_quinn::quinn::Endpoint::client("[::]:0".parse().unwrap())?; - let client_config = quinn::ClientConfig::new(Arc::new(tls_config)); + let client_config = quinn::ClientConfig::new(Arc::new( + quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)?, + )); client_endpoint.set_default_client_config(client_config); let conn = client_endpoint.connect(addr, auth.host())?.await?; diff --git a/examples/launch_chrome.sh b/examples/launch_chrome.sh new file mode 100755 index 00000000..97acd7ee --- /dev/null +++ b/examples/launch_chrome.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -e + +SPKI=`openssl x509 -inform der -in localhost.crt -pubkey -noout | openssl pkey -pubin -outform der | openssl dgst -sha256 -binary | openssl enc -base64` + +echo "Got cert key $SPKI" + +echo "Opening google chrome" + +case `uname` in + (*Linux*) google-chrome --origin-to-force-quic-on=127.0.0.1:4433 --ignore-certificate-errors-spki-list=$SPKI --enable-logging --v=1 ;; + (*Darwin*) open -a "Google Chrome" --args --origin-to-force-quic-on=127.0.0.1:4433 --ignore-certificate-errors-spki-list=$SPKI --enable-logging --v=1 ;; +esac + +## Logs are stored to ~/Library/Application Support/Google/Chrome/chrome_debug.log diff --git a/examples/server.rs b/examples/server.rs index efb840ab..6a917ca1 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,15 +1,14 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use bytes::{Bytes, BytesMut}; -use futures::StreamExt; use http::{Request, StatusCode}; -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use structopt::StructOpt; use tokio::{fs::File, io::AsyncReadExt}; use tracing::{error, info, trace_span}; use h3::{error::ErrorLevel, quic::BidiStream, server::RequestStream}; -use h3_quinn::quinn; +use h3_quinn::quinn::{self, crypto::rustls::QuicServerConfig}; #[derive(StructOpt, Debug)] #[structopt(name = "server")] @@ -85,28 +84,25 @@ async fn main() -> Result<(), Box> { // create quinn server endpoint and bind UDP socket // both cert and key must be DER-encoded - let cert = Certificate(std::fs::read(cert)?); - let key = PrivateKey(std::fs::read(key)?); + let cert = CertificateDer::from(std::fs::read(cert)?); + let key = PrivateKeyDer::try_from(std::fs::read(key)?)?; let mut tls_config = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() .with_no_client_auth() .with_single_cert(vec![cert], key)?; tls_config.max_early_data_size = u32::MAX; tls_config.alpn_protocols = vec![ALPN.into()]; - let server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_config)); - let (endpoint, mut incoming) = quinn::Endpoint::server(server_config, opt.listen)?; + let server_config = + quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?)); + let endpoint = quinn::Endpoint::server(server_config, opt.listen)?; info!("listening on {}", opt.listen); // handle incoming connections and requests - while let Some(new_conn) = incoming.next().await { + while let Some(new_conn) = endpoint.accept().await { trace_span!("New connection being attempted"); let root = root.clone(); diff --git a/examples/webtransport_server.rs b/examples/webtransport_server.rs new file mode 100644 index 00000000..9ecc7d15 --- /dev/null +++ b/examples/webtransport_server.rs @@ -0,0 +1,338 @@ +use anyhow::{Context, Result}; +use bytes::{BufMut, Bytes, BytesMut}; +use h3::{ + error::ErrorLevel, + ext::Protocol, + quic::{self, RecvDatagramExt, SendDatagramExt, SendStreamUnframed}, + server::Connection, +}; +use h3_quinn::quinn::{self, crypto::rustls::QuicServerConfig}; +use h3_webtransport::{ + server::{self, WebTransportSession}, + stream, +}; +use http::Method; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use structopt::StructOpt; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::pin; +use tracing::{error, info, trace_span}; + +#[derive(StructOpt, Debug)] +#[structopt(name = "server")] +struct Opt { + #[structopt( + short, + long, + default_value = "127.0.0.1:4433", + help = "What address:port to listen for new connections" + )] + pub listen: SocketAddr, + + #[structopt(flatten)] + pub certs: Certs, +} + +#[derive(StructOpt, Debug)] +pub struct Certs { + #[structopt( + long, + short, + default_value = "examples/localhost.crt", + help = "Certificate for TLS. If present, `--key` is mandatory." + )] + pub cert: PathBuf, + + #[structopt( + long, + short, + default_value = "examples/localhost.key", + help = "Private key for the certificate." + )] + pub key: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // 0. Setup tracing + #[cfg(not(feature = "tree"))] + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL) + .with_writer(std::io::stderr) + .init(); + + #[cfg(feature = "tree")] + use tracing_subscriber::prelude::*; + #[cfg(feature = "tree")] + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::from_default_env()) + .with(tracing_tree::HierarchicalLayer::new(4).with_bracketed_fields(true)) + .init(); + + // process cli arguments + + let opt = Opt::from_args(); + + tracing::info!("Opt: {opt:#?}"); + let Certs { cert, key } = opt.certs; + + // create quinn server endpoint and bind UDP socket + + // both cert and key must be DER-encoded + let cert = CertificateDer::from(std::fs::read(cert)?); + let key = PrivateKeyDer::try_from(std::fs::read(key)?)?; + + let mut tls_config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert], key)?; + + tls_config.max_early_data_size = u32::MAX; + let alpn: Vec> = vec![ + b"h3".to_vec(), + b"h3-32".to_vec(), + b"h3-31".to_vec(), + b"h3-30".to_vec(), + b"h3-29".to_vec(), + ]; + tls_config.alpn_protocols = alpn; + + let mut server_config = + quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?)); + let mut transport_config = quinn::TransportConfig::default(); + transport_config.keep_alive_interval(Some(Duration::from_secs(2))); + server_config.transport = Arc::new(transport_config); + let endpoint = quinn::Endpoint::server(server_config, opt.listen)?; + + info!("listening on {}", opt.listen); + + // 2. Accept new quic connections and spawn a new task to handle them + while let Some(new_conn) = endpoint.accept().await { + trace_span!("New connection being attempted"); + + tokio::spawn(async move { + match new_conn.await { + Ok(conn) => { + info!("new http3 established"); + let h3_conn = h3::server::builder() + .enable_webtransport(true) + .enable_connect(true) + .enable_datagram(true) + .max_webtransport_sessions(1) + .send_grease(true) + .build(h3_quinn::Connection::new(conn)) + .await + .unwrap(); + + // tracing::info!("Establishing WebTransport session"); + // // 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3. + // // if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them + // // to the webtransport session. + + tokio::spawn(async move { + if let Err(err) = handle_connection(h3_conn).await { + tracing::error!("Failed to handle connection: {err:?}"); + } + }); + // let mut session: WebTransportSession<_, Bytes> = + // WebTransportSession::accept(h3_conn).await.unwrap(); + // tracing::info!("Finished establishing webtransport session"); + // // 4. Get datagrams, bidirectional streams, and unidirectional streams and wait for client requests here. + // // h3_conn needs to handover the datagrams, bidirectional streams, and unidirectional streams to the webtransport session. + // let result = handle.await; + } + Err(err) => { + error!("accepting connection failed: {:?}", err); + } + } + }); + } + + // shut down gracefully + // wait for connections to be closed before exiting + endpoint.wait_idle().await; + + Ok(()) +} + +async fn handle_connection(mut conn: Connection) -> Result<()> { + // 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3. + // if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them + // to the webtransport session. + + loop { + match conn.accept().await { + Ok(Some((req, stream))) => { + info!("new request: {:#?}", req); + + let ext = req.extensions(); + match req.method() { + &Method::CONNECT if ext.get::() == Some(&Protocol::WEB_TRANSPORT) => { + tracing::info!("Peer wants to initiate a webtransport session"); + + tracing::info!("Handing over connection to WebTransport"); + let session = WebTransportSession::accept(req, stream, conn).await?; + tracing::info!("Established webtransport session"); + // 4. Get datagrams, bidirectional streams, and unidirectional streams and wait for client requests here. + // h3_conn needs to handover the datagrams, bidirectional streams, and unidirectional streams to the webtransport session. + handle_session_and_echo_all_inbound_messages(session).await?; + + return Ok(()); + } + _ => { + tracing::info!(?req, "Received request"); + } + } + } + + // indicating no more streams to be received + Ok(None) => { + break; + } + + Err(err) => { + error!("Error on accept {}", err); + match err.get_error_level() { + ErrorLevel::ConnectionError => break, + ErrorLevel::StreamError => continue, + } + } + } + } + Ok(()) +} + +macro_rules! log_result { + ($expr:expr) => { + if let Err(err) = $expr { + tracing::error!("{err:?}"); + } + }; +} + +async fn echo_stream(send: T, recv: R) -> anyhow::Result<()> +where + T: AsyncWrite, + R: AsyncRead, +{ + pin!(send); + pin!(recv); + + tracing::info!("Got stream"); + let mut buf = Vec::new(); + recv.read_to_end(&mut buf).await?; + + let message = Bytes::from(buf); + + send_chunked(send, message).await?; + + Ok(()) +} + +// Used to test that all chunks arrive properly as it is easy to write an impl which only reads and +// writes the first chunk. +async fn send_chunked(mut send: impl AsyncWrite + Unpin, data: Bytes) -> anyhow::Result<()> { + for chunk in data.chunks(4) { + tokio::time::sleep(Duration::from_millis(100)).await; + tracing::info!("Sending {chunk:?}"); + send.write_all(chunk).await?; + } + + Ok(()) +} + +async fn open_bidi_test(mut stream: S) -> anyhow::Result<()> +where + S: Unpin + AsyncRead + AsyncWrite, +{ + tracing::info!("Opening bidirectional stream"); + + stream + .write_all(b"Hello from a server initiated bidi stream") + .await + .context("Failed to respond")?; + + let mut resp = Vec::new(); + stream.shutdown().await?; + stream.read_to_end(&mut resp).await?; + + tracing::info!("Got response from client: {resp:?}"); + + Ok(()) +} + +/// This method will echo all inbound datagrams, unidirectional and bidirectional streams. +#[tracing::instrument(level = "info", skip(session))] +async fn handle_session_and_echo_all_inbound_messages( + session: WebTransportSession, +) -> anyhow::Result<()> +where + // Use trait bounds to ensure we only happen to use implementation that are only for the quinn + // backend. + C: 'static + + Send + + h3::quic::Connection + + RecvDatagramExt + + SendDatagramExt, + >::Error: + 'static + std::error::Error + Send + Sync + Into, + ::Error: + 'static + std::error::Error + Send + Sync + Into, + stream::BidiStream: + quic::BidiStream + Unpin + AsyncWrite + AsyncRead, + as quic::BidiStream>::SendStream: + Unpin + AsyncWrite + Send + Sync, + as quic::BidiStream>::RecvStream: + Unpin + AsyncRead + Send + Sync, + C::SendStream: Send + Unpin, + C::RecvStream: Send + Unpin, + C::BidiStream: Send + Unpin, + stream::SendStream: AsyncWrite, + C::BidiStream: SendStreamUnframed, + C::SendStream: SendStreamUnframed, +{ + let session_id = session.session_id(); + + // This will open a bidirectional stream and send a message to the client right after connecting! + let stream = session.open_bi(session_id).await?; + + tokio::spawn(async move { log_result!(open_bidi_test(stream).await) }); + + loop { + tokio::select! { + datagram = session.accept_datagram() => { + let datagram = datagram?; + if let Some((_, datagram)) = datagram { + tracing::info!("Responding with {datagram:?}"); + // Put something before to make sure encoding and decoding works and don't just + // pass through + let mut resp = BytesMut::from(&b"Response: "[..]); + resp.put(datagram); + + session.send_datagram(resp.freeze())?; + tracing::info!("Finished sending datagram"); + } + } + uni_stream = session.accept_uni() => { + let (id, stream) = uni_stream?.unwrap(); + + let send = session.open_uni(id).await?; + tokio::spawn( async move { log_result!(echo_stream(send, stream).await); }); + } + stream = session.accept_bi() => { + if let Some(server::AcceptedBi::BidiStream(_, stream)) = stream? { + let (send, recv) = quic::BidiStream::split(stream); + tokio::spawn( async move { log_result!(echo_stream(send, recv).await); }); + } + } + else => { + break + } + } + } + + tracing::info!("Finished handling session"); + + Ok(()) +} diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index fab58bbc..54e5b92c 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -3,14 +3,14 @@ name = "h3-fuzz" version = "0.0.0" authors = ["Automatically generated"] publish = false -edition = "2018" +edition = "2021" [package.metadata] cargo-fuzz = true [dependencies] -libfuzzer-sys = "0.3" -bytes = "0.5.6" +libfuzzer-sys = "0.4" +bytes = "1.5" # Prevent this from interfering with workspaces [workspace] diff --git a/h3-quinn/Cargo.toml b/h3-quinn/Cargo.toml index 8474bf9c..aade3fdb 100644 --- a/h3-quinn/Cargo.toml +++ b/h3-quinn/Cargo.toml @@ -1,19 +1,23 @@ [package] name = "h3-quinn" -version = "0.0.1" +version = "0.0.6" +rust-version = "1.66" authors = ["Jean-Christophe BEGUE "] -edition = "2018" +edition = "2021" documentation = "https://docs.rs/h3-quinn" repository = "https://github.com/hyperium/h3" readme = "../README.md" description = "QUIC transport implementation based on Quinn." -keywords = ["http3","quic","h3"] +keywords = ["http3", "quic", "h3"] categories = ["network-programming", "web-programming"] license = "MIT" [dependencies] -h3 = { version = "0.0.1", path = "../h3" } +h3 = { version = "0.0.5", path = "../h3" } bytes = "1" -futures-util = { version = "0.3", default-features = false, features = ["io"] } -quinn = { version = "0.8.0", default-features = false } -quinn-proto = { version = "0.8.0", default-features = false } +quinn = { version = "0.11", default-features = false, features = [ + "futures-io", +] } +tokio-util = { version = "0.7.9" } +futures = { version = "0.3.28" } +tokio = { version = "1", features = ["io-util"], default-features = false } diff --git a/h3-quinn/LICENSE b/h3-quinn/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/h3-quinn/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/h3-quinn/src/lib.rs b/h3-quinn/src/lib.rs index eb44cb89..7819b383 100644 --- a/h3-quinn/src/lib.rs +++ b/h3-quinn/src/lib.rs @@ -6,51 +6,56 @@ use std::{ convert::TryInto, fmt::{self, Display}, + future::Future, pin::Pin, sync::Arc, task::{self, Poll}, }; -use bytes::{Buf, Bytes}; -use futures_util::future::FutureExt as _; -use futures_util::io::AsyncWrite as _; -use futures_util::ready; -use futures_util::stream::StreamExt as _; +use bytes::{Buf, Bytes, BytesMut}; -pub use quinn::{ - self, crypto::Session, Endpoint, IncomingBiStreams, IncomingUniStreams, NewConnection, OpenBi, - OpenUni, VarInt, WriteError, +use futures::{ + ready, + stream::{self, BoxStream}, + StreamExt, }; +pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError}; +use quinn::{ApplicationClose, ClosedStream, ReadDatagram}; -use h3::quic::{self, Error, StreamId, WriteBuf}; +use h3::{ + ext::Datagram, + quic::{self, Error, StreamId, WriteBuf}, +}; +use tokio_util::sync::ReusableBoxFuture; /// A QUIC connection backed by Quinn /// /// Implements a [`quic::Connection`] backed by a [`quinn::Connection`]. pub struct Connection { conn: quinn::Connection, - incoming_bi: IncomingBiStreams, - opening_bi: Option, - incoming_uni: IncomingUniStreams, - opening_uni: Option, + incoming_bi: BoxStream<'static, as Future>::Output>, + opening_bi: Option as Future>::Output>>, + incoming_uni: BoxStream<'static, as Future>::Output>, + opening_uni: Option as Future>::Output>>, + datagrams: BoxStream<'static, as Future>::Output>, } impl Connection { - /// Create a [`Connection`] from a [`quinn::NewConnection`] - pub fn new(new_conn: NewConnection) -> Self { - let NewConnection { - uni_streams, - bi_streams, - connection, - .. - } = new_conn; - + /// Create a [`Connection`] from a [`quinn::Connection`] + pub fn new(conn: quinn::Connection) -> Self { Self { - conn: connection, - incoming_bi: bi_streams, + conn: conn.clone(), + incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { + Some((conn.accept_bi().await, conn)) + })), opening_bi: None, - incoming_uni: uni_streams, + incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { + Some((conn.accept_uni().await, conn)) + })), opening_uni: None, + datagrams: Box::pin(stream::unfold(conn, |conn| async { + Some((conn.read_datagram().await, conn)) + })), } } } @@ -76,10 +81,9 @@ impl Error for ConnectionError { fn err_code(&self) -> Option { match self.0 { - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }) => Some(error_code.into_inner()), + quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) => { + Some(error_code.into_inner()) + } _ => None, } } @@ -91,6 +95,58 @@ impl From for ConnectionError { } } +/// Types of errors when sending a datagram. +#[derive(Debug)] +pub enum SendDatagramError { + /// Datagrams are not supported by the peer + UnsupportedByPeer, + /// Datagrams are locally disabled + Disabled, + /// The datagram was too large to be sent. + TooLarge, + /// Network error + ConnectionLost(Box), +} + +impl fmt::Display for SendDatagramError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), + SendDatagramError::Disabled => write!(f, "datagram support disabled"), + SendDatagramError::TooLarge => write!(f, "datagram too large"), + SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), + } + } +} + +impl std::error::Error for SendDatagramError {} + +impl Error for SendDatagramError { + fn is_timeout(&self) -> bool { + false + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(err) => err.err_code(), + _ => None, + } + } +} + +impl From for SendDatagramError { + fn from(value: quinn::SendDatagramError) -> Self { + match value { + quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, + quinn::SendDatagramError::Disabled => Self::Disabled, + quinn::SendDatagramError::TooLarge => Self::TooLarge, + quinn::SendDatagramError::ConnectionLost(err) => { + Self::ConnectionLost(ConnectionError::from(err).into()) + } + } + } +} + impl quic::Connection for Connection where B: Buf, @@ -103,7 +159,7 @@ where &mut self, cx: &mut task::Context<'_>, ) -> Poll, Self::AcceptError>> { - let (send, recv) = match ready!(self.incoming_bi.next().poll_unpin(cx)) { + let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { Some(x) => x?, None => return Poll::Ready(Ok(None)), }; @@ -146,10 +202,13 @@ where cx: &mut task::Context<'_>, ) -> Poll> { if self.opening_bi.is_none() { - self.opening_bi = Some(self.conn.open_bi()); + self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.clone().open_bi().await, conn)) + }))); } - let (send, recv) = ready!(self.opening_bi.as_mut().unwrap().poll_unpin(cx))?; + let (send, recv) = + ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; Poll::Ready(Ok(Self::BidiStream { send: Self::SendStream::new(send), recv: RecvStream::new(recv), @@ -161,10 +220,12 @@ where cx: &mut task::Context<'_>, ) -> Poll> { if self.opening_uni.is_none() { - self.opening_uni = Some(self.conn.open_uni()); + self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.open_uni().await, conn)) + }))); } - let send = ready!(self.opening_uni.as_mut().unwrap().poll_unpin(cx))?; + let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; Poll::Ready(Ok(Self::SendStream::new(send))) } @@ -176,14 +237,48 @@ where } } +impl quic::SendDatagramExt for Connection +where + B: Buf, +{ + type Error = SendDatagramError; + + fn send_datagram(&mut self, data: Datagram) -> Result<(), SendDatagramError> { + // TODO investigate static buffer from known max datagram size + let mut buf = BytesMut::new(); + data.encode(&mut buf); + self.conn.send_datagram(buf.freeze())?; + + Ok(()) + } +} + +impl quic::RecvDatagramExt for Connection { + type Buf = Bytes; + + type Error = ConnectionError; + + #[inline] + fn poll_accept_datagram( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + match ready!(self.datagrams.poll_next_unpin(cx)) { + Some(Ok(x)) => Poll::Ready(Ok(Some(x))), + Some(Err(e)) => Poll::Ready(Err(e.into())), + None => Poll::Ready(Ok(None)), + } + } +} + /// Stream opener backed by a Quinn connection /// /// Implements [`quic::OpenStreams`] using [`quinn::Connection`], /// [`quinn::OpenBi`], [`quinn::OpenUni`]. pub struct OpenStreams { conn: quinn::Connection, - opening_bi: Option, - opening_uni: Option, + opening_bi: Option as Future>::Output>>, + opening_uni: Option as Future>::Output>>, } impl quic::OpenStreams for OpenStreams @@ -199,10 +294,13 @@ where cx: &mut task::Context<'_>, ) -> Poll> { if self.opening_bi.is_none() { - self.opening_bi = Some(self.conn.open_bi()); + self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.open_bi().await, conn)) + }))); } - let (send, recv) = ready!(self.opening_bi.as_mut().unwrap().poll_unpin(cx))?; + let (send, recv) = + ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; Poll::Ready(Ok(Self::BidiStream { send: Self::SendStream::new(send), recv: RecvStream::new(recv), @@ -214,10 +312,12 @@ where cx: &mut task::Context<'_>, ) -> Poll> { if self.opening_uni.is_none() { - self.opening_uni = Some(self.conn.open_uni()); + self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.open_uni().await, conn)) + }))); } - let send = ready!(self.opening_uni.as_mut().unwrap().poll_unpin(cx))?; + let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; Poll::Ready(Ok(Self::SendStream::new(send))) } @@ -263,10 +363,7 @@ where } } -impl quic::RecvStream for BidiStream -where - B: Buf, -{ +impl quic::RecvStream for BidiStream { type Buf = Bytes; type Error = ReadError; @@ -280,6 +377,10 @@ where fn stop_sending(&mut self, error_code: u64) { self.recv.stop_sending(error_code) } + + fn recv_id(&self) -> StreamId { + self.recv.recv_id() + } } impl quic::SendStream for BidiStream @@ -304,8 +405,20 @@ where self.send.send_data(data) } - fn id(&self) -> StreamId { - self.send.id() + fn send_id(&self) -> StreamId { + self.send.send_id() + } +} +impl quic::SendStreamUnframed for BidiStream +where + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.send.poll_send(cx, buf) } } @@ -313,12 +426,25 @@ where /// /// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`]. pub struct RecvStream { - stream: quinn::RecvStream, + stream: Option, + read_chunk_fut: ReadChunkFuture, } +type ReadChunkFuture = ReusableBoxFuture< + 'static, + ( + quinn::RecvStream, + Result, quinn::ReadError>, + ), +>; + impl RecvStream { fn new(stream: quinn::RecvStream) -> Self { - Self { stream } + Self { + stream: Some(stream), + // Should only allocate once the first time it's used + read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), + } } } @@ -330,17 +456,34 @@ impl quic::RecvStream for RecvStream { &mut self, cx: &mut task::Context<'_>, ) -> Poll, Self::Error>> { - Poll::Ready(Ok(ready!(self - .stream - .read_chunk(usize::MAX, true) - .poll_unpin(cx))? - .map(|c| (c.bytes)))) + if let Some(mut stream) = self.stream.take() { + self.read_chunk_fut.set(async move { + let chunk = stream.read_chunk(usize::MAX, true).await; + (stream, chunk) + }) + }; + + let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); + self.stream = Some(stream); + Poll::Ready(Ok(chunk?.map(|c| c.bytes))) } fn stop_sending(&mut self, error_code: u64) { - let _ = self - .stream - .stop(VarInt::from_u64(error_code).expect("invalid error_code")); + self.stream + .as_mut() + .unwrap() + .stop(VarInt::from_u64(error_code).expect("invalid error_code")) + .ok(); + } + + fn recv_id(&self) -> StreamId { + self.stream + .as_ref() + .unwrap() + .id() + .0 + .try_into() + .expect("invalid stream id") } } @@ -350,7 +493,17 @@ impl quic::RecvStream for RecvStream { #[derive(Debug)] pub struct ReadError(quinn::ReadError); -impl std::error::Error for ReadError {} +impl From for std::io::Error { + fn from(value: ReadError) -> Self { + value.0.into() + } +} + +impl std::error::Error for ReadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0.source() + } +} impl fmt::Display for ReadError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -381,7 +534,7 @@ impl Error for ReadError { fn err_code(&self) -> Option { match self.0 { quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( - quinn_proto::ApplicationClose { error_code, .. }, + ApplicationClose { error_code, .. }, )) => Some(error_code.into_inner()), quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), _ => None, @@ -393,18 +546,23 @@ impl Error for ReadError { /// /// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`]. pub struct SendStream { - stream: quinn::SendStream, + stream: Option, writing: Option>, + write_fut: WriteFuture, } +type WriteFuture = + ReusableBoxFuture<'static, (quinn::SendStream, Result)>; + impl SendStream where B: Buf, { fn new(stream: quinn::SendStream) -> SendStream { Self { - stream, + stream: Some(stream), writing: None, + write_fut: ReusableBoxFuture::new(async { unreachable!() }), } } } @@ -418,24 +576,20 @@ where fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { if let Some(ref mut data) = self.writing { while data.has_remaining() { - match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) { + if let Some(mut stream) = self.stream.take() { + let chunk = data.chunk().to_owned(); // FIXME - avoid copy + self.write_fut.set(async move { + let ret = stream.write(&chunk).await; + (stream, ret) + }); + } + + let (stream, res) = ready!(self.write_fut.poll(cx)); + self.stream = Some(stream); + match res { Ok(cnt) => data.advance(cnt), Err(err) => { - // We are forced to use AsyncWrite for now because we cannot store - // the result of a call to: - // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S>. - // - // This is why we have to unpack the error from io::Error below. This should not - // panic as long as quinn's AsyncWrite impl doesn't change. - return Poll::Ready(Err(SendStreamError::Write( - err.into_inner() - .expect("write stream returned an empty error") - .downcast_ref::() - .expect( - "write stream returned an error which type is not WriteError", - ) - .clone(), - ))); + return Poll::Ready(Err(SendStreamError::Write(err))); } } } @@ -444,13 +598,15 @@ where Poll::Ready(Ok(())) } - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.stream.finish().poll_unpin(cx).map_err(Into::into) + fn poll_finish(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + Poll::Ready(self.stream.as_mut().unwrap().finish().map_err(|e| e.into())) } fn reset(&mut self, reset_code: u64) { let _ = self .stream + .as_mut() + .unwrap() .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); } @@ -462,8 +618,56 @@ where Ok(()) } - fn id(&self) -> StreamId { - self.stream.id().0.try_into().expect("invalid stream id") + fn send_id(&self) -> StreamId { + self.stream + .as_ref() + .unwrap() + .id() + .0 + .try_into() + .expect("invalid stream id") + } +} + +impl quic::SendStreamUnframed for SendStream +where + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + if self.writing.is_some() { + // This signifies a bug in implementation + panic!("poll_send called while send stream is not ready") + } + + let s = Pin::new(self.stream.as_mut().unwrap()); + + let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); + match res { + Ok(written) => { + buf.advance(written); + Poll::Ready(Ok(written)) + } + Err(err) => { + // We are forced to use AsyncWrite for now because we cannot store + // the result of a call to: + // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. + // + // This is why we have to unpack the error from io::Error instead of having it + // returned directly. This should not panic as long as quinn's AsyncWrite impl + // doesn't change. + let err = err + .into_inner() + .expect("write stream returned an empty error") + .downcast::() + .expect("write stream returned an error which type is not WriteError"); + + Poll::Ready(Err(SendStreamError::Write(*err))) + } + } } } @@ -477,6 +681,20 @@ pub enum SendStreamError { /// Error when the stream is not ready, because it is still sending /// data from a previous call NotReady, + /// Error when the stream is closed + StreamClosed(ClosedStream), +} + +impl From for std::io::Error { + fn from(value: SendStreamError) -> Self { + match value { + SendStreamError::Write(err) => err.into(), + SendStreamError::NotReady => { + std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") + } + SendStreamError::StreamClosed(err) => err.into(), + } + } } impl std::error::Error for SendStreamError {} @@ -493,6 +711,12 @@ impl From for SendStreamError { } } +impl From for SendStreamError { + fn from(value: ClosedStream) -> Self { + Self::StreamClosed(value) + } +} + impl Error for SendStreamError { fn is_timeout(&self) -> bool { matches!( @@ -507,10 +731,7 @@ impl Error for SendStreamError { match self { Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }), + quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }), )) => Some(error_code.into_inner()), _ => None, } diff --git a/h3-webtransport/Cargo.toml b/h3-webtransport/Cargo.toml new file mode 100644 index 00000000..a502cd0e --- /dev/null +++ b/h3-webtransport/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "h3-webtransport" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1" +futures-util = { version = "0.3", default-features = false } +http = "1" +pin-project-lite = { version = "0.2", default_features = false } +tracing = "0.1.37" +tokio = { version = "1.28", default_features = false } + +[dependencies.h3] +version = "0.0.5" +path = "../h3" +features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"] diff --git a/h3-webtransport/src/lib.rs b/h3-webtransport/src/lib.rs new file mode 100644 index 00000000..85762d45 --- /dev/null +++ b/h3-webtransport/src/lib.rs @@ -0,0 +1,13 @@ +//! Provides the client and server support for WebTransport sessions. +//! +//! # Relevant Links +//! WebTransport: +//! WebTransport over HTTP/3: +#![deny(missing_docs)] + +/// Server side WebTransport session support +pub mod server; +/// Webtransport stream types +pub mod stream; + +pub use h3::webtransport::SessionId; diff --git a/h3-webtransport/src/server.rs b/h3-webtransport/src/server.rs new file mode 100644 index 00000000..05970b61 --- /dev/null +++ b/h3-webtransport/src/server.rs @@ -0,0 +1,428 @@ +//! Provides the server side WebTransport session + +use std::{ + marker::PhantomData, + pin::Pin, + sync::Mutex, + task::{Context, Poll}, +}; + +use bytes::Buf; +use futures_util::{future::poll_fn, ready, Future}; +use h3::{ + connection::ConnectionState, + error::{Code, ErrorLevel}, + ext::{Datagram, Protocol}, + frame::FrameStream, + proto::frame::Frame, + quic::{self, OpenStreams, RecvDatagramExt, SendDatagramExt, WriteBuf}, + server::Connection, + server::RequestStream, + Error, +}; +use h3::{ + quic::SendStreamUnframed, + stream::{BidiStreamHeader, BufRecvStream, UniStreamHeader}, +}; +use http::{Method, Request, Response, StatusCode}; + +use h3::webtransport::SessionId; +use pin_project_lite::pin_project; + +use crate::stream::{BidiStream, RecvStream, SendStream}; + +/// WebTransport session driver. +/// +/// Maintains the session using the underlying HTTP/3 connection. +/// +/// Similar to [`h3::server::Connection`](https://docs.rs/h3/latest/h3/server/struct.Connection.html) it is generic over the QUIC implementation and Buffer. +pub struct WebTransportSession +where + C: quic::Connection, + B: Buf, +{ + // See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-2-3 + session_id: SessionId, + /// The underlying HTTP/3 connection + server_conn: Mutex>, + connect_stream: RequestStream, + opener: Mutex, +} + +impl WebTransportSession +where + C: quic::Connection, + B: Buf, +{ + /// Accepts a *CONNECT* request for establishing a WebTransport session. + /// + /// TODO: is the API or the user responsible for validating the CONNECT request? + pub async fn accept( + request: Request<()>, + mut stream: RequestStream, + mut conn: Connection, + ) -> Result { + let shared = conn.shared_state().clone(); + { + let config = shared.write("Read WebTransport support").peer_config; + + if !config.enable_webtransport() { + return Err(conn.close( + Code::H3_SETTINGS_ERROR, + "webtransport is not supported by client", + )); + } + + if !config.enable_datagram() { + return Err(conn.close( + Code::H3_SETTINGS_ERROR, + "datagrams are not supported by client", + )); + } + } + + // The peer is responsible for validating our side of the webtransport support. + // + // However, it is still advantageous to show a log on the server as (attempting) to + // establish a WebTransportSession without the proper h3 config is usually a mistake. + if !conn.inner.config.settings.enable_webtransport() { + tracing::warn!("Server does not support webtransport"); + } + + if !conn.inner.config.settings.enable_datagram() { + tracing::warn!("Server does not support datagrams"); + } + + if !conn.inner.config.settings.enable_extended_connect() { + tracing::warn!("Server does not support CONNECT"); + } + + // Respond to the CONNECT request. + + //= https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.3 + let response = if validate_wt_connect(&request) { + Response::builder() + // This is the only header that chrome cares about. + .header("sec-webtransport-http3-draft", "draft02") + .status(StatusCode::OK) + .body(()) + .unwrap() + } else { + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(()) + .unwrap() + }; + + stream.send_response(response).await?; + + let session_id = stream.send_id().into(); + let conn_inner = &mut conn.inner.conn; + let opener = Mutex::new(conn_inner.opener()); + + Ok(Self { + session_id, + opener, + server_conn: Mutex::new(conn), + connect_stream: stream, + }) + } + + /// Receive a datagram from the client + pub fn accept_datagram(&self) -> ReadDatagram { + ReadDatagram { + conn: &self.server_conn, + _marker: PhantomData, + } + } + + /// Sends a datagram + /// + /// TODO: maybe make async. `quinn` does not require an async send + pub fn send_datagram(&self, data: B) -> Result<(), Error> + where + C: SendDatagramExt, + { + self.server_conn + .lock() + .unwrap() + .send_datagram(self.connect_stream.id(), data)?; + + Ok(()) + } + + /// Accept an incoming unidirectional stream from the client, it reads the stream until EOF. + pub fn accept_uni(&self) -> AcceptUni { + AcceptUni { + conn: &self.server_conn, + } + } + + /// Accepts an incoming bidirectional stream or request + pub async fn accept_bi(&self) -> Result>, Error> { + // Get the next stream + // Accept the incoming stream + let stream = poll_fn(|cx| { + let mut conn = self.server_conn.lock().unwrap(); + conn.poll_accept_request(cx) + }) + .await; + + let mut stream = match stream { + Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)), + Ok(None) => { + // FIXME: is proper HTTP GoAway shutdown required? + return Ok(None); + } + Err(err) => { + match err.kind() { + h3::error::Kind::Closed => return Ok(None), + h3::error::Kind::Application { + code, + reason, + level: ErrorLevel::ConnectionError, + .. + } => { + return Err(self.server_conn.lock().unwrap().close( + code, + reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), + )) + } + _ => return Err(err), + }; + } + }; + + // Read the first frame. + // + // This will determine if it is a webtransport bi-stream or a request stream + let frame = poll_fn(|cx| stream.poll_next(cx)).await; + + match frame { + Ok(None) => Ok(None), + Ok(Some(Frame::WebTransportStream(session_id))) => { + // Take the stream out of the framed reader and split it in half like Paul Allen + let stream = stream.into_inner(); + + Ok(Some(AcceptedBi::BidiStream( + session_id, + BidiStream::new(stream), + ))) + } + // Make the underlying HTTP/3 connection handle the rest + frame => { + let req = { + let mut conn = self.server_conn.lock().unwrap(); + conn.accept_with_frame(stream, frame)? + }; + if let Some(req) = req { + let (req, resp) = req.resolve().await?; + Ok(Some(AcceptedBi::Request(req, resp))) + } else { + Ok(None) + } + } + } + } + + /// Open a new bidirectional stream + pub fn open_bi(&self, session_id: SessionId) -> OpenBi { + OpenBi { + opener: &self.opener, + stream: None, + session_id, + } + } + + /// Open a new unidirectional stream + pub fn open_uni(&self, session_id: SessionId) -> OpenUni { + OpenUni { + opener: &self.opener, + stream: None, + session_id, + } + } + + /// Returns the session id + pub fn session_id(&self) -> SessionId { + self.session_id + } +} + +/// Streams are opened, but the initial webtransport header has not been sent +type PendingStreams = ( + BidiStream<>::BidiStream, B>, + WriteBuf<&'static [u8]>, +); + +/// Streams are opened, but the initial webtransport header has not been sent +type PendingUniStreams = ( + SendStream<>::SendStream, B>, + WriteBuf<&'static [u8]>, +); + +pin_project! { + /// Future for opening a bidi stream + pub struct OpenBi<'a, C:quic::Connection, B:Buf> { + opener: &'a Mutex, + stream: Option>, + session_id: SessionId, + } +} + +impl<'a, B, C> Future for OpenBi<'a, C, B> +where + C: quic::Connection, + B: Buf, + C::BidiStream: SendStreamUnframed, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut p = self.project(); + loop { + match &mut p.stream { + Some((stream, buf)) => { + while buf.has_remaining() { + ready!(stream.poll_send(cx, buf))?; + } + + let (stream, _) = p.stream.take().unwrap(); + return Poll::Ready(Ok(stream)); + } + None => { + let mut opener = (*p.opener).lock().unwrap(); + // Open the stream first + let res = ready!(opener.poll_open_bidi(cx))?; + let stream = BidiStream::new(BufRecvStream::new(res)); + + let buf = WriteBuf::from(BidiStreamHeader::WebTransportBidi(*p.session_id)); + *p.stream = Some((stream, buf)); + } + } + } + } +} + +pin_project! { + /// Opens a unidirectional stream + pub struct OpenUni<'a, C: quic::Connection, B:Buf> { + opener: &'a Mutex, + stream: Option>, + // Future for opening a uni stream + session_id: SessionId, + } +} + +impl<'a, C, B> Future for OpenUni<'a, C, B> +where + C: quic::Connection, + B: Buf, + C::SendStream: SendStreamUnframed, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut p = self.project(); + loop { + match &mut p.stream { + Some((send, buf)) => { + while buf.has_remaining() { + ready!(send.poll_send(cx, buf))?; + } + let (send, buf) = p.stream.take().unwrap(); + assert!(!buf.has_remaining()); + return Poll::Ready(Ok(send)); + } + None => { + let mut opener = (*p.opener).lock().unwrap(); + let send = ready!(opener.poll_open_send(cx))?; + let send = BufRecvStream::new(send); + let send = SendStream::new(send); + + let buf = WriteBuf::from(UniStreamHeader::WebTransportUni(*p.session_id)); + *p.stream = Some((send, buf)); + } + } + } + } +} + +/// An accepted incoming bidirectional stream. +/// +/// Since +pub enum AcceptedBi, B: Buf> { + /// An incoming bidirectional stream + BidiStream(SessionId, BidiStream), + /// An incoming HTTP/3 request, passed through a webtransport session. + /// + /// This makes it possible to respond to multiple CONNECT requests + Request(Request<()>, RequestStream), +} + +/// Future for [`Connection::read_datagram`] +pub struct ReadDatagram<'a, C, B> +where + C: quic::Connection, + B: Buf, +{ + conn: &'a Mutex>, + _marker: PhantomData, +} + +impl<'a, C, B> Future for ReadDatagram<'a, C, B> +where + C: quic::Connection + RecvDatagramExt, + B: Buf, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut conn = self.conn.lock().unwrap(); + match ready!(conn.inner.conn.poll_accept_datagram(cx))? { + Some(v) => { + let datagram = Datagram::decode(v)?; + Poll::Ready(Ok(Some(( + datagram.stream_id().into(), + datagram.into_payload(), + )))) + } + None => Poll::Ready(Ok(None)), + } + } +} + +/// Future for [`WebTransportSession::accept_uni`] +pub struct AcceptUni<'a, C, B> +where + C: quic::Connection, + B: Buf, +{ + conn: &'a Mutex>, +} + +impl<'a, C, B> Future for AcceptUni<'a, C, B> +where + C: quic::Connection, + B: Buf, +{ + type Output = Result)>, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut conn = self.conn.lock().unwrap(); + conn.inner.poll_accept_recv(cx)?; + + // Get the currently available streams + let streams = conn.inner.accepted_streams_mut(); + if let Some((id, stream)) = streams.wt_uni_streams.pop() { + return Poll::Ready(Ok(Some((id, RecvStream::new(stream))))); + } + + Poll::Pending + } +} + +fn validate_wt_connect(request: &Request<()>) -> bool { + let protocol = request.extensions().get::(); + matches!((request.method(), protocol), (&Method::CONNECT, Some(p)) if p == &Protocol::WEB_TRANSPORT) +} diff --git a/h3-webtransport/src/stream.rs b/h3-webtransport/src/stream.rs new file mode 100644 index 00000000..4f29f6f5 --- /dev/null +++ b/h3-webtransport/src/stream.rs @@ -0,0 +1,382 @@ +use std::task::Poll; + +use bytes::{Buf, Bytes}; +use h3::{quic, stream::BufRecvStream}; +use pin_project_lite::pin_project; +use tokio::io::ReadBuf; + +pin_project! { + /// WebTransport receive stream + pub struct RecvStream { + #[pin] + stream: BufRecvStream, + } +} + +impl RecvStream { + #[allow(missing_docs)] + pub fn new(stream: BufRecvStream) -> Self { + Self { stream } + } +} + +impl quic::RecvStream for RecvStream +where + S: quic::RecvStream, + B: Buf, +{ + type Buf = Bytes; + + type Error = S::Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>> { + self.stream.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.recv_id() + } +} + +impl futures_util::io::AsyncRead for RecvStream +where + BufRecvStream: futures_util::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +impl tokio::io::AsyncRead for RecvStream +where + BufRecvStream: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +pin_project! { + /// WebTransport send stream + pub struct SendStream { + #[pin] + stream: BufRecvStream, + } +} + +impl std::fmt::Debug for SendStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendStream") + .field("stream", &self.stream) + .finish() + } +} + +impl SendStream { + #[allow(missing_docs)] + pub(crate) fn new(stream: BufRecvStream) -> Self { + Self { stream } + } +} + +impl quic::SendStreamUnframed for SendStream +where + S: quic::SendStreamUnframed, + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } +} + +impl quic::SendStream for SendStream +where + S: quic::SendStream, + B: Buf, +{ + type Error = S::Error; + + fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.stream.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.stream.send_id() + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.stream.send_data(data) + } + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_ready(cx) + } +} + +impl futures_util::io::AsyncWrite for SendStream +where + BufRecvStream: futures_util::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_close(cx) + } +} + +impl tokio::io::AsyncWrite for SendStream +where + BufRecvStream: tokio::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_shutdown(cx) + } +} + +pin_project! { + /// Combined send and receive stream. + /// + /// Can be split into a [`RecvStream`] and [`SendStream`] if the underlying QUIC implementation + /// supports it. + pub struct BidiStream { + #[pin] + stream: BufRecvStream, + } +} + +impl BidiStream { + pub(crate) fn new(stream: BufRecvStream) -> Self { + Self { stream } + } +} + +impl quic::SendStream for BidiStream +where + S: quic::SendStream, + B: Buf, +{ + type Error = S::Error; + + fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.stream.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.stream.send_id() + } + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.stream.send_data(data) + } +} + +impl quic::SendStreamUnframed for BidiStream +where + S: quic::SendStreamUnframed, + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } +} + +impl quic::RecvStream for BidiStream { + type Buf = Bytes; + + type Error = S::Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>> { + self.stream.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.recv_id() + } +} + +impl quic::BidiStream for BidiStream +where + S: quic::BidiStream, + B: Buf, +{ + type SendStream = SendStream; + + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + let (send, recv) = self.stream.split(); + (SendStream::new(send), RecvStream::new(recv)) + } +} + +impl futures_util::io::AsyncRead for BidiStream +where + BufRecvStream: futures_util::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +impl futures_util::io::AsyncWrite for BidiStream +where + BufRecvStream: futures_util::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_close(cx) + } +} + +impl tokio::io::AsyncRead for BidiStream +where + BufRecvStream: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for BidiStream +where + BufRecvStream: tokio::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_shutdown(cx) + } +} diff --git a/h3/Cargo.toml b/h3/Cargo.toml index f3fef363..7475942a 100644 --- a/h3/Cargo.toml +++ b/h3/Cargo.toml @@ -1,38 +1,48 @@ [package] name = "h3" -version = "0.0.1" +version = "0.0.5" +rust-version = "1.63" authors = [ "Sean McArthur ", "Jean-Christophe BEGUE ", ] license = "MIT" -edition = "2018" +edition = "2021" documentation = "https://docs.rs/h3" repository = "https://github.com/hyperium/h3" readme = "../README.md" description = "An async HTTP/3 implementation." -keywords = ["http3","quic","h3"] -categories = ["network-programming", "web-programming::http-client", "web-programming::http-server"] +keywords = ["http3", "quic", "h3"] +categories = [ + "network-programming", + "web-programming::http-client", + "web-programming::http-server", +] + +[features] +i-implement-a-third-party-backend-and-opt-into-breaking-changes = [] [dependencies] bytes = "1" -futures-util = { version = "0.3", default-features = false } -http = "0.2.3" +futures-util = { version = "0.3", default-features = false, features = ["io"] } +http = "1" tokio = { version = "1", features = ["sync"] } -tracing = "0.1.18" -fastrand = "1.7.0" +pin-project-lite = { version = "0.2", default_features = false } +tracing = "0.1.40" +fastrand = "2.0.1" [dev-dependencies] -assert_matches = "1.3.0" +assert_matches = "1.5.0" futures-util = { version = "0.3", default-features = false, features = ["io"] } proptest = "1" -quinn = { version = "0.8.0", default-features = false, features = [ - "tls-rustls", +quinn = { version = "0.11", default-features = false, features = [ + "runtime-tokio", + "rustls", "ring", ] } -quinn-proto = { version = "0.8.0", default-features = false } -rcgen = "0.9" -rustls = "0.20" +quinn-proto = { version = "0.11", default-features = false } +rcgen = "0.13" +rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std"] } tokio = { version = "1", features = ["rt", "macros", "io-util", "io-std"] } tracing-subscriber = { version = "0.3", default-features = false, features = [ "fmt", @@ -41,3 +51,5 @@ tracing-subscriber = { version = "0.3", default-features = false, features = [ "time", "tracing-log", ] } +futures = { version = "0.3.28" } +tokio-util = { version = "0.7.9" } diff --git a/h3/LICENSE b/h3/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/h3/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/h3/src/buf.rs b/h3/src/buf.rs index 99d12d56..c6c5617e 100644 --- a/h3/src/buf.rs +++ b/h3/src/buf.rs @@ -3,6 +3,7 @@ use std::io::IoSlice; use bytes::{Buf, Bytes}; +#[derive(Debug)] pub(crate) struct BufList { bufs: VecDeque, } @@ -32,11 +33,16 @@ impl BufList { } impl BufList { + pub fn take_first_chunk(&mut self) -> Option { + self.bufs.pop_front() + } + pub fn take_chunk(&mut self, max_len: usize) -> Option { let chunk = self .bufs .front_mut() .map(|chunk| chunk.split_to(usize::min(max_len, chunk.remaining()))); + if let Some(front) = self.bufs.front() { if front.remaining() == 0 { let _ = self.bufs.pop_front(); diff --git a/h3/src/client/builder.rs b/h3/src/client/builder.rs new file mode 100644 index 00000000..d5d875cb --- /dev/null +++ b/h3/src/client/builder.rs @@ -0,0 +1,129 @@ +//! HTTP/3 client builder + +use std::{ + marker::PhantomData, + sync::{atomic::AtomicUsize, Arc}, + task::Poll, +}; + +use bytes::{Buf, Bytes}; +use futures_util::future; + +use crate::{ + config::Config, + connection::{ConnectionInner, SharedStateRef}, + error::Error, + quic::{self}, +}; + +use super::connection::{Connection, SendRequest}; + +/// Start building a new HTTP/3 client +pub fn builder() -> Builder { + Builder::new() +} + +/// Create a new HTTP/3 client with default settings +pub async fn new(conn: C) -> Result<(Connection, SendRequest), Error> +where + C: quic::Connection, + O: quic::OpenStreams, +{ + //= https://www.rfc-editor.org/rfc/rfc9114#section-3.3 + //= type=implication + //# Clients SHOULD NOT open more than one HTTP/3 connection to a given IP + //# address and UDP port, where the IP address and port might be derived + //# from a URI, a selected alternative service ([ALTSVC]), a configured + //# proxy, or name resolution of any of these. + Builder::new().build(conn).await +} + +/// HTTP/3 client builder +/// +/// Set the configuration for a new client. +/// +/// # Examples +/// ```rust +/// # use h3::quic; +/// # async fn doc(quic: C) +/// # where +/// # C: quic::Connection, +/// # O: quic::OpenStreams, +/// # B: bytes::Buf, +/// # { +/// let h3_conn = h3::client::builder() +/// .max_field_section_size(8192) +/// .build(quic) +/// .await +/// .expect("Failed to build connection"); +/// # } +/// ``` +pub struct Builder { + config: Config, +} + +impl Builder { + pub(super) fn new() -> Self { + Builder { + config: Default::default(), + } + } + + #[cfg(test)] + pub fn send_settings(&mut self, value: bool) -> &mut Self { + self.config.send_settings = value; + self + } + + /// Set the maximum header size this client is willing to accept + /// + /// See [header size constraints] section of the specification for details. + /// + /// [header size constraints]: https://www.rfc-editor.org/rfc/rfc9114.html#name-header-size-constraints + pub fn max_field_section_size(&mut self, value: u64) -> &mut Self { + self.config.settings.max_field_section_size = value; + self + } + + /// Just like in HTTP/2, HTTP/3 also uses the concept of "grease" + /// to prevent potential interoperability issues in the future. + /// In HTTP/3, the concept of grease is used to ensure that the protocol can evolve + /// and accommodate future changes without breaking existing implementations. + pub fn send_grease(&mut self, enabled: bool) -> &mut Self { + self.config.send_grease = enabled; + self + } + + /// Create a new HTTP/3 client from a `quic` connection + pub async fn build( + &mut self, + quic: C, + ) -> Result<(Connection, SendRequest), Error> + where + C: quic::Connection, + O: quic::OpenStreams, + B: Buf, + { + let open = quic.opener(); + let conn_state = SharedStateRef::default(); + + let conn_waker = Some(future::poll_fn(|cx| Poll::Ready(cx.waker().clone())).await); + + Ok(( + Connection { + inner: ConnectionInner::new(quic, conn_state.clone(), self.config).await?, + sent_closing: None, + recv_closing: None, + }, + SendRequest { + open, + conn_state, + conn_waker, + max_field_section_size: self.config.settings.max_field_section_size, + sender_count: Arc::new(AtomicUsize::new(1)), + send_grease_frame: self.config.send_grease, + _buf: PhantomData, + }, + )) + } +} diff --git a/h3/src/client.rs b/h3/src/client/connection.rs similarity index 57% rename from h3/src/client.rs rename to h3/src/client/connection.rs index e2643aaa..4fb53cd7 100644 --- a/h3/src/client.rs +++ b/h3/src/client/connection.rs @@ -1,46 +1,27 @@ //! Client implementation of the HTTP/3 protocol use std::{ - convert::TryFrom, marker::PhantomData, sync::{atomic::AtomicUsize, Arc}, task::{Context, Poll, Waker}, }; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Buf, BytesMut}; use futures_util::future; -use http::{request, HeaderMap, Response}; +use http::request; use tracing::{info, trace}; use crate::{ connection::{self, ConnectionInner, ConnectionState, SharedStateRef}, error::{Code, Error, ErrorLevel}, frame::FrameStream, - proto::{frame::Frame, headers::Header, push::PushId, varint::VarInt}, + proto::{frame::Frame, headers::Header, push::PushId}, qpack, quic::{self, StreamId}, - stream, + stream::{self, BufRecvStream}, }; -/// Start building a new HTTP/3 client -pub fn builder() -> Builder { - Builder::new() -} - -/// Create a new HTTP/3 client with default settings -pub async fn new(conn: C) -> Result<(Connection, SendRequest), Error> -where - C: quic::Connection, - O: quic::OpenStreams, -{ - //= https://www.rfc-editor.org/rfc/rfc9114#section-3.3 - //= type=implication - //# Clients SHOULD NOT open more than one HTTP/3 connection to a given IP - //# address and UDP port, where the IP address and port might be derived - //# from a URI, a selected alternative service ([ALTSVC]), a configured - //# proxy, or name resolution of any of these. - Builder::new().build(conn).await -} +use super::stream::RequestStream; /// HTTP/3 request sender /// @@ -124,14 +105,14 @@ where T: quic::OpenStreams, B: Buf, { - open: T, - conn_state: SharedStateRef, - max_field_section_size: u64, // maximum size for a header we receive + pub(super) open: T, + pub(super) conn_state: SharedStateRef, + pub(super) max_field_section_size: u64, // maximum size for a header we receive // counts instances of SendRequest to close the connection when the last is dropped. - sender_count: Arc, - conn_waker: Option, - _buf: PhantomData, - send_grease_frame: bool, + pub(super) sender_count: Arc, + pub(super) conn_waker: Option, + pub(super) _buf: PhantomData, + pub(super) send_grease_frame: bool, } impl SendRequest @@ -146,7 +127,7 @@ where ) -> Result, Error> { let (peer_max_field_section_size, closing) = { let state = self.conn_state.read("send request lock state"); - (state.peer_max_field_section_size, state.closing) + (state.peer_config.max_field_section_size, state.closing) }; if closing { @@ -158,9 +139,10 @@ where method, uri, headers, + extensions, .. } = parts; - let headers = Header::request(method, uri, headers)?; + let headers = Header::request(method, uri, headers, extensions)?; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 //= type=implication @@ -199,7 +181,7 @@ where let request_stream = RequestStream { inner: connection::RequestStream::new( - FrameStream::new(stream), + FrameStream::new(BufRecvStream::new(stream)), self.max_field_section_size, self.conn_state.clone(), self.send_grease_frame, @@ -253,7 +235,7 @@ where .fetch_sub(1, std::sync::atomic::Ordering::AcqRel) == 1 { - if let Some(w) = self.conn_waker.take() { + if let Some(w) = Option::take(&mut self.conn_waker) { w.wake() } self.shared_state().write("SendRequest drop").error = Some(Error::closed()); @@ -304,7 +286,9 @@ where /// ```rust /// # use bytes::Buf; /// # use futures_util::future; -/// # use h3::{client::*, quic}; +/// # use h3::quic; +/// # use h3::client::Connection; +/// # use h3::client::SendRequest; /// # use tokio::{self, sync::oneshot, task::JoinHandle}; /// # async fn doc(mut connection: Connection) /// # -> Result<(), Box> @@ -349,11 +333,11 @@ where C: quic::Connection, B: Buf, { - inner: ConnectionInner, + pub(super) inner: ConnectionInner, // Has a GOAWAY frame been sent? If so, this PushId is the last we are willing to accept. - sent_closing: Option, + pub(super) sent_closing: Option, // Has a GOAWAY frame been received? If so, this is StreamId the last the remote will accept. - recv_closing: Option, + pub(super) recv_closing: Option, } impl Connection @@ -460,294 +444,3 @@ where Poll::Pending } } - -/// HTTP/3 client builder -/// -/// Set the configuration for a new client. -/// -/// # Examples -/// ```rust -/// # use h3::quic; -/// # async fn doc(quic: C) -/// # where -/// # C: quic::Connection, -/// # O: quic::OpenStreams, -/// # B: bytes::Buf, -/// # { -/// let h3_conn = h3::client::builder() -/// .max_field_section_size(8192) -/// .build(quic) -/// .await -/// .expect("Failed to build connection"); -/// # } -/// ``` -pub struct Builder { - max_field_section_size: u64, - send_grease: bool, -} - -impl Builder { - pub(super) fn new() -> Self { - Builder { - max_field_section_size: VarInt::MAX.0, - send_grease: true, - } - } - - /// Set the maximum header size this client is willing to accept - /// - /// See [header size constraints] section of the specification for details. - /// - /// [header size constraints]: https://www.rfc-editor.org/rfc/rfc9114.html#name-header-size-constraints - pub fn max_field_section_size(&mut self, value: u64) -> &mut Self { - self.max_field_section_size = value; - self - } - - /// Create a new HTTP/3 client from a `quic` connection - pub async fn build( - &mut self, - quic: C, - ) -> Result<(Connection, SendRequest), Error> - where - C: quic::Connection, - O: quic::OpenStreams, - B: Buf, - { - let open = quic.opener(); - let conn_state = SharedStateRef::default(); - - let conn_waker = Some(future::poll_fn(|cx| Poll::Ready(cx.waker().clone())).await); - - Ok(( - Connection { - inner: ConnectionInner::new( - quic, - self.max_field_section_size, - conn_state.clone(), - self.send_grease, - ) - .await?, - sent_closing: None, - recv_closing: None, - }, - SendRequest { - open, - conn_state, - conn_waker, - max_field_section_size: self.max_field_section_size, - sender_count: Arc::new(AtomicUsize::new(1)), - _buf: PhantomData, - send_grease_frame: self.send_grease, - }, - )) - } -} - -/// Manage request bodies transfer, response and trailers. -/// -/// Once a request has been sent via [`send_request()`], a response can be awaited by calling -/// [`recv_response()`]. A body for this request can be sent with [`send_data()`], then the request -/// shall be completed by either sending trailers with [`send_trailers()`], or [`finish()`]. -/// -/// After receiving the response's headers, it's body can be read by [`recv_data()`] until it returns -/// `None`. Then the trailers will eventually be available via [`recv_trailers()`]. -/// -/// TODO: If data is polled before the response has been received, an error will be thrown. -/// -/// TODO: If trailers are polled but the body hasn't been fully received, an UNEXPECT_FRAME error will be -/// thrown -/// -/// Whenever the client wants to cancel this request, it can call [`stop_sending()`], which will -/// put an end to any transfer concerning it. -/// -/// # Examples -/// -/// ```rust -/// # use h3::{quic, client::*}; -/// # use http::{Request, Response}; -/// # use bytes::Buf; -/// # use tokio::io::AsyncWriteExt; -/// # async fn doc(mut req_stream: RequestStream) -> Result<(), Box> -/// # where -/// # T: quic::RecvStream, -/// # B: Buf, -/// # { -/// // Prepare the HTTP request to send to the server -/// let request = Request::get("https://www.example.com/").body(())?; -/// -/// // Receive the response -/// let response = req_stream.recv_response().await?; -/// // Receive the body -/// while let Some(mut chunk) = req_stream.recv_data().await? { -/// let mut out = tokio::io::stdout(); -/// out.write_all_buf(&mut chunk).await?; -/// out.flush().await?; -/// } -/// # Ok(()) -/// # } -/// # pub fn main() {} -/// ``` -/// -/// [`send_request()`]: struct.SendRequest.html#method.send_request -/// [`recv_response()`]: #method.recv_response -/// [`recv_data()`]: #method.recv_data -/// [`send_data()`]: #method.send_data -/// [`send_trailers()`]: #method.send_trailers -/// [`recv_trailers()`]: #method.recv_trailers -/// [`finish()`]: #method.finish -/// [`stop_sending()`]: #method.stop_sending -pub struct RequestStream { - inner: connection::RequestStream, -} - -impl ConnectionState for RequestStream { - fn shared_state(&self) -> &SharedStateRef { - &self.inner.conn_state - } -} - -impl RequestStream -where - S: quic::RecvStream, -{ - /// Receive the HTTP/3 response - /// - /// This should be called before trying to receive any data with [`recv_data()`]. - /// - /// [`recv_data()`]: #method.recv_data - pub async fn recv_response(&mut self) -> Result, Error> { - let mut frame = future::poll_fn(|cx| self.inner.stream.poll_next(cx)) - .await - .map_err(|e| self.maybe_conn_err(e))? - .ok_or_else(|| { - Code::H3_GENERAL_PROTOCOL_ERROR.with_reason( - "Did not receive response headers", - ErrorLevel::ConnectionError, - ) - })?; - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 - //= type=TODO - //# A client MUST treat - //# receipt of a PUSH_PROMISE frame that contains a larger push ID than - //# the client has advertised as a connection error of H3_ID_ERROR. - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 - //= type=TODO - //# If a client - //# receives a push ID that has already been promised and detects a - //# mismatch, it MUST respond with a connection error of type - //# H3_GENERAL_PROTOCOL_ERROR. - - let decoded = if let Frame::Headers(ref mut encoded) = frame { - match qpack::decode_stateless(encoded, self.inner.max_field_section_size) { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 - //# An HTTP/3 implementation MAY impose a limit on the maximum size of - //# the message header it will accept on an individual HTTP message. - Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => { - self.inner.stop_sending(Code::H3_REQUEST_CANCELLED); - return Err(Error::header_too_big( - cancel_size, - self.inner.max_field_section_size, - )); - } - Ok(decoded) => decoded, - Err(e) => return Err(e.into()), - } - } else { - return Err(Code::H3_FRAME_UNEXPECTED.with_reason( - "First response frame is not headers", - ErrorLevel::ConnectionError, - )); - }; - - let qpack::Decoded { fields, .. } = decoded; - - let (status, headers) = Header::try_from(fields)?.into_response_parts()?; - let mut resp = Response::new(()); - *resp.status_mut() = status; - *resp.headers_mut() = headers; - *resp.version_mut() = http::Version::HTTP_3; - - Ok(resp) - } - - /// Receive some of the request body. - // TODO what if called before recv_response ? - pub async fn recv_data(&mut self) -> Result, Error> { - self.inner.recv_data().await - } - - /// Receive an optional set of trailers for the response. - pub async fn recv_trailers(&mut self) -> Result, Error> { - let res = self.inner.recv_trailers().await; - if let Err(ref e) = res { - if e.is_header_too_big() { - self.inner.stream.stop_sending(Code::H3_REQUEST_CANCELLED); - } - } - res - } - - /// Tell the peer to stop sending into the underlying QUIC stream - pub fn stop_sending(&mut self, error_code: crate::error::Code) { - // TODO take by value to prevent any further call as this request is cancelled - // rename `cancel()` ? - self.inner.stream.stop_sending(error_code) - } -} - -impl RequestStream -where - S: quic::SendStream, - B: Buf, -{ - /// Send some data on the request body. - pub async fn send_data(&mut self, buf: B) -> Result<(), Error> { - self.inner.send_data(buf).await - } - - /// Send a set of trailers to end the request. - /// - /// Either [`RequestStream::finish`] or - /// [`RequestStream::send_trailers`] must be called to finalize a - /// request. - pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> { - self.inner.send_trailers(trailers).await - } - - /// End the request without trailers. - /// - /// Either [`RequestStream::finish`] or - /// [`RequestStream::send_trailers`] must be called to finalize a - /// request. - pub async fn finish(&mut self) -> Result<(), Error> { - self.inner.finish().await - } - - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.1 - //= type=TODO - //# Implementations SHOULD cancel requests by abruptly terminating any - //# directions of a stream that are still open. To do so, an - //# implementation resets the sending parts of streams and aborts reading - //# on the receiving parts of streams; see Section 2.4 of - //# [QUIC-TRANSPORT]. -} - -impl RequestStream -where - S: quic::BidiStream, - B: Buf, -{ - /// Split this stream into two halves that can be driven independently. - pub fn split( - self, - ) -> ( - RequestStream, - RequestStream, - ) { - let (send, recv) = self.inner.split(); - (RequestStream { inner: send }, RequestStream { inner: recv }) - } -} diff --git a/h3/src/client/mod.rs b/h3/src/client/mod.rs new file mode 100644 index 00000000..2435a4c3 --- /dev/null +++ b/h3/src/client/mod.rs @@ -0,0 +1,12 @@ +//! HTTP/3 client + +mod connection; +mod stream; + +mod builder; + +pub use builder::builder; +pub use builder::new; +pub use builder::Builder; +pub use connection::{Connection, SendRequest}; +pub use stream::RequestStream; diff --git a/h3/src/client/stream.rs b/h3/src/client/stream.rs new file mode 100644 index 00000000..e2b4b8e6 --- /dev/null +++ b/h3/src/client/stream.rs @@ -0,0 +1,219 @@ +use bytes::Buf; +use futures_util::future; +use http::{HeaderMap, Response}; + +use crate::{ + connection::{self, ConnectionState, SharedStateRef}, + error::{Code, Error, ErrorLevel}, + proto::{frame::Frame, headers::Header}, + qpack, + quic::{self}, +}; +use std::convert::TryFrom; + +/// Manage request bodies transfer, response and trailers. +/// +/// Once a request has been sent via [`send_request()`], a response can be awaited by calling +/// [`recv_response()`]. A body for this request can be sent with [`send_data()`], then the request +/// shall be completed by either sending trailers with [`send_trailers()`], or [`finish()`]. +/// +/// After receiving the response's headers, it's body can be read by [`recv_data()`] until it returns +/// `None`. Then the trailers will eventually be available via [`recv_trailers()`]. +/// +/// TODO: If data is polled before the response has been received, an error will be thrown. +/// +/// TODO: If trailers are polled but the body hasn't been fully received, an UNEXPECT_FRAME error will be +/// thrown +/// +/// Whenever the client wants to cancel this request, it can call [`stop_sending()`], which will +/// put an end to any transfer concerning it. +/// +/// # Examples +/// +/// ```rust +/// # use h3::{quic, client::*}; +/// # use http::{Request, Response}; +/// # use bytes::Buf; +/// # use tokio::io::AsyncWriteExt; +/// # async fn doc(mut req_stream: RequestStream) -> Result<(), Box> +/// # where +/// # T: quic::RecvStream, +/// # { +/// // Prepare the HTTP request to send to the server +/// let request = Request::get("https://www.example.com/").body(())?; +/// +/// // Receive the response +/// let response = req_stream.recv_response().await?; +/// // Receive the body +/// while let Some(mut chunk) = req_stream.recv_data().await? { +/// let mut out = tokio::io::stdout(); +/// out.write_all_buf(&mut chunk).await?; +/// out.flush().await?; +/// } +/// # Ok(()) +/// # } +/// # pub fn main() {} +/// ``` +/// +/// [`send_request()`]: struct.SendRequest.html#method.send_request +/// [`recv_response()`]: #method.recv_response +/// [`recv_data()`]: #method.recv_data +/// [`send_data()`]: #method.send_data +/// [`send_trailers()`]: #method.send_trailers +/// [`recv_trailers()`]: #method.recv_trailers +/// [`finish()`]: #method.finish +/// [`stop_sending()`]: #method.stop_sending +pub struct RequestStream { + pub(super) inner: connection::RequestStream, +} + +impl ConnectionState for RequestStream { + fn shared_state(&self) -> &SharedStateRef { + &self.inner.conn_state + } +} + +impl RequestStream +where + S: quic::RecvStream, +{ + /// Receive the HTTP/3 response + /// + /// This should be called before trying to receive any data with [`recv_data()`]. + /// + /// [`recv_data()`]: #method.recv_data + pub async fn recv_response(&mut self) -> Result, Error> { + let mut frame = future::poll_fn(|cx| self.inner.stream.poll_next(cx)) + .await + .map_err(|e| self.maybe_conn_err(e))? + .ok_or_else(|| { + Code::H3_GENERAL_PROTOCOL_ERROR.with_reason( + "Did not receive response headers", + ErrorLevel::ConnectionError, + ) + })?; + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 + //= type=TODO + //# A client MUST treat + //# receipt of a PUSH_PROMISE frame that contains a larger push ID than + //# the client has advertised as a connection error of H3_ID_ERROR. + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 + //= type=TODO + //# If a client + //# receives a push ID that has already been promised and detects a + //# mismatch, it MUST respond with a connection error of type + //# H3_GENERAL_PROTOCOL_ERROR. + + let decoded = if let Frame::Headers(ref mut encoded) = frame { + match qpack::decode_stateless(encoded, self.inner.max_field_section_size) { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 + //# An HTTP/3 implementation MAY impose a limit on the maximum size of + //# the message header it will accept on an individual HTTP message. + Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => { + self.inner.stop_sending(Code::H3_REQUEST_CANCELLED); + return Err(Error::header_too_big( + cancel_size, + self.inner.max_field_section_size, + )); + } + Ok(decoded) => decoded, + Err(e) => return Err(e.into()), + } + } else { + return Err(Code::H3_FRAME_UNEXPECTED.with_reason( + "First response frame is not headers", + ErrorLevel::ConnectionError, + )); + }; + + let qpack::Decoded { fields, .. } = decoded; + + let (status, headers) = Header::try_from(fields)?.into_response_parts()?; + let mut resp = Response::new(()); + *resp.status_mut() = status; + *resp.headers_mut() = headers; + *resp.version_mut() = http::Version::HTTP_3; + + Ok(resp) + } + + /// Receive some of the request body. + // TODO what if called before recv_response ? + pub async fn recv_data(&mut self) -> Result, Error> { + self.inner.recv_data().await + } + + /// Receive an optional set of trailers for the response. + pub async fn recv_trailers(&mut self) -> Result, Error> { + let res = self.inner.recv_trailers().await; + if let Err(ref e) = res { + if e.is_header_too_big() { + self.inner.stream.stop_sending(Code::H3_REQUEST_CANCELLED); + } + } + res + } + + /// Tell the peer to stop sending into the underlying QUIC stream + pub fn stop_sending(&mut self, error_code: crate::error::Code) { + // TODO take by value to prevent any further call as this request is cancelled + // rename `cancel()` ? + self.inner.stream.stop_sending(error_code) + } +} + +impl RequestStream +where + S: quic::SendStream, + B: Buf, +{ + /// Send some data on the request body. + pub async fn send_data(&mut self, buf: B) -> Result<(), Error> { + self.inner.send_data(buf).await + } + + /// Send a set of trailers to end the request. + /// + /// Either [`RequestStream::finish`] or + /// [`RequestStream::send_trailers`] must be called to finalize a + /// request. + pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> { + self.inner.send_trailers(trailers).await + } + + /// End the request without trailers. + /// + /// Either [`RequestStream::finish`] or + /// [`RequestStream::send_trailers`] must be called to finalize a + /// request. + pub async fn finish(&mut self) -> Result<(), Error> { + self.inner.finish().await + } + + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.1 + //= type=TODO + //# Implementations SHOULD cancel requests by abruptly terminating any + //# directions of a stream that are still open. To do so, an + //# implementation resets the sending parts of streams and aborts reading + //# on the receiving parts of streams; see Section 2.4 of + //# [QUIC-TRANSPORT]. +} + +impl RequestStream +where + S: quic::BidiStream, + B: Buf, +{ + /// Split this stream into two halves that can be driven independently. + pub fn split( + self, + ) -> ( + RequestStream, + RequestStream, + ) { + let (send, recv) = self.inner.split(); + (RequestStream { inner: send }, RequestStream { inner: recv }) + } +} diff --git a/h3/src/config.rs b/h3/src/config.rs new file mode 100644 index 00000000..ffff248c --- /dev/null +++ b/h3/src/config.rs @@ -0,0 +1,176 @@ +use std::convert::TryFrom; + +use crate::proto::{frame, varint::VarInt}; + +/// Configures the HTTP/3 connection +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub struct Config { + /// Just like in HTTP/2, HTTP/3 also uses the concept of "grease" + /// to prevent potential interoperability issues in the future. + /// In HTTP/3, the concept of grease is used to ensure that the protocol can evolve + /// and accommodate future changes without breaking existing implementations. + pub(crate) send_grease: bool, + + #[cfg(test)] + pub(crate) send_settings: bool, + + /// HTTP/3 Settings + pub settings: Settings, +} + +/// HTTP/3 Settings +#[derive(Debug, Clone, Copy)] +pub struct Settings { + /// The MAX_FIELD_SECTION_SIZE in HTTP/3 refers to the maximum size of the dynamic table used in HPACK compression. + /// HPACK is the compression algorithm used in HTTP/3 to reduce the size of the header fields in HTTP requests and responses. + + /// In HTTP/3, the MAX_FIELD_SECTION_SIZE is set to 12. + /// This means that the dynamic table used for HPACK compression can have a maximum size of 2^12 bytes, which is 4KB. + pub(crate) max_field_section_size: u64, + + /// https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.1 + /// Sets `SETTINGS_ENABLE_WEBTRANSPORT` if enabled + pub(crate) enable_webtransport: bool, + /// https://www.rfc-editor.org/info/rfc8441 defines an extended CONNECT method in Section 4, + /// enabled by the SETTINGS_ENABLE_CONNECT_PROTOCOL parameter. + /// That parameter is only defined for HTTP/2. + /// for extended CONNECT in HTTP/3; instead, the SETTINGS_ENABLE_WEBTRANSPORT setting implies that an endpoint supports extended CONNECT. + pub(crate) enable_extended_connect: bool, + /// Enable HTTP Datagrams, see https://datatracker.ietf.org/doc/rfc9297/ for details + pub(crate) enable_datagram: bool, + /// The maximum number of concurrent streams that can be opened by the peer. + pub(crate) max_webtransport_sessions: u64, +} + +impl From<&frame::Settings> for Settings { + fn from(settings: &frame::Settings) -> Self { + let defaults: Self = Default::default(); + Self { + max_field_section_size: settings + .get(frame::SettingId::MAX_HEADER_LIST_SIZE) + .unwrap_or(defaults.max_field_section_size), + enable_webtransport: settings + .get(frame::SettingId::ENABLE_WEBTRANSPORT) + .map(|value| value != 0) + .unwrap_or(defaults.enable_webtransport), + max_webtransport_sessions: settings + .get(frame::SettingId::WEBTRANSPORT_MAX_SESSIONS) + .unwrap_or(defaults.max_webtransport_sessions), + enable_datagram: settings + .get(frame::SettingId::H3_DATAGRAM) + .map(|value| value != 0) + .unwrap_or(defaults.enable_datagram), + enable_extended_connect: settings + .get(frame::SettingId::ENABLE_CONNECT_PROTOCOL) + .map(|value| value != 0) + .unwrap_or(defaults.enable_extended_connect), + } + } +} + +impl TryFrom for frame::Settings { + type Error = frame::SettingsError; + fn try_from(value: Config) -> Result { + let mut settings = frame::Settings::default(); + + let Config { + send_grease, + #[cfg(test)] + send_settings: _, + settings: + Settings { + max_field_section_size, + enable_webtransport, + enable_extended_connect, + enable_datagram, + max_webtransport_sessions, + }, + } = value; + + if send_grease { + // Grease Settings (https://www.rfc-editor.org/rfc/rfc9114.html#name-defined-settings-parameters) + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.4.1 + //# Setting identifiers of the format 0x1f * N + 0x21 for non-negative + //# integer values of N are reserved to exercise the requirement that + //# unknown identifiers be ignored. Such settings have no defined + //# meaning. Endpoints SHOULD include at least one such setting in their + //# SETTINGS frame. + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.4.1 + //# Setting identifiers that were defined in [HTTP/2] where there is no + //# corresponding HTTP/3 setting have also been reserved + //# (Section 11.2.2). These reserved settings MUST NOT be sent, and + //# their receipt MUST be treated as a connection error of type + //# H3_SETTINGS_ERROR. + match settings.insert(frame::SettingId::grease(), 0) { + Ok(_) => (), + Err(err) => tracing::warn!("Error when adding the grease Setting. Reason {}", err), + } + } + + settings.insert( + frame::SettingId::MAX_HEADER_LIST_SIZE, + max_field_section_size, + )?; + settings.insert( + frame::SettingId::ENABLE_CONNECT_PROTOCOL, + enable_extended_connect as u64, + )?; + settings.insert( + frame::SettingId::ENABLE_WEBTRANSPORT, + enable_webtransport as u64, + )?; + settings.insert(frame::SettingId::H3_DATAGRAM, enable_datagram as u64)?; + settings.insert( + frame::SettingId::WEBTRANSPORT_MAX_SESSIONS, + max_webtransport_sessions, + )?; + + Ok(settings) + } +} + +impl Default for Settings { + fn default() -> Self { + Self { + max_field_section_size: VarInt::MAX.0, + enable_webtransport: false, + enable_extended_connect: false, + enable_datagram: false, + max_webtransport_sessions: 0, + } + } +} + +impl Settings { + /// https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.1 + /// Sets `SETTINGS_ENABLE_WEBTRANSPORT` if enabled + pub fn enable_webtransport(&self) -> bool { + self.enable_webtransport + } + + /// Enable HTTP Datagrams, see https://datatracker.ietf.org/doc/rfc9297/ for details + pub fn enable_datagram(&self) -> bool { + self.enable_datagram + } + + /// https://www.rfc-editor.org/info/rfc8441 defines an extended CONNECT method in Section 4, + /// enabled by the SETTINGS_ENABLE_CONNECT_PROTOCOL parameter. + /// That parameter is only defined for HTTP/2. + /// for extended CONNECT in HTTP/3; instead, the SETTINGS_ENABLE_WEBTRANSPORT setting implies that an endpoint supports extended CONNECT. + pub fn enable_extended_connect(&self) -> bool { + self.enable_extended_connect + } +} + +impl Default for Config { + fn default() -> Self { + Self { + send_grease: true, + #[cfg(test)] + send_settings: true, + settings: Default::default(), + } + } +} diff --git a/h3/src/connection.rs b/h3/src/connection.rs index abf7cd92..f05ef0bc 100644 --- a/h3/src/connection.rs +++ b/h3/src/connection.rs @@ -7,26 +7,30 @@ use std::{ use bytes::{Buf, Bytes, BytesMut}; use futures_util::{future, ready}; use http::HeaderMap; -use tracing::warn; +use stream::WriteBuf; +use tracing::{trace, warn}; use crate::{ + config::{Config, Settings}, error::{Code, Error}, frame::FrameStream, proto::{ - frame::{Frame, PayloadLen, SettingId, Settings}, + frame::{self, Frame, PayloadLen}, headers::Header, stream::StreamType, varint::VarInt, }, qpack, quic::{self, SendStream as _}, - stream::{self, AcceptRecvStream, AcceptedRecvStream}, + stream::{self, AcceptRecvStream, AcceptedRecvStream, BufRecvStream, UniStreamHeader}, + webtransport::SessionId, }; #[doc(hidden)] +#[non_exhaustive] pub struct SharedState { - // maximum size for a header we send - pub peer_max_field_section_size: u64, + // Peer settings + pub peer_config: Settings, // connection-wide error, concerns all RequestStreams and drivers pub error: Option, // Has a GOAWAY frame been sent or received? @@ -50,13 +54,14 @@ impl SharedStateRef { impl Default for SharedStateRef { fn default() -> Self { Self(Arc::new(RwLock::new(SharedState { - peer_max_field_section_size: VarInt::MAX.0, + peer_config: Default::default(), error: None, closing: false, }))) } } +#[allow(missing_docs)] pub trait ConnectionState { fn shared_state(&self) -> &SharedStateRef; @@ -69,67 +74,84 @@ pub trait ConnectionState { } } +#[allow(missing_docs)] +pub struct AcceptedStreams +where + C: quic::Connection, + B: Buf, +{ + #[allow(missing_docs)] + pub wt_uni_streams: Vec<(SessionId, BufRecvStream)>, +} + +impl Default for AcceptedStreams +where + C: quic::Connection, + B: Buf, +{ + fn default() -> Self { + Self { + wt_uni_streams: Default::default(), + } + } +} + +#[allow(missing_docs)] pub struct ConnectionInner where C: quic::Connection, B: Buf, { pub(super) shared: SharedStateRef, - conn: C, + /// TODO: breaking encapsulation just to see if we can get this to work, will fix before merging + pub conn: C, control_send: C::SendStream, control_recv: Option>, decoder_recv: Option>, encoder_recv: Option>, - pending_recv_streams: Vec>, + /// Buffers incoming uni/recv streams which have yet to be claimed. + /// + /// This is opposed to discarding them by returning in `poll_accept_recv`, which may cause them to be missed by something else polling. + /// + /// See: + /// + /// In WebTransport over HTTP/3, the client MAY send its SETTINGS frame, as well as + /// multiple WebTransport CONNECT requests, WebTransport data streams and WebTransport + /// datagrams, all within a single flight. As those can arrive out of order, a WebTransport + /// server could be put into a situation where it receives a stream or a datagram without a + /// corresponding session. Similarly, a client may receive a server-initiated stream or a + /// datagram before receiving the CONNECT response headers from the server.To handle this + /// case, WebTransport endpoints SHOULD buffer streams and datagrams until those can be + /// associated with an established session. To avoid resource exhaustion, the endpoints + /// MUST limit the number of buffered streams and datagrams. When the number of buffered + /// streams is exceeded, a stream SHALL be closed by sending a RESET_STREAM and/or + /// STOP_SENDING with the H3_WEBTRANSPORT_BUFFERED_STREAM_REJECTED error code. When the + /// number of buffered datagrams is exceeded, a datagram SHALL be dropped. It is up to an + /// implementation to choose what stream or datagram to discard. + accepted_streams: AcceptedStreams, + + pending_recv_streams: Vec>, + got_peer_settings: bool, - pub(super) send_grease_frame: bool, + pub send_grease_frame: bool, + pub config: Config, } -impl ConnectionInner +impl ConnectionInner where C: quic::Connection, B: Buf, { - pub async fn new( - mut conn: C, - max_field_section_size: u64, - shared: SharedStateRef, - grease: bool, - ) -> Result { - //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2 - //# Endpoints SHOULD create the HTTP control stream as well as the - //# unidirectional streams required by mandatory extensions (such as the - //# QPACK encoder and decoder streams) first, and then create additional - //# streams as allowed by their peer. - let mut control_send = future::poll_fn(|cx| conn.poll_open_send(cx)) - .await - .map_err(|e| Code::H3_STREAM_CREATION_ERROR.with_transport(e))?; + pub async fn send_settings(&mut self) -> Result<(), Error> { + #[cfg(test)] + if !self.config.send_settings { + return Ok(()); + } - let mut settings = Settings::default(); - settings - .insert(SettingId::MAX_HEADER_LIST_SIZE, max_field_section_size) + let settings = frame::Settings::try_from(self.config) .map_err(|e| Code::H3_INTERNAL_ERROR.with_cause(e))?; - if grease { - // Grease Settings (https://www.rfc-editor.org/rfc/rfc9114.html#name-defined-settings-parameters) - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.4.1 - //# Setting identifiers of the format 0x1f * N + 0x21 for non-negative - //# integer values of N are reserved to exercise the requirement that - //# unknown identifiers be ignored. Such settings have no defined - //# meaning. Endpoints SHOULD include at least one such setting in their - //# SETTINGS frame. - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.4.1 - //# Setting identifiers that were defined in [HTTP/2] where there is no - //# corresponding HTTP/3 setting have also been reserved - //# (Section 11.2.2). These reserved settings MUST NOT be sent, and - //# their receipt MUST be treated as a connection error of type - //# H3_SETTINGS_ERROR. - match settings.insert(SettingId::grease(), 0) { - Ok(_) => (), - Err(err) => warn!("Error when adding the grease Setting. Reason {}", err), - } - } + tracing::debug!("Sending server settings: {:#x?}", settings); //= https://www.rfc-editor.org/rfc/rfc9114#section-3.2 //# After the QUIC connection is @@ -156,12 +178,27 @@ where //# Endpoints MUST NOT require any data to be received from //# the peer prior to sending the SETTINGS frame; settings MUST be sent //# as soon as the transport is ready to send data. + trace!("Sending Settings frame: {:#x?}", settings); stream::write( - &mut control_send, - (StreamType::CONTROL, Frame::Settings(settings)), + &mut self.control_send, + WriteBuf::from(UniStreamHeader::Control(settings)), ) .await?; + Ok(()) + } + + /// Initiates the connection and opens a control stream + pub async fn new(mut conn: C, shared: SharedStateRef, config: Config) -> Result { + //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2 + //# Endpoints SHOULD create the HTTP control stream as well as the + //# unidirectional streams required by mandatory extensions (such as the + //# QPACK encoder and decoder streams) first, and then create additional + //# streams as allowed by their peer. + let control_send = future::poll_fn(|cx| conn.poll_open_send(cx)) + .await + .map_err(|e| Code::H3_STREAM_CREATION_ERROR.with_transport(e))?; + //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 //= type=implication //# The @@ -176,10 +213,15 @@ where encoder_recv: None, pending_recv_streams: Vec::with_capacity(3), got_peer_settings: false, - send_grease_frame: grease, + send_grease_frame: config.send_grease, + config, + accepted_streams: Default::default(), }; + + conn_inner.send_settings().await?; + // start a grease stream - if grease { + if config.send_grease { //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.8 //= type=implication //# Frame types of the format 0x1f * N + 0x21 for non-negative integer @@ -193,6 +235,7 @@ where } /// Send GOAWAY with specified max_id, iff max_id is smaller than the previous one. + pub async fn shutdown( &mut self, sent_closing: &mut Option, @@ -220,6 +263,7 @@ where stream::write(&mut self.control_send, Frame::Goaway(max_id.into())).await } + #[allow(missing_docs)] pub fn poll_accept_request( &mut self, cx: &mut Context<'_>, @@ -231,26 +275,31 @@ where } } + // Accept the request by accepting the next bidirectional stream // .into().into() converts the impl QuicError into crate::error::Error. // The `?` operator doesn't work here for some reason. self.conn.poll_accept_bidi(cx).map_err(|e| e.into().into()) } - pub fn poll_accept_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + /// Polls incoming streams + /// + /// Accepted streams which are not control, decoder, or encoder streams are buffer in `accepted_recv_streams` + pub fn poll_accept_recv(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { if let Some(ref e) = self.shared.read("poll_accept_request").error { - return Poll::Ready(Err(e.clone())); + return Err(e.clone()); } + // Get all currently pending streams loop { match self.conn.poll_accept_recv(cx)? { Poll::Ready(Some(stream)) => self .pending_recv_streams .push(AcceptRecvStream::new(stream)), Poll::Ready(None) => { - return Poll::Ready(Err(Code::H3_GENERAL_PROTOCOL_ERROR.with_reason( + return Err(Code::H3_GENERAL_PROTOCOL_ERROR.with_reason( "Connection closed unexpected", crate::error::ErrorLevel::ConnectionError, - ))) + )) } Poll::Pending => break, } @@ -275,6 +324,7 @@ where .pending_recv_streams .remove(index - removed) .into_stream()?; + match stream { //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 //# Only one control stream per peer is permitted; @@ -282,26 +332,33 @@ where //# treated as a connection error of type H3_STREAM_CREATION_ERROR. AcceptedRecvStream::Control(s) => { if self.control_recv.is_some() { - return Poll::Ready(Err( + return Err( self.close(Code::H3_STREAM_CREATION_ERROR, "got two control streams") - )); + ); } self.control_recv = Some(s); } enc @ AcceptedRecvStream::Encoder(_) => { if let Some(_prev) = self.encoder_recv.replace(enc) { - return Poll::Ready(Err( + return Err( self.close(Code::H3_STREAM_CREATION_ERROR, "got two encoder streams") - )); + ); } } dec @ AcceptedRecvStream::Decoder(_) => { if let Some(_prev) = self.decoder_recv.replace(dec) { - return Poll::Ready(Err( + return Err( self.close(Code::H3_STREAM_CREATION_ERROR, "got two decoder streams") - )); + ); } } + AcceptedRecvStream::WebTransportUni(id, s) + if self.config.settings.enable_webtransport => + { + // Store until someone else picks it up, like a webtransport session which is + // not yet established. + self.accepted_streams.wt_uni_streams.push((id, s)) + } //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.3 //= type=implication @@ -311,28 +368,27 @@ where } } - Poll::Pending + Ok(()) } + /// Waits for the control stream to be received and reads subsequent frames. pub fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll, Error>> { if let Some(ref e) = self.shared.read("poll_accept_request").error { return Poll::Ready(Err(e.clone())); } - loop { - match self.poll_accept_recv(cx) { - Poll::Ready(Ok(_)) => continue, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending if self.control_recv.is_none() => return Poll::Pending, - _ => break, + let recv = { + // TODO + self.poll_accept_recv(cx)?; + if let Some(v) = &mut self.control_recv { + v + } else { + // Try later + return Poll::Pending; } - } + }; - let recvd = ready!(self - .control_recv - .as_mut() - .expect("control_recv") - .poll_next(cx))?; + let recvd = ready!(recv.poll_next(cx))?; let res = match recvd { //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 @@ -367,11 +423,9 @@ where //= type=implication //# Endpoints MUST NOT consider such settings to have //# any meaning upon receipt. - self.shared - .write("connection settings write") - .peer_max_field_section_size = settings - .get(SettingId::MAX_HEADER_LIST_SIZE) - .unwrap_or(VarInt::MAX.0); + let mut shared = self.shared.write("connection settings write"); + shared.peer_config = (&settings).into(); + Ok(Frame::Settings(settings)) } f @ Frame::Goaway(_) => Ok(f), @@ -524,8 +578,14 @@ where Err(err) => warn!("grease stream error on close {}", err), }; } + + #[allow(missing_docs)] + pub fn accepted_streams_mut(&mut self) -> &mut AcceptedStreams { + &mut self.accepted_streams + } } +#[allow(missing_docs)] pub struct RequestStream { pub(super) stream: FrameStream, pub(super) trailers: Option, @@ -535,6 +595,7 @@ pub struct RequestStream { } impl RequestStream { + #[allow(missing_docs)] pub fn new( stream: FrameStream, max_field_section_size: u64, @@ -562,16 +623,21 @@ where S: quic::RecvStream, { /// Receive some of the request body. - pub async fn recv_data(&mut self) -> Result, Error> { + pub fn poll_recv_data( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Error>> { if !self.stream.has_data() { - let frame = future::poll_fn(|cx| self.stream.poll_next(cx)) - .await + let frame = self + .stream + .poll_next(cx) .map_err(|e| self.maybe_conn_err(e))?; - match frame { + + match ready!(frame) { Some(Frame::Data { .. }) => (), Some(Frame::Headers(encoded)) => { self.trailers = Some(encoded); - return Ok(None); + return Poll::Ready(Ok(None)); } //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 @@ -596,15 +662,18 @@ where //# The MAX_PUSH_ID frame is always sent on the control stream. Receipt //# of a MAX_PUSH_ID frame on any other stream MUST be treated as a //# connection error of type H3_FRAME_UNEXPECTED. - Some(_) => return Err(Code::H3_FRAME_UNEXPECTED.into()), - None => return Ok(None), + Some(_) => return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.into())), + None => return Poll::Ready(Ok(None)), } } - let data = future::poll_fn(|cx| self.stream.poll_data(cx)) - .await - .map_err(|e| self.maybe_conn_err(e))?; - Ok(data) + self.stream + .poll_data(cx) + .map_err(|e| self.maybe_conn_err(e)) + } + /// Receive some of the request body. + pub async fn recv_data(&mut self) -> Result, Error> { + future::poll_fn(|cx| self.poll_recv_data(cx)).await } /// Receive trailers @@ -675,6 +744,7 @@ where Ok(Some(Header::try_from(fields)?.into_fields())) } + #[allow(missing_docs)] pub fn stop_sending(&mut self, err_code: Code) { self.stream.stop_sending(err_code); } @@ -707,7 +777,8 @@ where let max_mem_size = self .conn_state .read("send_trailers shared state read") - .peer_max_field_section_size; + .peer_config + .max_field_section_size; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 //# An implementation that @@ -729,6 +800,7 @@ where self.stream.reset(code.into()); } + #[allow(missing_docs)] pub async fn finish(&mut self) -> Result<(), Error> { if self.send_grease_frame { // send a grease frame once per Connection @@ -737,9 +809,7 @@ where .map_err(|e| self.maybe_conn_err(e))?; self.send_grease_frame = false; } - future::poll_fn(|cx| self.stream.poll_ready(cx)) - .await - .map_err(|e| self.maybe_conn_err(e))?; + future::poll_fn(|cx| self.stream.poll_finish(cx)) .await .map_err(|e| self.maybe_conn_err(e)) diff --git a/h3/src/error.rs b/h3/src/error.rs index 5c693cd6..b622262c 100644 --- a/h3/src/error.rs +++ b/h3/src/error.rs @@ -12,6 +12,7 @@ pub(crate) type TransportError = Box; /// A general error that can occur when handling the HTTP/3 protocol. #[derive(Clone)] pub struct Error { + /// The error kind. pub(crate) inner: Box, } @@ -37,6 +38,7 @@ impl PartialEq for Code { } } +/// The error kind. #[derive(Clone)] pub(crate) struct ErrorImpl { pub(crate) kind: Kind, @@ -110,6 +112,9 @@ macro_rules! codes { } codes! { + /// Datagram or capsule parse error + /// See: + (0x33, H3_DATAGRAM_ERROR); /// No error. This is used when the connection or stream needs to be /// closed, but there is no error to signal. (0x100, H3_NO_ERROR); @@ -272,7 +277,6 @@ impl Error { matches!(&self.inner.kind, Kind::HeaderTooBig { .. }) } - #[cfg(test)] #[doc(hidden)] pub fn kind(&self) -> Kind { self.inner.kind.clone() diff --git a/h3/src/ext.rs b/h3/src/ext.rs new file mode 100644 index 00000000..ff1f0e58 --- /dev/null +++ b/h3/src/ext.rs @@ -0,0 +1,122 @@ +//! Extensions for the HTTP/3 protocol. + +use std::convert::TryFrom; +use std::str::FromStr; + +use bytes::{Buf, Bytes}; + +use crate::{ + error::Code, + proto::{stream::StreamId, varint::VarInt}, + Error, +}; + +/// Describes the `:protocol` pseudo-header for extended connect +/// +/// See: +#[derive(Copy, PartialEq, Debug, Clone)] +pub struct Protocol(ProtocolInner); + +impl Protocol { + /// WebTransport protocol + pub const WEB_TRANSPORT: Protocol = Protocol(ProtocolInner::WebTransport); + /// RFC 9298 protocol + pub const CONNECT_UDP: Protocol = Protocol(ProtocolInner::ConnectUdp); + + /// Return a &str representation of the `:protocol` pseudo-header value + #[inline] + pub fn as_str(&self) -> &str { + match self.0 { + ProtocolInner::WebTransport => "webtransport", + ProtocolInner::ConnectUdp => "connect-udp", + } + } +} + +#[derive(Copy, PartialEq, Debug, Clone)] +enum ProtocolInner { + WebTransport, + ConnectUdp, +} + +/// Error when parsing the protocol +pub struct InvalidProtocol; + +impl FromStr for Protocol { + type Err = InvalidProtocol; + + fn from_str(s: &str) -> Result { + match s { + "webtransport" => Ok(Self(ProtocolInner::WebTransport)), + "connect-udp" => Ok(Self(ProtocolInner::ConnectUdp)), + _ => Err(InvalidProtocol), + } + } +} + +/// HTTP datagram frames +/// See: +pub struct Datagram { + /// Stream id divided by 4 + stream_id: StreamId, + /// The data contained in the datagram + payload: B, +} + +impl Datagram +where + B: Buf, +{ + /// Creates a new datagram frame + pub fn new(stream_id: StreamId, payload: B) -> Self { + assert!( + stream_id.into_inner() % 4 == 0, + "StreamId is not divisible by 4" + ); + Self { stream_id, payload } + } + + /// Decodes a datagram frame from the QUIC datagram + pub fn decode(mut buf: B) -> Result { + let q_stream_id = VarInt::decode(&mut buf) + .map_err(|_| Code::H3_DATAGRAM_ERROR.with_cause("Malformed datagram frame"))?; + + //= https://www.rfc-editor.org/rfc/rfc9297#section-2.1 + // Quarter Stream ID: A variable-length integer that contains the value of the client-initiated bidirectional + // stream that this datagram is associated with divided by four (the division by four stems + // from the fact that HTTP requests are sent on client-initiated bidirectional streams, + // which have stream IDs that are divisible by four). The largest legal QUIC stream ID + // value is 262-1, so the largest legal value of the Quarter Stream ID field is 260-1. + // Receipt of an HTTP/3 Datagram that includes a larger value MUST be treated as an HTTP/3 + // connection error of type H3_DATAGRAM_ERROR (0x33). + let stream_id = StreamId::try_from(u64::from(q_stream_id) * 4) + .map_err(|_| Code::H3_DATAGRAM_ERROR.with_cause("Invalid stream id"))?; + + let payload = buf; + + Ok(Self { stream_id, payload }) + } + + #[inline] + /// Returns the associated stream id of the datagram + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + #[inline] + /// Returns the datagram payload + pub fn payload(&self) -> &B { + &self.payload + } + + /// Encode the datagram to wire format + pub fn encode(self, buf: &mut D) { + (VarInt::from(self.stream_id) / 4).encode(buf); + buf.put(self.payload); + } + + /// Returns the datagram payload + pub fn into_payload(self) -> B { + self.payload + } +} diff --git a/h3/src/frame.rs b/h3/src/frame.rs index 3c50e961..a166fa1b 100644 --- a/h3/src/frame.rs +++ b/h3/src/frame.rs @@ -1,11 +1,10 @@ -use std::marker::PhantomData; use std::task::{Context, Poll}; -use bytes::{Buf, Bytes}; +use bytes::Buf; -use futures_util::ready; use tracing::trace; +use crate::stream::{BufRecvStream, WriteBuf}; use crate::{ buf::BufList, error::TransportError, @@ -14,34 +13,30 @@ use crate::{ stream::StreamId, }, quic::{BidiStream, RecvStream, SendStream}, - stream::WriteBuf, }; +/// Decodes Frames from the underlying QUIC stream pub struct FrameStream { - stream: S, - bufs: BufList, + pub stream: BufRecvStream, + // Already read data from the stream decoder: FrameDecoder, remaining_data: usize, - /// Set to true when `stream` reaches the end. - is_eos: bool, - _phantom_buffer: PhantomData, } impl FrameStream { - pub fn new(stream: S) -> Self { - Self::with_bufs(stream, BufList::new()) - } - - pub(crate) fn with_bufs(stream: S, bufs: BufList) -> Self { + pub fn new(stream: BufRecvStream) -> Self { Self { stream, - bufs, decoder: FrameDecoder::default(), remaining_data: 0, - is_eos: false, - _phantom_buffer: PhantomData, } } + + /// Unwraps the Framed streamer and returns the underlying stream **without** data loss for + /// partially received/read frames. + pub fn into_inner(self) -> BufRecvStream { + self.stream + } } impl FrameStream @@ -60,18 +55,22 @@ where loop { let end = self.try_recv(cx)?; - return match self.decoder.decode(&mut self.bufs)? { + return match self.decoder.decode(self.stream.buf_mut())? { Some(Frame::Data(PayloadLen(len))) => { self.remaining_data = len; Poll::Ready(Ok(Some(Frame::Data(PayloadLen(len))))) } + frame @ Some(Frame::WebTransportStream(_)) => { + self.remaining_data = usize::MAX; + Poll::Ready(Ok(frame)) + } Some(frame) => Poll::Ready(Ok(Some(frame))), None => match end { // Received a chunk but frame is incomplete, poll until we get `Pending`. Poll::Ready(false) => continue, Poll::Pending => Poll::Pending, Poll::Ready(true) => { - if self.bufs.has_remaining() { + if self.stream.buf_mut().has_remaining() { // Reached the end of receive stream, but there is still some data: // The frame is incomplete. Poll::Ready(Err(FrameStreamError::UnexpectedEnd)) @@ -84,6 +83,10 @@ where } } + /// Retrieves the next piece of data in an incoming data packet or webtransport stream + /// + /// + /// WebTransport bidirectional payload has no finite length and is processed until the end of the stream. pub fn poll_data( &mut self, cx: &mut Context<'_>, @@ -92,14 +95,19 @@ where return Poll::Ready(Ok(None)); }; - let end = ready!(self.try_recv(cx))?; - let data = self.bufs.take_chunk(self.remaining_data as usize); + let end = match self.try_recv(cx) { + Poll::Ready(Ok(end)) => end, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => false, + }; + let data = self.stream.buf_mut().take_chunk(self.remaining_data); match (data, end) { (None, true) => Poll::Ready(Ok(None)), (None, false) => Poll::Pending, (Some(d), true) - if d.remaining() < self.remaining_data && !self.bufs.has_remaining() => + if d.remaining() < self.remaining_data + && !self.stream.buf_mut().has_remaining() => { Poll::Ready(Err(FrameStreamError::UnexpectedEnd)) } @@ -110,6 +118,7 @@ where } } + /// Stops the underlying stream with the provided error code pub(crate) fn stop_sending(&mut self, error_code: crate::error::Code) { self.stream.stop_sending(error_code.into()); } @@ -119,26 +128,23 @@ where } pub(crate) fn is_eos(&self) -> bool { - self.is_eos && !self.bufs.has_remaining() + self.stream.is_eos() && !self.stream.buf().has_remaining() } fn try_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.is_eos { + if self.stream.is_eos() { return Poll::Ready(Ok(true)); } - match self.stream.poll_data(cx) { + match self.stream.poll_read(cx) { Poll::Ready(Err(e)) => Poll::Ready(Err(FrameStreamError::Quic(e.into()))), Poll::Pending => Poll::Pending, - Poll::Ready(Ok(None)) => { - self.is_eos = true; - Poll::Ready(Ok(true)) - } - Poll::Ready(Ok(Some(mut d))) => { - self.bufs.push_bytes(&mut d); - Poll::Ready(Ok(false)) - } + Poll::Ready(Ok(eos)) => Poll::Ready(Ok(eos)), } } + + pub fn id(&self) -> StreamId { + self.stream.recv_id() + } } impl SendStream for FrameStream @@ -164,8 +170,8 @@ where self.stream.reset(reset_code) } - fn id(&self) -> StreamId { - self.stream.id() + fn send_id(&self) -> StreamId { + self.stream.send_id() } } @@ -179,19 +185,13 @@ where ( FrameStream { stream: send, - bufs: BufList::new(), decoder: FrameDecoder::default(), remaining_data: 0, - is_eos: false, - _phantom_buffer: PhantomData, }, FrameStream { stream: recv, - bufs: self.bufs, decoder: self.decoder, remaining_data: self.remaining_data, - is_eos: self.is_eos, - _phantom_buffer: PhantomData, }, ) } @@ -223,7 +223,7 @@ impl FrameDecoder { let (pos, decoded) = { let mut cur = src.cursor(); let decoded = Frame::decode(&mut cur); - (cur.position() as usize, decoded) + (cur.position(), decoded) }; match decoded { @@ -266,7 +266,7 @@ mod tests { use super::*; use assert_matches::assert_matches; - use bytes::{BufMut, BytesMut}; + use bytes::{BufMut, Bytes, BytesMut}; use futures_util::future::poll_fn; use std::{collections::VecDeque, fmt, sync::Arc}; @@ -373,7 +373,7 @@ mod tests { Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!(|cx| stream.poll_next(cx), Ok(Some(Frame::Headers(_)))); assert_poll_matches!( @@ -395,7 +395,7 @@ mod tests { Frame::headers(&b"header"[..]).encode_with_payload(&mut buf); let mut buf = buf.freeze(); recv.chunk(buf.split_to(buf.len() - 1)); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -414,7 +414,7 @@ mod tests { FrameType::DATA.encode(&mut buf); VarInt::from(4u32).encode(&mut buf); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -436,7 +436,7 @@ mod tests { let mut buf = buf.freeze(); recv.chunk(buf.split_to(buf.len() - 2)); recv.chunk(buf); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); // We get the total size of data about to be received assert_poll_matches!( @@ -465,7 +465,7 @@ mod tests { VarInt::from(4u32).encode(&mut buf); buf.put_slice(&b"b"[..]); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -497,7 +497,7 @@ mod tests { Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -519,7 +519,7 @@ mod tests { buf.put_slice(&b"bo"[..]); recv.chunk(buf.clone().freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -528,7 +528,7 @@ mod tests { buf.truncate(0); buf.put_slice(&b"dy"[..]); - stream.bufs.push_bytes(&mut buf.freeze()); + stream.stream.buf_mut().push_bytes(&mut buf.freeze()); assert_poll_matches!( |cx| to_bytes(stream.poll_data(cx)), @@ -569,6 +569,10 @@ mod tests { fn stop_sending(&mut self, _: u64) { unimplemented!() } + + fn recv_id(&self) -> StreamId { + unimplemented!() + } } #[derive(Debug)] diff --git a/h3/src/lib.rs b/h3/src/lib.rs index 1cb59cd8..7dc10006 100644 --- a/h3/src/lib.rs +++ b/h3/src/lib.rs @@ -3,22 +3,48 @@ #![allow(clippy::derive_partial_eq_without_eq)] pub mod client; + +mod config; pub mod error; +pub mod ext; pub mod quic; + pub mod server; pub use error::Error; mod buf; + +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod connection; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod frame; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod proto; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod stream; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod webtransport; + +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod connection; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod frame; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod proto; -#[allow(dead_code)] -mod qpack; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod stream; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] +mod webtransport; +#[allow(dead_code)] +mod qpack; #[cfg(test)] mod tests; - #[cfg(test)] extern crate self as h3; diff --git a/h3/src/proto/frame.rs b/h3/src/proto/frame.rs index 340a9598..b65886e1 100644 --- a/h3/src/proto/frame.rs +++ b/h3/src/proto/frame.rs @@ -1,9 +1,14 @@ use bytes::{Buf, BufMut, Bytes}; -use std::{convert::TryInto, fmt}; +use std::{ + convert::TryInto, + fmt::{self, Debug}, +}; use tracing::trace; +use crate::webtransport::SessionId; + use super::{ - coding::Encode, + coding::{Decode, Encode}, push::{InvalidPushId, PushId}, stream::InvalidStreamId, varint::{BufExt, BufMutExt, UnexpectedEnd, VarInt}, @@ -46,13 +51,21 @@ pub enum Frame { PushPromise(PushPromise), Goaway(VarInt), MaxPushId(PushId), + /// Describes the header for a webtransport stream. + /// + /// The payload is sent streaming until the stream is closed + /// + /// Unwrap the framed streamer and read the inner stream until the end. + /// + /// Conversely, when sending, send this frame and unwrap the stream + WebTransportStream(SessionId), Grease, } /// Represents the available data len for a `Data` frame on a RecvStream /// /// Decoding received frames does not handle `Data` frames payload. Instead, receiving it -/// and passing it to the user is left under the responsability of `RequestStream`s. +/// and passing it to the user is left under the responsibility of `RequestStream`s. pub struct PayloadLen(pub usize); impl From for PayloadLen { @@ -62,11 +75,21 @@ impl From for PayloadLen { } impl Frame { - pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE * 3; + pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE * 7; + /// Decodes a Frame from the stream according to pub fn decode(buf: &mut T) -> Result { let remaining = buf.remaining(); let ty = FrameType::decode(buf).map_err(|_| FrameError::Incomplete(remaining + 1))?; + + // Webtransport streams need special handling as they have no length. + // + // See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-4.2 + if ty == FrameType::WEBTRANSPORT_BI_STREAM { + tracing::trace!("webtransport frame"); + return Ok(Frame::WebTransportStream(SessionId::decode(buf)?)); + } + let len = buf .get_var() .map_err(|_| FrameError::Incomplete(remaining + 1))?; @@ -80,6 +103,7 @@ impl Frame { } let mut payload = buf.take(len as usize); + trace!("frame ty: {:?}", ty); let frame = match ty { FrameType::HEADERS => Ok(Frame::Headers(payload.copy_to_bytes(len as usize))), FrameType::SETTINGS => Ok(Frame::Settings(Settings::decode(&mut payload)?)), @@ -91,11 +115,13 @@ impl Frame { | FrameType::H2_PING | FrameType::H2_WINDOW_UPDATE | FrameType::H2_CONTINUATION => Err(FrameError::UnsupportedFrame(ty.0)), + FrameType::WEBTRANSPORT_BI_STREAM | FrameType::DATA => unreachable!(), _ => { buf.advance(len as usize); Err(FrameError::UnknownFrame(ty.0)) } }; + if let Ok(frame) = &frame { trace!( "got frame {:?}, len: {}, remaining: {}", @@ -132,6 +158,11 @@ where buf.write_var(6); buf.put_slice(b"grease"); } + Frame::WebTransportStream(id) => { + FrameType::WEBTRANSPORT_BI_STREAM.encode(buf); + id.encode(buf); + // rest of the data is sent streaming + } } } } @@ -189,6 +220,7 @@ impl fmt::Debug for Frame { Frame::Goaway(id) => write!(f, "GoAway({})", id), Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id), Frame::Grease => write!(f, "Grease()"), + Frame::WebTransportStream(session) => write!(f, "WebTransportStream({:?})", session), } } } @@ -207,6 +239,7 @@ where Frame::Goaway(id) => write!(f, "GoAway({})", id), Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id), Frame::Grease => write!(f, "Grease()"), + Frame::WebTransportStream(_) => write!(f, "WebTransportStream()"), } } } @@ -226,6 +259,9 @@ impl PartialEq> for Frame { Frame::Goaway(x) => matches!(other, Frame::Goaway(y) if x == y), Frame::MaxPushId(x) => matches!(other, Frame::MaxPushId(y) if x == y), Frame::Grease => matches!(other, Frame::Grease), + Frame::WebTransportStream(x) => { + matches!(other, Frame::WebTransportStream(y) if x == y) + } } } } @@ -257,6 +293,8 @@ frame_types! { H2_WINDOW_UPDATE = 0x8, H2_CONTINUATION = 0x9, MAX_PUSH_ID = 0xD, + // Reserved frame types + WEBTRANSPORT_BI_STREAM = 0x41, } impl FrameType { @@ -350,7 +388,11 @@ impl SettingId { self, SettingId::MAX_HEADER_LIST_SIZE | SettingId::QPACK_MAX_TABLE_CAPACITY - | SettingId::QPACK_MAX_BLOCKED_STREAMS, + | SettingId::QPACK_MAX_BLOCKED_STREAMS + | SettingId::ENABLE_CONNECT_PROTOCOL + | SettingId::ENABLE_WEBTRANSPORT + | SettingId::WEBTRANSPORT_MAX_SESSIONS + | SettingId::H3_DATAGRAM, ) } @@ -389,9 +431,19 @@ setting_identifiers! { QPACK_MAX_TABLE_CAPACITY = 0x1, QPACK_MAX_BLOCKED_STREAMS = 0x7, MAX_HEADER_LIST_SIZE = 0x6, + // https://datatracker.ietf.org/doc/html/rfc9220#section-5 + ENABLE_CONNECT_PROTOCOL = 0x8, + // https://datatracker.ietf.org/doc/html/rfc9297#name-http-3-setting + H3_DATAGRAM = 0x33, + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-8.2 + ENABLE_WEBTRANSPORT = 0x2B603742, + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-8.2 + H3_SETTING_ENABLE_DATAGRAM_CHROME_SPECIFIC= 0xFFD277, + + WEBTRANSPORT_MAX_SESSIONS = 0x2b603743, } -const SETTINGS_LEN: usize = 4; +const SETTINGS_LEN: usize = 8; #[derive(Debug, PartialEq)] pub struct Settings { @@ -446,7 +498,7 @@ impl Settings { None } - pub(super) fn encode(&self, buf: &mut T) { + pub(crate) fn encode(&self, buf: &mut T) { self.encode_header(buf); for (id, val) in self.entries[..self.len].iter() { id.encode(buf); @@ -483,6 +535,8 @@ impl Settings { //# their receipt MUST be treated as a connection error of type //# H3_SETTINGS_ERROR. settings.insert(identifier, value)?; + } else { + tracing::warn!("Unsupported setting: {:#x?}", identifier); } } Ok(settings) @@ -594,6 +648,10 @@ mod tests { (SettingId::QPACK_MAX_TABLE_CAPACITY, 0xfad2), (SettingId::QPACK_MAX_BLOCKED_STREAMS, 0xfad3), (SettingId(95), 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), ], len: 4, }), @@ -607,6 +665,10 @@ mod tests { (SettingId::QPACK_MAX_BLOCKED_STREAMS, 0xfad3), // check without the Grease setting because this is ignored (SettingId(0), 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), ], len: 3, }), diff --git a/h3/src/proto/headers.rs b/h3/src/proto/headers.rs index 86f8de74..d081709e 100644 --- a/h3/src/proto/headers.rs +++ b/h3/src/proto/headers.rs @@ -8,10 +8,10 @@ use std::{ use http::{ header::{self, HeaderName, HeaderValue}, uri::{self, Authority, Parts, PathAndQuery, Scheme, Uri}, - HeaderMap, Method, StatusCode, + Extensions, HeaderMap, Method, StatusCode, }; -use crate::qpack::HeaderField; +use crate::{ext::Protocol, qpack::HeaderField}; #[derive(Debug)] #[cfg_attr(test, derive(PartialEq, Clone))] @@ -22,12 +22,18 @@ pub struct Header { #[allow(clippy::len_without_is_empty)] impl Header { - pub fn request(method: Method, uri: Uri, fields: HeaderMap) -> Result { + /// Creates a new `Header` frame data suitable for sending a request + pub fn request( + method: Method, + uri: Uri, + fields: HeaderMap, + ext: Extensions, + ) -> Result { match (uri.authority(), fields.get("host")) { (None, None) => Err(HeaderError::MissingAuthority), (Some(a), Some(h)) if a.as_str() != h => Err(HeaderError::ContradictedAuthority), _ => Ok(Self { - pseudo: Pseudo::request(method, uri), + pseudo: Pseudo::request(method, uri, ext), fields, }), } @@ -50,7 +56,9 @@ impl Header { } } - pub fn into_request_parts(self) -> Result<(Method, Uri, HeaderMap), HeaderError> { + pub fn into_request_parts( + self, + ) -> Result<(Method, Uri, Option, HeaderMap), HeaderError> { let mut uri = Uri::builder(); if let Some(path) = self.pseudo.path { @@ -92,6 +100,7 @@ impl Header { Ok(( self.pseudo.method.ok_or(HeaderError::MissingMethod)?, uri.build().map_err(HeaderError::InvalidRequest)?, + self.pseudo.protocol, self.fields, )) } @@ -172,6 +181,10 @@ impl Iterator for HeaderIter { if let Some(status) = pseudo.status.take() { return Some((":status", status.as_str()).into()); } + + if let Some(protocol) = pseudo.protocol.take() { + return Some((":protocol", protocol.as_str().as_bytes()).into()); + } } self.pseudo = None; @@ -221,6 +234,10 @@ impl TryFrom> for Header { Field::Header((n, v)) => { fields.append(n, v); } + Field::Protocol(p) => { + pseudo.protocol = Some(p); + pseudo.len += 1; + } } } @@ -234,6 +251,7 @@ enum Field { Authority(Authority), Path(PathAndQuery), Status(StatusCode), + Protocol(Protocol), Header((HeaderName, HeaderValue)), } @@ -257,9 +275,15 @@ impl Field { //# character not permitted in a field value MUST be treated as //# malformed. + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2 + //= type=implication + //# A request or + //# response containing uppercase characters in field names MUST be + //# treated as malformed. + if name[0] != b':' { return Ok(Field::Header(( - HeaderName::from_bytes(name).map_err(|_| HeaderError::invalid_name(name))?, + HeaderName::from_lowercase(name).map_err(|_| HeaderError::invalid_name(name))?, HeaderValue::from_bytes(value.as_ref()) .map_err(|_| HeaderError::invalid_value(name, value))?, ))); @@ -277,6 +301,7 @@ impl Field { StatusCode::from_bytes(value.as_ref()) .map_err(|_| HeaderError::invalid_value(name, value))?, ), + b":protocol" => Field::Protocol(try_value(name, value)?), _ => return Err(HeaderError::invalid_name(name)), }) } @@ -316,12 +341,14 @@ struct Pseudo { // Response status: Option, + protocol: Option, + len: usize, } #[allow(clippy::len_without_is_empty)] impl Pseudo { - fn request(method: Method, uri: Uri) -> Self { + fn request(method: Method, uri: Uri, ext: Extensions) -> Self { let Parts { scheme, authority, @@ -345,7 +372,16 @@ impl Pseudo { }, ); - let len = 3 + if authority.is_some() { 1 } else { 0 }; + // If the method is connect, the `:protocol` pseudo-header MAY be defined + // + // See: [https://www.rfc-editor.org/rfc/rfc8441#section-4] + let protocol = if method == Method::CONNECT { + ext.get::().copied() + } else { + None + }; + + let len = 3 + authority.is_some() as usize + protocol.is_some() as usize; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.3 //= type=implication @@ -364,6 +400,7 @@ impl Pseudo { authority, path: Some(path), status: None, + protocol, len, } } @@ -381,6 +418,7 @@ impl Pseudo { path: None, status: Some(status), len: 1, + protocol: None, } } diff --git a/h3/src/proto/stream.rs b/h3/src/proto/stream.rs index a3d8fea5..2d525167 100644 --- a/h3/src/proto/stream.rs +++ b/h3/src/proto/stream.rs @@ -5,6 +5,8 @@ use std::{ ops::Add, }; +use crate::webtransport::SessionId; + use super::{ coding::{BufExt, BufMutExt, Decode, Encode, UnexpectedEnd}, varint::VarInt, @@ -26,6 +28,8 @@ stream_types! { PUSH = 0x01, ENCODER = 0x02, DECODER = 0x03, + WEBTRANSPORT_BIDI = 0x41, + WEBTRANSPORT_UNI = 0x54, } impl StreamType { @@ -59,6 +63,7 @@ impl fmt::Display for StreamType { &StreamType::CONTROL => write!(f, "Control"), &StreamType::ENCODER => write!(f, "Encoder"), &StreamType::DECODER => write!(f, "Decoder"), + &StreamType::WEBTRANSPORT_UNI => write!(f, "WebTransportUni"), x => write!(f, "StreamType({})", x.0), } } @@ -112,11 +117,11 @@ impl StreamId { /// Create a new StreamId const fn new(index: u64, dir: Dir, initiator: Side) -> Self { - StreamId((index as u64) << 2 | (dir as u64) << 1 | initiator as u64) + StreamId((index) << 2 | (dir as u64) << 1 | initiator as u64) } /// Distinguishes streams of the same initiator and directionality - fn index(self) -> u64 { + pub fn index(self) -> u64 { self.0 >> 2 } @@ -128,6 +133,10 @@ impl StreamId { Dir::Uni } } + + pub(crate) fn into_inner(self) -> u64 { + self.0 + } } impl TryFrom for StreamId { @@ -154,7 +163,7 @@ impl From for VarInt { /// Invalid StreamId, for example because it's too large #[derive(Debug, PartialEq)] -pub struct InvalidStreamId(u64); +pub struct InvalidStreamId(pub(crate) u64); impl Display for InvalidStreamId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -181,6 +190,12 @@ impl Add for StreamId { } } +impl From for StreamId { + fn from(value: SessionId) -> Self { + Self(value.into_inner()) + } +} + #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Side { /// The initiator of a connection diff --git a/h3/src/proto/varint.rs b/h3/src/proto/varint.rs index 9af8e28f..60cdacf5 100644 --- a/h3/src/proto/varint.rs +++ b/h3/src/proto/varint.rs @@ -1,4 +1,4 @@ -use std::{convert::TryInto, fmt}; +use std::{convert::TryInto, fmt, ops::Div}; use bytes::{Buf, BufMut}; @@ -12,6 +12,14 @@ pub use super::coding::UnexpectedEnd; #[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct VarInt(pub(crate) u64); +impl Div for VarInt { + type Output = Self; + + fn div(self, rhs: u64) -> Self::Output { + Self(self.0 / rhs) + } +} + impl VarInt { /// The largest representable value pub const MAX: VarInt = VarInt((1 << 62) - 1); diff --git a/h3/src/qpack/block.rs b/h3/src/qpack/block.rs index 9614f2ad..1ba16306 100644 --- a/h3/src/qpack/block.rs +++ b/h3/src/qpack/block.rs @@ -177,17 +177,30 @@ impl HeaderPrefix { pub fn decode(buf: &mut R) -> Result { let (_, encoded_insert_count) = prefix_int::decode(8, buf)?; let (sign_negative, delta_base) = prefix_int::decode(7, buf)?; + + if encoded_insert_count > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + if delta_base > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + Ok(Self { - encoded_insert_count, - delta_base, + encoded_insert_count: encoded_insert_count as usize, + delta_base: delta_base as usize, sign_negative: sign_negative == 1, }) } pub fn encode(&self, buf: &mut W) { let sign_bit = if self.sign_negative { 1 } else { 0 }; - prefix_int::encode(8, 0, self.encoded_insert_count, buf); - prefix_int::encode(7, sign_bit, self.delta_base, buf); + prefix_int::encode(8, 0, self.encoded_insert_count as u64, buf); + prefix_int::encode(7, sign_bit, self.delta_base as u64, buf); } } @@ -200,16 +213,32 @@ pub enum Indexed { impl Indexed { pub fn decode(buf: &mut R) -> Result { match prefix_int::decode(6, buf)? { - (0b11, i) => Ok(Indexed::Static(i)), - (0b10, i) => Ok(Indexed::Dynamic(i)), + (0b11, i) => { + if i > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + Ok(Indexed::Static(i as usize)) + } + (0b10, i) => { + if i > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + Ok(Indexed::Dynamic(i as usize)) + } (f, _) => Err(ParseError::InvalidPrefix(f)), } } pub fn encode(&self, buf: &mut W) { match self { - Indexed::Static(i) => prefix_int::encode(6, 0b11, *i, buf), - Indexed::Dynamic(i) => prefix_int::encode(6, 0b10, *i, buf), + Indexed::Static(i) => prefix_int::encode(6, 0b11, *i as u64, buf), + Indexed::Dynamic(i) => prefix_int::encode(6, 0b10, *i as u64, buf), } } } @@ -220,13 +249,21 @@ pub struct IndexedWithPostBase(pub usize); impl IndexedWithPostBase { pub fn decode(buf: &mut R) -> Result { match prefix_int::decode(4, buf)? { - (0b0001, i) => Ok(IndexedWithPostBase(i)), + (0b0001, i) => { + if i > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + Ok(IndexedWithPostBase(i as usize)) + } (f, _) => Err(ParseError::InvalidPrefix(f)), } } pub fn encode(&self, buf: &mut W) { - prefix_int::encode(4, 0b0001, self.0, buf) + prefix_int::encode(4, 0b0001, self.0 as u64, buf) } } @@ -253,14 +290,30 @@ impl LiteralWithNameRef { pub fn decode(buf: &mut R) -> Result { match prefix_int::decode(4, buf)? { - (f, i) if f & 0b0101 == 0b0101 => Ok(LiteralWithNameRef::new_static( - i, - prefix_string::decode(8, buf)?, - )), - (f, i) if f & 0b0101 == 0b0100 => Ok(LiteralWithNameRef::new_dynamic( - i, - prefix_string::decode(8, buf)?, - )), + (f, i) if f & 0b0101 == 0b0101 => { + if i > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + Ok(LiteralWithNameRef::new_static( + i as usize, + prefix_string::decode(8, buf)?, + )) + } + (f, i) if f & 0b0101 == 0b0100 => { + if i > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + Ok(LiteralWithNameRef::new_dynamic( + i as usize, + prefix_string::decode(8, buf)?, + )) + } (f, _) => Err(ParseError::InvalidPrefix(f)), } } @@ -268,11 +321,11 @@ impl LiteralWithNameRef { pub fn encode(&self, buf: &mut W) -> Result<(), prefix_string::Error> { match self { LiteralWithNameRef::Static { index, value } => { - prefix_int::encode(4, 0b0101, *index, buf); + prefix_int::encode(4, 0b0101, *index as u64, buf); prefix_string::encode(8, 0, value, buf)?; } LiteralWithNameRef::Dynamic { index, value } => { - prefix_int::encode(4, 0b0100, *index, buf); + prefix_int::encode(4, 0b0100, *index as u64, buf); prefix_string::encode(8, 0, value, buf)?; } } @@ -296,16 +349,24 @@ impl LiteralWithPostBaseNameRef { pub fn decode(buf: &mut R) -> Result { match prefix_int::decode(3, buf)? { - (f, i) if f & 0b1111_0000 == 0 => Ok(LiteralWithPostBaseNameRef::new( - i, - prefix_string::decode(8, buf)?, - )), + (f, i) if f & 0b1111_0000 == 0 => { + if i > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + + Ok(LiteralWithPostBaseNameRef::new( + i as usize, + prefix_string::decode(8, buf)?, + )) + } (f, _) => Err(ParseError::InvalidPrefix(f)), } } pub fn encode(&self, buf: &mut W) -> Result<(), prefix_string::Error> { - prefix_int::encode(3, 0b0000, self.index, buf); + prefix_int::encode(3, 0b0000, self.index as u64, buf); prefix_string::encode(8, 0, &self.value, buf)?; Ok(()) } @@ -347,6 +408,7 @@ impl Literal { #[cfg(test)] mod test { use super::*; + use std::convert::TryInto; use std::io::Cursor; const TABLE_SIZE: usize = 4096; @@ -424,7 +486,7 @@ mod test { #[test] fn base_index_too_small() { let mut buf = vec![]; - let encoded_largest_ref = (2 % (2 * TABLE_SIZE / 32)) + 1; + let encoded_largest_ref: u64 = ((2 % (2 * TABLE_SIZE / 32)) + 1).try_into().unwrap(); prefix_int::encode(8, 0, encoded_largest_ref, &mut buf); prefix_int::encode(7, 1, 2, &mut buf); // base index negative = 0 diff --git a/h3/src/qpack/decoder.rs b/h3/src/qpack/decoder.rs index 18a46416..799002bc 100644 --- a/h3/src/qpack/decoder.rs +++ b/h3/src/qpack/decoder.rs @@ -1,5 +1,5 @@ use bytes::{Buf, BufMut}; -use std::{fmt, io::Cursor}; +use std::{convert::TryInto, fmt, io::Cursor, num::TryFromIntError}; use tracing::trace; @@ -36,6 +36,7 @@ pub enum Error { BadBaseIndex(isize), UnexpectedEnd, HeaderTooLong(u64), + BufSize(TryFromIntError), } impl std::error::Error for Error {} @@ -53,6 +54,7 @@ impl std::fmt::Display for Error { Error::BadBaseIndex(i) => write!(f, "out of bounds base index: {}", i), Error::UnexpectedEnd => write!(f, "unexpected end"), Error::HeaderTooLong(_) => write!(f, "header too long"), + Error::BufSize(_) => write!(f, "number in buffer wrong size"), } } } @@ -126,7 +128,8 @@ impl Decoder { } if self.table.total_inserted() != inserted_on_start { - InsertCountIncrement(self.table.total_inserted() - inserted_on_start).encode(write); + InsertCountIncrement((self.table.total_inserted() - inserted_on_start).try_into()?) + .encode(write); } Ok(self.table.total_inserted()) @@ -326,6 +329,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: TryFromIntError) -> Self { + Error::BufSize(error) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/h3/src/qpack/dynamic.rs b/h3/src/qpack/dynamic.rs index 5e149ee8..542e0ccf 100644 --- a/h3/src/qpack/dynamic.rs +++ b/h3/src/qpack/dynamic.rs @@ -667,7 +667,7 @@ mod tests { assert_eq!(table.fields.len(), 1); assert_eq!( - table.fields.get(0), + table.fields.front(), Some(&HeaderField::new("Name-Large", "Value-Large")) ); } @@ -736,7 +736,7 @@ mod tests { insert_fields(&mut table, vec![HeaderField::new("Name-C", "Value-C")]); assert_eq!( - table.fields.get(0), + table.fields.front(), Some(&HeaderField::new("Name-B", "Value-B")) ); assert_eq!( @@ -1042,7 +1042,7 @@ mod tests { assert_eq!(table.track_map.get(&1), Some(&1)); } let track_blocks = table.track_blocks; - let block = track_blocks.get(&stream_id).unwrap().get(0).unwrap(); + let block = track_blocks.get(&stream_id).unwrap().front().unwrap(); assert_eq!(block.get(&1), Some(&1)); assert_eq!(block.get(&2), Some(&1)); assert_eq!(block.get(&3), Some(&1)); diff --git a/h3/src/qpack/encoder.rs b/h3/src/qpack/encoder.rs index 68ce0889..567b685d 100644 --- a/h3/src/qpack/encoder.rs +++ b/h3/src/qpack/encoder.rs @@ -239,9 +239,8 @@ impl Action { let first = buf.chunk()[0]; let instruction = match DecoderInstruction::decode(first) { DecoderInstruction::Unknown => return Err(Error::UnknownDecoderInstruction(first)), - DecoderInstruction::InsertCountIncrement => { - InsertCountIncrement::decode(&mut buf)?.map(|x| Action::ReceivedRefIncrement(x.0)) - } + DecoderInstruction::InsertCountIncrement => InsertCountIncrement::decode(&mut buf)? + .map(|x| Action::ReceivedRefIncrement(x.0 as usize)), DecoderInstruction::HeaderAck => { HeaderAck::decode(&mut buf)?.map(|x| Action::Untrack(x.0)) } diff --git a/h3/src/qpack/field.rs b/h3/src/qpack/field.rs index d2b96e51..22a9f6c7 100644 --- a/h3/src/qpack/field.rs +++ b/h3/src/qpack/field.rs @@ -36,7 +36,7 @@ impl HeaderField { T: Into>, { Self { - name: self.name.to_owned(), + name: self.name.clone(), value: Cow::Owned(value.into()), } } diff --git a/h3/src/qpack/mod.rs b/h3/src/qpack/mod.rs index 7189dc3a..ed12af96 100644 --- a/h3/src/qpack/mod.rs +++ b/h3/src/qpack/mod.rs @@ -1,9 +1,6 @@ pub use self::{ - decoder::{ - ack_header, decode_stateless, stream_canceled, Decoded, Decoder, Error as DecoderError, - }, - dynamic::Error as DynamicTableError, - encoder::{encode_stateless, Encoder, Error as EncoderError}, + decoder::{decode_stateless, Decoded, Error as DecoderError}, + encoder::{encode_stateless, Error as EncoderError}, field::HeaderField, }; diff --git a/h3/src/qpack/prefix_int.rs b/h3/src/qpack/prefix_int.rs index 1cfb2415..65a030dc 100644 --- a/h3/src/qpack/prefix_int.rs +++ b/h3/src/qpack/prefix_int.rs @@ -19,7 +19,7 @@ impl std::fmt::Display for Error { } } -pub fn decode(size: u8, buf: &mut B) -> Result<(u8, usize), Error> { +pub fn decode(size: u8, buf: &mut B) -> Result<(u8, u64), Error> { assert!(size <= 8); let mut first = buf.get::()?; @@ -31,13 +31,13 @@ pub fn decode(size: u8, buf: &mut B) -> Result<(u8, usize), Error> { // if first < 2usize.pow(size) - 1 if first < mask { - return Ok((flags, first as usize)); + return Ok((flags, first as u64)); } - let mut value = mask as usize; + let mut value = mask as u64; let mut power = 0usize; loop { - let byte = buf.get::()? as usize; + let byte = buf.get::()? as u64; value += (byte & 127) << power; power += 7; @@ -53,7 +53,7 @@ pub fn decode(size: u8, buf: &mut B) -> Result<(u8, usize), Error> { Ok((flags, value)) } -pub fn encode(size: u8, flags: u8, value: usize, buf: &mut B) { +pub fn encode(size: u8, flags: u8, value: u64, buf: &mut B) { assert!(size <= 8); // NOTE: following casts to u8 intend to trim the most significant bits, they are used as a // workaround for shiftoverflow errors when size == 8. @@ -61,13 +61,13 @@ pub fn encode(size: u8, flags: u8, value: usize, buf: &mut B) { let flags = ((flags as usize) << size) as u8; // if value < 2usize.pow(size) - 1 - if value < (mask as usize) { + if value < (mask as u64) { buf.write(flags | value as u8); return; } buf.write(mask | flags); - let mut remaining = value - mask as usize; + let mut remaining = value - mask as u64; while remaining >= 128 { let rest = (remaining % 128) as u8; @@ -93,7 +93,7 @@ impl From for Error { mod test { use std::io::Cursor; - fn check_codec(size: u8, flags: u8, value: usize, data: &[u8]) { + fn check_codec(size: u8, flags: u8, value: u64, data: &[u8]) { let mut buf = Vec::new(); super::encode(size, flags, value, &mut buf); assert_eq!(buf, data); @@ -110,7 +110,7 @@ mod test { check_codec( 5, 0b010, - usize::max_value(), + u64::max_value(), &[95, 224, 255, 255, 255, 255, 255, 255, 255, 255, 1], ); } @@ -122,7 +122,7 @@ mod test { check_codec( 8, 0, - usize::max_value(), + u64::max_value(), &[255, 128, 254, 255, 255, 255, 255, 255, 255, 255, 1], ); } diff --git a/h3/src/qpack/prefix_string/mod.rs b/h3/src/qpack/prefix_string/mod.rs index c83f9204..4a0967ba 100644 --- a/h3/src/qpack/prefix_string/mod.rs +++ b/h3/src/qpack/prefix_string/mod.rs @@ -2,14 +2,16 @@ mod bitwin; mod decode; mod encode; +use std::convert::TryInto; use std::fmt; +use std::num::TryFromIntError; use bytes::{Buf, BufMut}; pub use self::bitwin::BitWindow; pub use self::{ - decode::{DecodeIter, Error as HuffmanDecodingError, HpackStringDecode}, + decode::{Error as HuffmanDecodingError, HpackStringDecode}, encode::{Error as HuffmanEncodingError, HpackStringEncode}, }; @@ -22,6 +24,7 @@ pub enum Error { Integer(IntegerError), HuffmanDecoding(HuffmanDecodingError), HuffmanEncoding(HuffmanEncodingError), + BufSize(TryFromIntError), } impl std::fmt::Display for Error { @@ -31,12 +34,14 @@ impl std::fmt::Display for Error { Error::Integer(e) => write!(f, "could not parse integer: {}", e), Error::HuffmanDecoding(e) => write!(f, "Huffman decode failed: {:?}", e), Error::HuffmanEncoding(e) => write!(f, "Huffman encode failed: {:?}", e), + Error::BufSize(_) => write!(f, "number in buffer wrong size"), } } } pub fn decode(size: u8, buf: &mut B) -> Result, Error> { let (flags, len) = prefix_int::decode(size - 1, buf)?; + let len: usize = len.try_into()?; if buf.remaining() < len { return Err(Error::UnexpectedEnd); } @@ -56,7 +61,7 @@ pub fn decode(size: u8, buf: &mut B) -> Result, Error> { pub fn encode(size: u8, flags: u8, value: &[u8], buf: &mut B) -> Result<(), Error> { let encoded = Vec::from(value).hpack_encode()?; - prefix_int::encode(size - 1, flags << 1 | 1, encoded.len(), buf); + prefix_int::encode(size - 1, flags << 1 | 1, encoded.len().try_into()?, buf); for byte in encoded { buf.write(byte); } @@ -84,6 +89,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: TryFromIntError) -> Self { + Error::BufSize(error) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/h3/src/qpack/stream.rs b/h3/src/qpack/stream.rs index 472b885e..d2ec4c8f 100644 --- a/h3/src/qpack/stream.rs +++ b/h3/src/qpack/stream.rs @@ -1,4 +1,5 @@ use bytes::{Buf, BufMut}; +use std::convert::TryInto; use super::{ parse_error::ParseError, @@ -96,6 +97,9 @@ impl InsertWithNameRef { Err(IntError::UnexpectedEnd) => return Ok(None), Err(e) => return Err(e.into()), }; + let index: usize = index + .try_into() + .map_err(|_e| ParseError::Integer(crate::qpack::prefix_int::Error::Overflow))?; let value = match prefix_string::decode(8, buf) { Ok(x) => x, @@ -113,11 +117,11 @@ impl InsertWithNameRef { pub fn encode(&self, buf: &mut W) -> Result<(), prefix_string::Error> { match self { InsertWithNameRef::Static { index, value } => { - prefix_int::encode(6, 0b11, *index, buf); + prefix_int::encode(6, 0b11, *index as u64, buf); prefix_string::encode(8, 0, value, buf)?; } InsertWithNameRef::Dynamic { index, value } => { - prefix_int::encode(6, 0b10, *index, buf); + prefix_int::encode(6, 0b10, *index as u64, buf); prefix_string::encode(8, 0, value, buf)?; } } @@ -166,7 +170,14 @@ pub struct Duplicate(pub usize); impl Duplicate { pub fn decode(buf: &mut R) -> Result, ParseError> { let index = match prefix_int::decode(5, buf) { - Ok((0, x)) => x, + Ok((0, x)) => { + if x > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + x as usize + } Ok((f, _)) => return Err(ParseError::InvalidPrefix(f)), Err(IntError::UnexpectedEnd) => return Ok(None), Err(e) => return Err(e.into()), @@ -175,7 +186,7 @@ impl Duplicate { } pub fn encode(&self, buf: &mut W) { - prefix_int::encode(5, 0, self.0, buf); + prefix_int::encode(5, 0, self.0 as u64, buf); } } @@ -185,7 +196,14 @@ pub struct DynamicTableSizeUpdate(pub usize); impl DynamicTableSizeUpdate { pub fn decode(buf: &mut R) -> Result, ParseError> { let size = match prefix_int::decode(5, buf) { - Ok((0b001, x)) => x, + Ok((0b001, x)) => { + if x > (usize::MAX as u64) { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + x as usize + } Ok((f, _)) => return Err(ParseError::InvalidPrefix(f)), Err(IntError::UnexpectedEnd) => return Ok(None), Err(e) => return Err(e.into()), @@ -194,7 +212,7 @@ impl DynamicTableSizeUpdate { } pub fn encode(&self, buf: &mut W) { - prefix_int::encode(5, 0b001, self.0, buf); + prefix_int::encode(5, 0b001, self.0 as u64, buf); } } @@ -245,12 +263,19 @@ impl DecoderInstruction { } #[derive(Debug, PartialEq)] -pub struct InsertCountIncrement(pub usize); +pub struct InsertCountIncrement(pub u8); impl InsertCountIncrement { pub fn decode(buf: &mut R) -> Result, ParseError> { let insert_count = match prefix_int::decode(6, buf) { - Ok((0b00, x)) => x, + Ok((0b00, x)) => { + if x > 64 { + return Err(ParseError::Integer( + crate::qpack::prefix_int::Error::Overflow, + )); + } + x as u8 + } Ok((f, _)) => return Err(ParseError::InvalidPrefix(f)), Err(IntError::UnexpectedEnd) => return Ok(None), Err(e) => return Err(e.into()), @@ -259,7 +284,7 @@ impl InsertCountIncrement { } pub fn encode(&self, buf: &mut W) { - prefix_int::encode(6, 0b00, self.0, buf); + prefix_int::encode(6, 0b00, self.0 as u64, buf); } } @@ -269,7 +294,7 @@ pub struct HeaderAck(pub u64); impl HeaderAck { pub fn decode(buf: &mut R) -> Result, ParseError> { let stream_id = match prefix_int::decode(7, buf) { - Ok((0b1, x)) => x as u64, + Ok((0b1, x)) => x, Ok((f, _)) => return Err(ParseError::InvalidPrefix(f)), Err(IntError::UnexpectedEnd) => return Ok(None), Err(e) => return Err(e.into()), @@ -278,7 +303,7 @@ impl HeaderAck { } pub fn encode(&self, buf: &mut W) { - prefix_int::encode(7, 0b1, self.0 as usize, buf); + prefix_int::encode(7, 0b1, self.0, buf); } } @@ -288,7 +313,7 @@ pub struct StreamCancel(pub u64); impl StreamCancel { pub fn decode(buf: &mut R) -> Result, ParseError> { let stream_id = match prefix_int::decode(6, buf) { - Ok((0b01, x)) => x as u64, + Ok((0b01, x)) => x, Ok((f, _)) => return Err(ParseError::InvalidPrefix(f)), Err(IntError::UnexpectedEnd) => return Ok(None), Err(e) => return Err(e.into()), @@ -297,7 +322,7 @@ impl StreamCancel { } pub fn encode(&self, buf: &mut W) { - prefix_int::encode(6, 0b01, self.0 as usize, buf); + prefix_int::encode(6, 0b01, self.0, buf); } } diff --git a/h3/src/qpack/tests.rs b/h3/src/qpack/tests.rs index 5ae34d51..f72f4510 100644 --- a/h3/src/qpack/tests.rs +++ b/h3/src/qpack/tests.rs @@ -1,4 +1,6 @@ -use crate::qpack::{dynamic::DynamicTable, Decoded, Decoder, DecoderError, Encoder, HeaderField}; +use crate::qpack::decoder::Decoder; +use crate::qpack::encoder::Encoder; +use crate::qpack::{dynamic::DynamicTable, Decoded, DecoderError, HeaderField}; use std::io::Cursor; pub mod helpers { @@ -44,7 +46,7 @@ fn codec_basic_get() { ]; encoder - .encode(42, &mut block_buf, &mut enc_buf, header.clone().into_iter()) + .encode(42, &mut block_buf, &mut enc_buf, header.clone()) .unwrap(); let mut enc_cur = Cursor::new(&mut enc_buf); @@ -111,7 +113,7 @@ fn codec_table_size_0() { let mut decoder = Decoder::from(dec_table); encoder - .encode(42, &mut block_buf, &mut enc_buf, header.clone().into_iter()) + .encode(42, &mut block_buf, &mut enc_buf, header.clone()) .unwrap(); let mut enc_cur = Cursor::new(&mut enc_buf); @@ -146,7 +148,7 @@ fn codec_table_full() { let mut decoder = Decoder::from(dec_table); encoder - .encode(42, &mut block_buf, &mut enc_buf, header.clone().into_iter()) + .encode(42, &mut block_buf, &mut enc_buf, header.clone()) .unwrap(); let mut enc_cur = Cursor::new(&mut enc_buf); diff --git a/h3/src/quic.rs b/h3/src/quic.rs index d7203f7b..6e8722cb 100644 --- a/h3/src/quic.rs +++ b/h3/src/quic.rs @@ -7,6 +7,7 @@ use std::task::{self, Poll}; use bytes::Buf; +use crate::ext::Datagram; pub use crate::proto::stream::{InvalidStreamId, StreamId}; pub use crate::stream::WriteBuf; @@ -34,7 +35,7 @@ pub trait Connection: OpenStreams { /// The type produced by `poll_accept_recv()` type RecvStream: RecvStream; /// A producer of outgoing Unidirectional and Bidirectional streams. - type OpenStreams: OpenStreams; + type OpenStreams: OpenStreams; /// Error type yielded by these trait methods type AcceptError: Into>; @@ -58,6 +59,33 @@ pub trait Connection: OpenStreams { fn opener(&self) -> Self::OpenStreams; } +/// Extends the `Connection` trait for sending datagrams +/// +/// See: +pub trait SendDatagramExt { + /// The error type that can occur when sending a datagram + type Error: Into>; + + /// Send a datagram + fn send_datagram(&mut self, data: Datagram) -> Result<(), Self::Error>; +} + +/// Extends the `Connection` trait for receiving datagrams +/// +/// See: +pub trait RecvDatagramExt { + /// The type of `Buf` for *raw* datagrams (without the stream_id decoded) + type Buf: Buf; + /// The error type that can occur when receiving a datagram + type Error: Into>; + + /// Poll the connection for incoming datagrams. + fn poll_accept_datagram( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>>; +} + /// Trait for opening outgoing streams pub trait OpenStreams { /// The type produced by `poll_open_bidi()` @@ -101,7 +129,21 @@ pub trait SendStream { fn reset(&mut self, reset_code: u64); /// Get QUIC send stream id - fn id(&self) -> StreamId; + fn send_id(&self) -> StreamId; +} + +/// Allows sending unframed pure bytes to a stream. Similar to [`AsyncWrite`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWrite.html) +pub trait SendStreamUnframed: SendStream { + /// Attempts write data into the stream. + /// + /// Returns the number of bytes written. + /// + /// `buf` is advanced by the number of bytes written. + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll>; } /// A trait describing the "receive" actions of a QUIC stream. @@ -122,6 +164,9 @@ pub trait RecvStream { /// Send a `STOP_SENDING` QUIC code. fn stop_sending(&mut self, error_code: u64); + + /// Get QUIC send stream id + fn recv_id(&self) -> StreamId; } /// Optional trait to allow "splitting" a bidirectional stream into two sides. diff --git a/h3/src/server.rs b/h3/src/server.rs deleted file mode 100644 index 6c43084f..00000000 --- a/h3/src/server.rs +++ /dev/null @@ -1,727 +0,0 @@ -//! This module provides methods to create a http/3 Server. -//! -//! It allows to accept incoming requests, and send responses. -//! -//! # Examples -//! -//! ## Simple example -//! ```rust -//! async fn doc(conn: C) -//! where -//! C: h3::quic::Connection, -//! >::BidiStream: Send + 'static -//! { -//! let mut server_builder = h3::server::builder(); -//! // Build the Connection -//! let mut h3_conn = server_builder.build(conn).await.unwrap(); -//! loop { -//! // Accept incoming requests -//! match h3_conn.accept().await { -//! Ok(Some((req, mut stream))) => { -//! // spawn a new task to handle the request -//! tokio::spawn(async move { -//! // build a http response -//! let response = http::Response::builder().status(http::StatusCode::OK).body(()).unwrap(); -//! // send the response to the wire -//! stream.send_response(response).await.unwrap(); -//! // send some date -//! stream.send_data(bytes::Bytes::from("test")).await.unwrap(); -//! // finnish the stream -//! stream.finish().await.unwrap(); -//! }); -//! } -//! Ok(None) => { -//! // break if no Request is accepted -//! break; -//! } -//! Err(err) => { -//! match err.get_error_level() { -//! // break on connection errors -//! h3::error::ErrorLevel::ConnectionError => break, -//! // continue on stream errors -//! h3::error::ErrorLevel::StreamError => continue, -//! } -//! } -//! } -//! } -//! } -//! ``` -//! -//! ## File server -//! A ready-to-use example of a file server is available [here](https://github.com/hyperium/h3/blob/master/examples/client.rs) - -use std::{ - collections::HashSet, - convert::TryFrom, - sync::Arc, - task::{Context, Poll}, -}; - -use bytes::{Buf, BytesMut}; -use futures_util::future; -use http::{response, HeaderMap, Request, Response, StatusCode}; -use quic::StreamId; -use tokio::sync::mpsc; - -use crate::{ - connection::{self, ConnectionInner, ConnectionState, SharedStateRef}, - error::{Code, Error, ErrorLevel}, - frame::FrameStream, - proto::{frame::Frame, headers::Header, push::PushId, varint::VarInt}, - qpack, - quic::{self, RecvStream as _, SendStream as _}, - stream, -}; -use tracing::{error, trace, warn}; - -/// Create a builder of HTTP/3 server connections -/// -/// This function creates a [`Builder`] that carries settings that can -/// be shared between server connections. -pub fn builder() -> Builder { - Builder::new() -} - -/// Server connection driver -/// -/// The [`Connection`] struct manages a connection from the side of the HTTP/3 server -/// -/// Create a new Instance with [`Connection::new()`]. -/// Accept incoming requests with [`Connection::accept()`]. -/// And shutdown a connection with [`Connection::shutdown()`]. -pub struct Connection -where - C: quic::Connection, - B: Buf, -{ - inner: ConnectionInner, - max_field_section_size: u64, - // List of all incoming streams that are currently running. - ongoing_streams: HashSet, - // Let the streams tell us when they are no longer running. - request_end_recv: mpsc::UnboundedReceiver, - request_end_send: mpsc::UnboundedSender, - // Has a GOAWAY frame been sent? If so, this StreamId is the last we are willing to accept. - sent_closing: Option, - // Has a GOAWAY frame been received? If so, this is PushId the last the remote will accept. - recv_closing: Option, - // The id of the last stream received by this connection. - last_accepted_stream: Option, -} - -impl ConnectionState for Connection -where - C: quic::Connection, - B: Buf, -{ - fn shared_state(&self) -> &SharedStateRef { - &self.inner.shared - } -} - -impl Connection -where - C: quic::Connection, - B: Buf, -{ - /// Create a new HTTP/3 server connection with default settings - /// - /// Use a custom [`Builder`] with [`builder()`] to create a connection - /// with different settings. - /// Provide a Connection which implements [`quic::Connection`]. - pub async fn new(conn: C) -> Result { - builder().build(conn).await - } -} - -impl Connection -where - C: quic::Connection, - B: Buf, -{ - /// Accept an incoming request. - /// - /// It returns a tuple with a [`http::Request`] and an [`RequestStream`]. - /// The [`http::Request`] is the received request from the client. - /// The [`RequestStream`] can be used to send the response. - pub async fn accept( - &mut self, - ) -> Result, RequestStream)>, Error> { - // Accept the incoming stream - let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await { - Ok(Some(s)) => FrameStream::new(s), - Ok(None) => { - // We always send a last GoAway frame to the client, so it knows which was the last - // non-rejected request. - self.shutdown(0).await?; - return Ok(None); - } - Err(err) => { - match err.inner.kind { - crate::error::Kind::Closed => return Ok(None), - crate::error::Kind::Application { - code, - reason, - level: ErrorLevel::ConnectionError, - } => { - return Err(self.inner.close( - code, - reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), - )) - } - _ => return Err(err), - }; - } - }; - - let frame = future::poll_fn(|cx| stream.poll_next(cx)).await; - - let mut encoded = match frame { - Ok(Some(Frame::Headers(h))) => h, - - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 - //# If a client-initiated - //# stream terminates without enough of the HTTP message to provide a - //# complete response, the server SHOULD abort its response stream with - //# the error code H3_REQUEST_INCOMPLETE. - Ok(None) => { - return Err(self.inner.close( - Code::H3_REQUEST_INCOMPLETE, - "request stream closed before headers", - )) - } - - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 - //# Receipt of an invalid sequence of frames MUST be treated as a - //# connection error of type H3_FRAME_UNEXPECTED. - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 - //# A server MUST treat the - //# receipt of a PUSH_PROMISE frame as a connection error of type - //# H3_FRAME_UNEXPECTED. - Ok(Some(_)) => { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 - //# Receipt of an invalid sequence of frames MUST be treated as a - //# connection error of type H3_FRAME_UNEXPECTED. - // Close if the first frame is not a header frame - return Err(self.inner.close( - Code::H3_FRAME_UNEXPECTED, - "first request frame is not headers", - )); - } - Err(e) => { - let err: Error = e.into(); - if err.is_closed() { - return Ok(None); - } - match err.inner.kind { - crate::error::Kind::Closed => return Ok(None), - crate::error::Kind::Application { - code, - reason, - level: ErrorLevel::ConnectionError, - } => { - return Err(self.inner.close( - code, - reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), - )) - } - crate::error::Kind::Application { - code, - reason: _, - level: ErrorLevel::StreamError, - } => { - stream.reset(code.into()); - return Err(err); - } - _ => return Err(err), - }; - } - }; - - let mut request_stream = RequestStream { - request_end: Arc::new(RequestEnd { - request_end: self.request_end_send.clone(), - stream_id: stream.id(), - }), - inner: connection::RequestStream::new( - stream, - self.max_field_section_size, - self.inner.shared.clone(), - self.inner.send_grease_frame, - ), - }; - - let qpack::Decoded { fields, .. } = - match qpack::decode_stateless(&mut encoded, self.max_field_section_size) { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 - //# An HTTP/3 implementation MAY impose a limit on the maximum size of - //# the message header it will accept on an individual HTTP message. - Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => { - request_stream - .send_response( - http::Response::builder() - .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE) - .body(()) - .expect("header too big response"), - ) - .await?; - return Err(Error::header_too_big( - cancel_size, - self.max_field_section_size, - )); - } - Ok(decoded) => decoded, - Err(e) => { - let err: Error = e.into(); - if err.is_closed() { - return Ok(None); - } - match err.inner.kind { - crate::error::Kind::Closed => return Ok(None), - crate::error::Kind::Application { - code, - reason, - level: ErrorLevel::ConnectionError, - } => { - return Err(self.inner.close( - code, - reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), - )) - } - crate::error::Kind::Application { - code, - reason: _, - level: ErrorLevel::StreamError, - } => { - request_stream.stop_stream(code); - return Err(err); - } - _ => return Err(err), - }; - } - }; - - // Parse the request headers - let (method, uri, headers) = match Header::try_from(fields) { - Ok(header) => match header.into_request_parts() { - Ok(parts) => parts, - Err(err) => { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 - //# Malformed requests or responses that are - //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. - let error: Error = err.into(); - request_stream - .stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); - return Err(error); - } - }, - Err(err) => { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 - //# Malformed requests or responses that are - //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. - let error: Error = err.into(); - request_stream.stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); - return Err(error); - } - }; - // request_stream.stop_stream(Code::H3_MESSAGE_ERROR).await; - let mut req = http::Request::new(()); - *req.method_mut() = method; - *req.uri_mut() = uri; - *req.headers_mut() = headers; - *req.version_mut() = http::Version::HTTP_3; - // send the grease frame only once - self.inner.send_grease_frame = false; - - Ok(Some((req, request_stream))) - } - - /// Itiniate a graceful shutdown, accepting `max_request` potentially still in-flight - /// - /// See [connection shutdown](https://www.rfc-editor.org/rfc/rfc9114.html#connection-shutdown) for more information. - pub async fn shutdown(&mut self, max_requests: usize) -> Result<(), Error> { - let max_id = self - .last_accepted_stream - .map(|id| id + max_requests) - .unwrap_or(StreamId::FIRST_REQUEST); - - self.inner.shutdown(&mut self.sent_closing, max_id).await - } - - fn poll_accept_request( - &mut self, - cx: &mut Context<'_>, - ) -> Poll, Error>> { - let _ = self.poll_control(cx)?; - let _ = self.poll_requests_completion(cx); - loop { - match self.inner.poll_accept_request(cx) { - Poll::Ready(Err(x)) => break Poll::Ready(Err(x)), - Poll::Ready(Ok(None)) => { - if self.poll_requests_completion(cx).is_ready() { - break Poll::Ready(Ok(None)); - } else { - // Wait for all the requests to be finished, request_end_recv will wake - // us on each request completion. - break Poll::Pending; - } - } - Poll::Pending => { - if self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready() { - // The connection is now idle. - break Poll::Ready(Ok(None)); - } else { - return Poll::Pending; - } - } - Poll::Ready(Ok(Some(mut s))) => { - // When the connection is in a graceful shutdown procedure, reject all - // incoming requests not belonging to the grace interval. It's possible that - // some acceptable request streams arrive after rejected requests. - if let Some(max_id) = self.sent_closing { - if s.id() > max_id.into() { - s.stop_sending(Code::H3_REQUEST_REJECTED.value()); - s.reset(Code::H3_REQUEST_REJECTED.value()); - if self.poll_requests_completion(cx).is_ready() { - break Poll::Ready(Ok(None)); - } - continue; - } - } - self.last_accepted_stream = Some(s.id()); - self.ongoing_streams.insert(s.id()); - break Poll::Ready(Ok(Some(s))); - } - }; - } - } - - fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll> { - while let Poll::Ready(frame) = self.inner.poll_control(cx)? { - match frame { - Frame::Settings(_) => trace!("Got settings"), - Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?, - f @ Frame::MaxPushId(_) | f @ Frame::CancelPush(_) => { - warn!("Control frame ignored {:?}", f); - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.3 - //= type=TODO - //# If a server receives a CANCEL_PUSH frame for a push - //# ID that has not yet been mentioned by a PUSH_PROMISE frame, this MUST - //# be treated as a connection error of type H3_ID_ERROR. - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.7 - //= type=TODO - //# A MAX_PUSH_ID frame cannot reduce the maximum push - //# ID; receipt of a MAX_PUSH_ID frame that contains a smaller value than - //# previously received MUST be treated as a connection error of type - //# H3_ID_ERROR. - } - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 - //# A server MUST treat the - //# receipt of a PUSH_PROMISE frame as a connection error of type - //# H3_FRAME_UNEXPECTED. - frame => { - return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.with_reason( - format!("on server control stream: {:?}", frame), - ErrorLevel::ConnectionError, - ))) - } - } - } - Poll::Pending - } - - fn poll_requests_completion(&mut self, cx: &mut Context<'_>) -> Poll<()> { - loop { - match self.request_end_recv.poll_recv(cx) { - // The channel is closed - Poll::Ready(None) => return Poll::Ready(()), - // A request has completed - Poll::Ready(Some(id)) => { - self.ongoing_streams.remove(&id); - } - Poll::Pending => { - if self.ongoing_streams.is_empty() { - // Tell the caller there is not more ongoing requests. - // Still, the completion of future requests will wake us. - return Poll::Ready(()); - } else { - return Poll::Pending; - } - } - } - } - } -} - -impl Drop for Connection -where - C: quic::Connection, - B: Buf, -{ - fn drop(&mut self) { - self.inner.close(Code::H3_NO_ERROR, ""); - } -} - -//= https://www.rfc-editor.org/rfc/rfc9114#section-6.1 -//= type=TODO -//# In order to -//# permit these streams to open, an HTTP/3 server SHOULD configure non- -//# zero minimum values for the number of permitted streams and the -//# initial stream flow-control window. - -//= https://www.rfc-editor.org/rfc/rfc9114#section-6.1 -//= type=TODO -//# So as to not unnecessarily limit -//# parallelism, at least 100 request streams SHOULD be permitted at a -//# time. - -/// Builder of HTTP/3 server connections. -/// -/// Use this struct to create a new [`Connection`]. -/// Settings for the [`Connection`] can be provided here. -/// -/// # Example -/// -/// ```rust -/// fn doc(conn: C) -/// where -/// C: h3::quic::Connection, -/// B: bytes::Buf, -/// { -/// let mut server_builder = h3::server::builder(); -/// // Set the maximum header size -/// server_builder.max_field_section_size(1000); -/// // do not send grease types -/// server_builder.send_grease(false); -/// // Build the Connection -/// let mut h3_conn = server_builder.build(conn); -/// } -/// ``` -pub struct Builder { - pub(super) max_field_section_size: u64, - pub(super) send_grease: bool, -} - -impl Builder { - /// Creates a new [`Builder`] with default settings. - pub(super) fn new() -> Self { - Builder { - max_field_section_size: VarInt::MAX.0, - send_grease: true, - } - } - /// Set the maximum header size this client is willing to accept - /// - /// See [header size constraints] section of the specification for details. - /// - /// [header size constraints]: https://www.rfc-editor.org/rfc/rfc9114.html#name-header-size-constraints - pub fn max_field_section_size(&mut self, value: u64) -> &mut Self { - self.max_field_section_size = value; - self - } - - /// Send grease values to the Client. - /// See [setting](https://www.rfc-editor.org/rfc/rfc9114.html#settings-parameters), [frame](https://www.rfc-editor.org/rfc/rfc9114.html#frame-reserved) and [stream](https://www.rfc-editor.org/rfc/rfc9114.html#stream-grease) for more information. - pub fn send_grease(&mut self, value: bool) -> &mut Self { - self.send_grease = value; - self - } -} - -impl Builder { - /// Build an HTTP/3 connection from a QUIC connection - /// - /// This method creates a [`Connection`] instance with the settings in the [`Builder`]. - pub async fn build(&self, conn: C) -> Result, Error> - where - C: quic::Connection, - B: Buf, - { - let (sender, receiver) = mpsc::unbounded_channel(); - Ok(Connection { - inner: ConnectionInner::new( - conn, - self.max_field_section_size, - SharedStateRef::default(), - self.send_grease, - ) - .await?, - max_field_section_size: self.max_field_section_size, - request_end_send: sender, - request_end_recv: receiver, - ongoing_streams: HashSet::new(), - sent_closing: None, - recv_closing: None, - last_accepted_stream: None, - }) - } -} - -struct RequestEnd { - request_end: mpsc::UnboundedSender, - stream_id: StreamId, -} - -/// Manage request and response transfer for an incoming request -/// -/// The [`RequestStream`] struct is used to send and/or receive -/// information from the client. -pub struct RequestStream { - inner: connection::RequestStream, - request_end: Arc, -} - -impl AsMut> for RequestStream { - fn as_mut(&mut self) -> &mut connection::RequestStream { - &mut self.inner - } -} - -impl ConnectionState for RequestStream { - fn shared_state(&self) -> &SharedStateRef { - &self.inner.conn_state - } -} - -impl RequestStream -where - S: quic::RecvStream, -{ - /// Receive data sent from the client - pub async fn recv_data(&mut self) -> Result, Error> { - self.inner.recv_data().await - } - - /// Receive an optional set of trailers for the request - pub async fn recv_trailers(&mut self) -> Result, Error> { - self.inner.recv_trailers().await - } - - /// Tell the peer to stop sending into the underlying QUIC stream - pub fn stop_sending(&mut self, error_code: crate::error::Code) { - self.inner.stream.stop_sending(error_code) - } -} - -impl RequestStream -where - S: quic::SendStream, - B: Buf, -{ - /// Send the HTTP/3 response - /// - /// This should be called before trying to send any data with - /// [`RequestStream::send_data`]. - pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> { - let (parts, _) = resp.into_parts(); - let response::Parts { - status, headers, .. - } = parts; - let headers = Header::response(status, headers); - - let mut block = BytesMut::new(); - let mem_size = qpack::encode_stateless(&mut block, headers)?; - - let max_mem_size = self - .inner - .conn_state - .read("send_response") - .peer_max_field_section_size; - - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 - //# An implementation that - //# has received this parameter SHOULD NOT send an HTTP message header - //# that exceeds the indicated size, as the peer will likely refuse to - //# process it. - if mem_size > max_mem_size { - return Err(Error::header_too_big(mem_size, max_mem_size)); - } - - stream::write(&mut self.inner.stream, Frame::Headers(block.freeze())) - .await - .map_err(|e| self.maybe_conn_err(e))?; - - Ok(()) - } - - /// Send some data on the response body. - pub async fn send_data(&mut self, buf: B) -> Result<(), Error> { - self.inner.send_data(buf).await - } - - /// Stop a stream with an error code - /// - /// The code can be [`Code::H3_NO_ERROR`]. - pub fn stop_stream(&mut self, error_code: Code) { - self.inner.stop_stream(error_code); - } - - /// Send a set of trailers to end the response. - /// - /// Either [`RequestStream::finish`] or - /// [`RequestStream::send_trailers`] must be called to finalize a - /// request. - pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> { - self.inner.send_trailers(trailers).await - } - - /// End the response without trailers. - /// - /// Either [`RequestStream::finish`] or - /// [`RequestStream::send_trailers`] must be called to finalize a - /// request. - pub async fn finish(&mut self) -> Result<(), Error> { - self.inner.finish().await - } - - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.1 - //= type=TODO - //# Implementations SHOULD cancel requests by abruptly terminating any - //# directions of a stream that are still open. To do so, an - //# implementation resets the sending parts of streams and aborts reading - //# on the receiving parts of streams; see Section 2.4 of - //# [QUIC-TRANSPORT]. -} - -impl RequestStream -where - S: quic::BidiStream, - B: Buf, -{ - /// Splits the Request-Stream into send and receive. - /// This can be used the send and receive data on different tasks. - pub fn split( - self, - ) -> ( - RequestStream, - RequestStream, - ) { - let (send, recv) = self.inner.split(); - ( - RequestStream { - inner: send, - request_end: self.request_end.clone(), - }, - RequestStream { - inner: recv, - request_end: self.request_end, - }, - ) - } -} - -impl Drop for RequestEnd { - fn drop(&mut self) { - if let Err(e) = self.request_end.send(self.stream_id) { - error!( - "failed to notify connection of request end: {} {}", - self.stream_id, e - ); - } - } -} diff --git a/h3/src/server/builder.rs b/h3/src/server/builder.rs new file mode 100644 index 00000000..1836fd1b --- /dev/null +++ b/h3/src/server/builder.rs @@ -0,0 +1,140 @@ +//! Builder of HTTP/3 server connections. +//! +//! Use this struct to create a new [`Connection`]. +//! Settings for the [`Connection`] can be provided here. +//! +//! # Example +//! +//! ```rust +//! fn doc(conn: C) +//! where +//! C: h3::quic::Connection, +//! B: bytes::Buf, +//! { +//! let mut server_builder = h3::server::builder(); +//! // Set the maximum header size +//! server_builder.max_field_section_size(1000); +//! // do not send grease types +//! server_builder.send_grease(false); +//! // Build the Connection +//! let mut h3_conn = server_builder.build(conn); +//! } +//! ``` + +use std::{collections::HashSet, result::Result}; + +use bytes::Buf; + +use tokio::sync::mpsc; + +use crate::{ + config::Config, + connection::{ConnectionInner, SharedStateRef}, + error::Error, + quic::{self}, +}; + +use super::connection::Connection; + +/// Create a builder of HTTP/3 server connections +/// +/// This function creates a [`Builder`] that carries settings that can +/// be shared between server connections. +pub fn builder() -> Builder { + Builder::new() +} + +/// Builder of HTTP/3 server connections. +pub struct Builder { + pub(crate) config: Config, +} + +impl Builder { + /// Creates a new [`Builder`] with default settings. + pub(super) fn new() -> Self { + Builder { + config: Default::default(), + } + } + + #[cfg(test)] + pub fn send_settings(&mut self, value: bool) -> &mut Self { + self.config.send_settings = value; + self + } + + /// Set the maximum header size this client is willing to accept + /// + /// See [header size constraints] section of the specification for details. + /// + /// [header size constraints]: https://www.rfc-editor.org/rfc/rfc9114.html#name-header-size-constraints + pub fn max_field_section_size(&mut self, value: u64) -> &mut Self { + self.config.settings.max_field_section_size = value; + self + } + + /// Send grease values to the Client. + /// See [setting](https://www.rfc-editor.org/rfc/rfc9114.html#settings-parameters), [frame](https://www.rfc-editor.org/rfc/rfc9114.html#frame-reserved) and [stream](https://www.rfc-editor.org/rfc/rfc9114.html#stream-grease) for more information. + #[inline] + pub fn send_grease(&mut self, value: bool) -> &mut Self { + self.config.send_grease = value; + self + } + + /// Indicates to the peer that WebTransport is supported. + /// + /// See: [establishing a webtransport session](https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.1) + /// + /// + /// **Server**: + /// Supporting for webtransport also requires setting `enable_connect` `enable_datagram` + /// and `max_webtransport_sessions`. + #[inline] + pub fn enable_webtransport(&mut self, value: bool) -> &mut Self { + self.config.settings.enable_webtransport = value; + self + } + + /// Enables the CONNECT protocol + pub fn enable_connect(&mut self, value: bool) -> &mut Self { + self.config.settings.enable_extended_connect = value; + self + } + + /// Limits the maximum number of WebTransport sessions + pub fn max_webtransport_sessions(&mut self, value: u64) -> &mut Self { + self.config.settings.max_webtransport_sessions = value; + self + } + + /// Indicates that the client or server supports HTTP/3 datagrams + /// + /// See: + pub fn enable_datagram(&mut self, value: bool) -> &mut Self { + self.config.settings.enable_datagram = value; + self + } +} + +impl Builder { + /// Build an HTTP/3 connection from a QUIC connection + /// + /// This method creates a [`Connection`] instance with the settings in the [`Builder`]. + pub async fn build(&self, conn: C) -> Result, Error> + where + C: quic::Connection, + B: Buf, + { + let (sender, receiver) = mpsc::unbounded_channel(); + Ok(Connection { + inner: ConnectionInner::new(conn, SharedStateRef::default(), self.config).await?, + max_field_section_size: self.config.settings.max_field_section_size, + request_end_send: sender, + request_end_recv: receiver, + ongoing_streams: HashSet::new(), + sent_closing: None, + recv_closing: None, + last_accepted_stream: None, + }) + } +} diff --git a/h3/src/server/connection.rs b/h3/src/server/connection.rs new file mode 100644 index 00000000..b28c8d6c --- /dev/null +++ b/h3/src/server/connection.rs @@ -0,0 +1,473 @@ +//! HTTP/3 server connection +//! +//! The [`Connection`] struct manages a connection from the side of the HTTP/3 server + +use std::{ + collections::HashSet, + marker::PhantomData, + option::Option, + result::Result, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::Buf; +use futures_util::{ + future::{self}, + ready, +}; +use http::Request; +use quic::RecvStream; +use quic::StreamId; +use tokio::sync::mpsc; + +use crate::{ + connection::{self, ConnectionInner, ConnectionState, SharedStateRef}, + error::{Code, Error, ErrorLevel}, + ext::Datagram, + frame::{FrameStream, FrameStreamError}, + proto::{ + frame::{Frame, PayloadLen}, + push::PushId, + }, + qpack, + quic::{self, RecvDatagramExt, SendDatagramExt, SendStream as _}, + stream::BufRecvStream, +}; + +use crate::server::request::ResolveRequest; + +use tracing::{trace, warn}; + +use super::stream::{ReadDatagram, RequestStream}; + +/// Server connection driver +/// +/// The [`Connection`] struct manages a connection from the side of the HTTP/3 server +/// +/// Create a new Instance with [`Connection::new()`]. +/// Accept incoming requests with [`Connection::accept()`]. +/// And shutdown a connection with [`Connection::shutdown()`]. +pub struct Connection +where + C: quic::Connection, + B: Buf, +{ + /// TODO: temporarily break encapsulation for `WebTransportSession` + pub inner: ConnectionInner, + pub(super) max_field_section_size: u64, + // List of all incoming streams that are currently running. + pub(super) ongoing_streams: HashSet, + // Let the streams tell us when they are no longer running. + pub(super) request_end_recv: mpsc::UnboundedReceiver, + pub(super) request_end_send: mpsc::UnboundedSender, + // Has a GOAWAY frame been sent? If so, this StreamId is the last we are willing to accept. + pub(super) sent_closing: Option, + // Has a GOAWAY frame been received? If so, this is PushId the last the remote will accept. + pub(super) recv_closing: Option, + // The id of the last stream received by this connection. + pub(super) last_accepted_stream: Option, +} + +impl ConnectionState for Connection +where + C: quic::Connection, + B: Buf, +{ + fn shared_state(&self) -> &SharedStateRef { + &self.inner.shared + } +} + +impl Connection +where + C: quic::Connection, + B: Buf, +{ + /// Create a new HTTP/3 server connection with default settings + /// + /// Use a custom [`super::builder::Builder`] with [`super::builder::builder()`] to create a connection + /// with different settings. + /// Provide a Connection which implements [`quic::Connection`]. + pub async fn new(conn: C) -> Result { + super::builder::builder().build(conn).await + } + + /// Closes the connection with a code and a reason. + pub fn close>(&mut self, code: Code, reason: T) -> Error { + self.inner.close(code, reason) + } +} + +impl Connection +where + C: quic::Connection, + B: Buf, +{ + /// Accept an incoming request. + /// + /// It returns a tuple with a [`http::Request`] and an [`RequestStream`]. + /// The [`http::Request`] is the received request from the client. + /// The [`RequestStream`] can be used to send the response. + pub async fn accept( + &mut self, + ) -> Result, RequestStream)>, Error> { + // Accept the incoming stream + let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await { + Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)), + Ok(None) => { + // We always send a last GoAway frame to the client, so it knows which was the last + // non-rejected request. + self.shutdown(0).await?; + return Ok(None); + } + Err(err) => { + match err.inner.kind { + crate::error::Kind::Closed => return Ok(None), + crate::error::Kind::Application { + code, + reason, + level: ErrorLevel::ConnectionError, + } => { + return Err(self.inner.close( + code, + reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), + )) + } + _ => return Err(err), + }; + } + }; + + let frame = future::poll_fn(|cx| stream.poll_next(cx)).await; + let req = self.accept_with_frame(stream, frame)?; + if let Some(req) = req { + Ok(Some(req.resolve().await?)) + } else { + Ok(None) + } + } + + /// Accepts an http request where the first frame has already been read and decoded. + /// + /// + /// This is needed as a bidirectional stream may be read as part of incoming webtransport + /// bi-streams. If it turns out that the stream is *not* a `WEBTRANSPORT_STREAM` the request + /// may still want to be handled and passed to the user. + pub fn accept_with_frame( + &mut self, + mut stream: FrameStream, + frame: Result>, FrameStreamError>, + ) -> Result>, Error> { + let mut encoded = match frame { + Ok(Some(Frame::Headers(h))) => h, + + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 + //# If a client-initiated + //# stream terminates without enough of the HTTP message to provide a + //# complete response, the server SHOULD abort its response stream with + //# the error code H3_REQUEST_INCOMPLETE. + Ok(None) => { + return Err(self.inner.close( + Code::H3_REQUEST_INCOMPLETE, + "request stream closed before headers", + )); + } + + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 + //# Receipt of an invalid sequence of frames MUST be treated as a + //# connection error of type H3_FRAME_UNEXPECTED. + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 + //# A server MUST treat the + //# receipt of a PUSH_PROMISE frame as a connection error of type + //# H3_FRAME_UNEXPECTED. + Ok(Some(_)) => { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 + //# Receipt of an invalid sequence of frames MUST be treated as a + //# connection error of type H3_FRAME_UNEXPECTED. + // Close if the first frame is not a header frame + return Err(self.inner.close( + Code::H3_FRAME_UNEXPECTED, + "first request frame is not headers", + )); + } + Err(e) => { + let err: Error = e.into(); + if err.is_closed() { + return Ok(None); + } + match err.inner.kind { + crate::error::Kind::Closed => return Ok(None), + crate::error::Kind::Application { + code, + reason, + level: ErrorLevel::ConnectionError, + } => { + return Err(self.inner.close( + code, + reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), + )) + } + crate::error::Kind::Application { + code, + reason: _, + level: ErrorLevel::StreamError, + } => { + stream.reset(code.into()); + return Err(err); + } + _ => return Err(err), + }; + } + }; + + let mut request_stream = RequestStream { + request_end: Arc::new(RequestEnd { + request_end: self.request_end_send.clone(), + stream_id: stream.send_id(), + }), + inner: connection::RequestStream::new( + stream, + self.max_field_section_size, + self.inner.shared.clone(), + self.inner.send_grease_frame, + ), + }; + + let decoded = match qpack::decode_stateless(&mut encoded, self.max_field_section_size) { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 + //# An HTTP/3 implementation MAY impose a limit on the maximum size of + //# the message header it will accept on an individual HTTP message. + Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => Err(cancel_size), + Ok(decoded) => { + // send the grease frame only once + self.inner.send_grease_frame = false; + Ok(decoded) + } + Err(e) => { + let err: Error = e.into(); + if err.is_closed() { + return Ok(None); + } + match err.inner.kind { + crate::error::Kind::Closed => return Ok(None), + crate::error::Kind::Application { + code, + reason, + level: ErrorLevel::ConnectionError, + } => { + return Err(self.inner.close( + code, + reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), + )) + } + crate::error::Kind::Application { + code, + reason: _, + level: ErrorLevel::StreamError, + } => { + request_stream.stop_stream(code); + return Err(err); + } + _ => return Err(err), + }; + } + }; + + Ok(Some(ResolveRequest::new( + request_stream, + decoded, + self.max_field_section_size, + ))) + } + + /// Initiate a graceful shutdown, accepting `max_request` potentially still in-flight + /// + /// See [connection shutdown](https://www.rfc-editor.org/rfc/rfc9114.html#connection-shutdown) for more information. + pub async fn shutdown(&mut self, max_requests: usize) -> Result<(), Error> { + let max_id = self + .last_accepted_stream + .map(|id| id + max_requests) + .unwrap_or(StreamId::FIRST_REQUEST); + + self.inner.shutdown(&mut self.sent_closing, max_id).await + } + + /// Accepts an incoming bidirectional stream. + /// + /// This could be either a *Request* or a *WebTransportBiStream*, the first frame's type + /// decides. + pub fn poll_accept_request( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Error>> { + let _ = self.poll_control(cx)?; + let _ = self.poll_requests_completion(cx); + loop { + match self.inner.poll_accept_request(cx) { + Poll::Ready(Err(x)) => break Poll::Ready(Err(x)), + Poll::Ready(Ok(None)) => { + if self.poll_requests_completion(cx).is_ready() { + break Poll::Ready(Ok(None)); + } else { + // Wait for all the requests to be finished, request_end_recv will wake + // us on each request completion. + break Poll::Pending; + } + } + Poll::Pending => { + if self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready() { + // The connection is now idle. + break Poll::Ready(Ok(None)); + } else { + return Poll::Pending; + } + } + Poll::Ready(Ok(Some(mut s))) => { + // When the connection is in a graceful shutdown procedure, reject all + // incoming requests not belonging to the grace interval. It's possible that + // some acceptable request streams arrive after rejected requests. + if let Some(max_id) = self.sent_closing { + if s.send_id() > max_id { + s.stop_sending(Code::H3_REQUEST_REJECTED.value()); + s.reset(Code::H3_REQUEST_REJECTED.value()); + if self.poll_requests_completion(cx).is_ready() { + break Poll::Ready(Ok(None)); + } + continue; + } + } + self.last_accepted_stream = Some(s.send_id()); + self.ongoing_streams.insert(s.send_id()); + break Poll::Ready(Ok(Some(s))); + } + }; + } + } + + pub(crate) fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll> { + while (self.poll_next_control(cx)?).is_ready() {} + Poll::Pending + } + + pub(crate) fn poll_next_control( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Error>> { + let frame = ready!(self.inner.poll_control(cx))?; + + match &frame { + Frame::Settings(w) => trace!("Got settings > {:?}", w), + &Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?, + f @ Frame::MaxPushId(_) | f @ Frame::CancelPush(_) => { + warn!("Control frame ignored {:?}", f); + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.3 + //= type=TODO + //# If a server receives a CANCEL_PUSH frame for a push + //# ID that has not yet been mentioned by a PUSH_PROMISE frame, this MUST + //# be treated as a connection error of type H3_ID_ERROR. + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.7 + //= type=TODO + //# A MAX_PUSH_ID frame cannot reduce the maximum push + //# ID; receipt of a MAX_PUSH_ID frame that contains a smaller value than + //# previously received MUST be treated as a connection error of type + //# H3_ID_ERROR. + } + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 + //# A server MUST treat the + //# receipt of a PUSH_PROMISE frame as a connection error of type + //# H3_FRAME_UNEXPECTED. + frame => { + return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.with_reason( + format!("on server control stream: {:?}", frame), + ErrorLevel::ConnectionError, + ))) + } + } + Poll::Ready(Ok(frame)) + } + + fn poll_requests_completion(&mut self, cx: &mut Context<'_>) -> Poll<()> { + loop { + match self.request_end_recv.poll_recv(cx) { + // The channel is closed + Poll::Ready(None) => return Poll::Ready(()), + // A request has completed + Poll::Ready(Some(id)) => { + self.ongoing_streams.remove(&id); + } + Poll::Pending => { + if self.ongoing_streams.is_empty() { + // Tell the caller there is not more ongoing requests. + // Still, the completion of future requests will wake us. + return Poll::Ready(()); + } else { + return Poll::Pending; + } + } + } + } + } +} + +impl Connection +where + C: quic::Connection + SendDatagramExt, + B: Buf, +{ + /// Sends a datagram + pub fn send_datagram(&mut self, stream_id: StreamId, data: B) -> Result<(), Error> { + self.inner + .conn + .send_datagram(Datagram::new(stream_id, data))?; + tracing::info!("Sent datagram"); + + Ok(()) + } +} + +impl Connection +where + C: quic::Connection + RecvDatagramExt, + B: Buf, +{ + /// Reads an incoming datagram + pub fn read_datagram(&mut self) -> ReadDatagram { + ReadDatagram { + conn: self, + _marker: PhantomData, + } + } +} + +impl Drop for Connection +where + C: quic::Connection, + B: Buf, +{ + fn drop(&mut self) { + self.inner.close(Code::H3_NO_ERROR, ""); + } +} + +//= https://www.rfc-editor.org/rfc/rfc9114#section-6.1 +//= type=TODO +//# In order to +//# permit these streams to open, an HTTP/3 server SHOULD configure non- +//# zero minimum values for the number of permitted streams and the +//# initial stream flow-control window. + +//= https://www.rfc-editor.org/rfc/rfc9114#section-6.1 +//= type=TODO +//# So as to not unnecessarily limit +//# parallelism, at least 100 request streams SHOULD be permitted at a +//# time. + +pub(super) struct RequestEnd { + pub(super) request_end: mpsc::UnboundedSender, + pub(super) stream_id: StreamId, +} diff --git a/h3/src/server/mod.rs b/h3/src/server/mod.rs new file mode 100644 index 00000000..a4663d5b --- /dev/null +++ b/h3/src/server/mod.rs @@ -0,0 +1,62 @@ +//! This module provides methods to create a http/3 Server. +//! +//! It allows to accept incoming requests, and send responses. +//! +//! # Examples +//! +//! ## Simple example +//! ```rust +//! async fn doc(conn: C) +//! where +//! C: h3::quic::Connection, +//! >::BidiStream: Send + 'static +//! { +//! let mut server_builder = h3::server::builder(); +//! // Build the Connection +//! let mut h3_conn = server_builder.build(conn).await.unwrap(); +//! loop { +//! // Accept incoming requests +//! match h3_conn.accept().await { +//! Ok(Some((req, mut stream))) => { +//! // spawn a new task to handle the request +//! tokio::spawn(async move { +//! // build a http response +//! let response = http::Response::builder().status(http::StatusCode::OK).body(()).unwrap(); +//! // send the response to the wire +//! stream.send_response(response).await.unwrap(); +//! // send some date +//! stream.send_data(bytes::Bytes::from("test")).await.unwrap(); +//! // finnish the stream +//! stream.finish().await.unwrap(); +//! }); +//! } +//! Ok(None) => { +//! // break if no Request is accepted +//! break; +//! } +//! Err(err) => { +//! match err.get_error_level() { +//! // break on connection errors +//! h3::error::ErrorLevel::ConnectionError => break, +//! // continue on stream errors +//! h3::error::ErrorLevel::StreamError => continue, +//! } +//! } +//! } +//! } +//! } +//! ``` +//! +//! ## File server +//! A ready-to-use example of a file server is available [here](https://github.com/hyperium/h3/blob/master/examples/server.rs) + +mod builder; +mod connection; +mod request; +mod stream; + +pub use builder::builder; +pub use builder::Builder; +pub use connection::Connection; +pub use stream::ReadDatagram; +pub use stream::RequestStream; diff --git a/h3/src/server/request.rs b/h3/src/server/request.rs new file mode 100644 index 00000000..b70e3dfa --- /dev/null +++ b/h3/src/server/request.rs @@ -0,0 +1,94 @@ +use std::convert::TryFrom; + +use bytes::Buf; +use http::{Request, StatusCode}; + +use crate::{error::Code, proto::headers::Header, qpack, quic, Error}; + +use super::stream::RequestStream; + +pub struct ResolveRequest, B: Buf> { + request_stream: RequestStream, + // Ok or `REQUEST_HEADER_FIELDS_TO_LARGE` which neeeds to be sent + decoded: Result, + max_field_section_size: u64, +} + +impl> ResolveRequest { + pub fn new( + request_stream: RequestStream, + decoded: Result, + max_field_section_size: u64, + ) -> Self { + Self { + request_stream, + decoded, + max_field_section_size, + } + } + + /// Finishes the resolution of the request + pub async fn resolve( + mut self, + ) -> Result<(Request<()>, RequestStream), Error> { + let fields = match self.decoded { + Ok(v) => v.fields, + Err(cancel_size) => { + // Send and await the error response + self.request_stream + .send_response( + http::Response::builder() + .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE) + .body(()) + .expect("header too big response"), + ) + .await?; + + return Err(Error::header_too_big( + cancel_size, + self.max_field_section_size, + )); + } + }; + + // Parse the request headers + let (method, uri, protocol, headers) = match Header::try_from(fields) { + Ok(header) => match header.into_request_parts() { + Ok(parts) => parts, + Err(err) => { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 + //# Malformed requests or responses that are + //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. + let error: Error = err.into(); + self.request_stream + .stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); + return Err(error); + } + }, + Err(err) => { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 + //# Malformed requests or responses that are + //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. + let error: Error = err.into(); + self.request_stream + .stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); + return Err(error); + } + }; + + // request_stream.stop_stream(Code::H3_MESSAGE_ERROR).await; + let mut req = http::Request::new(()); + *req.method_mut() = method; + *req.uri_mut() = uri; + *req.headers_mut() = headers; + // NOTE: insert `Protocol` and not `Option` + if let Some(protocol) = protocol { + req.extensions_mut().insert(protocol); + } + *req.version_mut() = http::Version::HTTP_3; + // send the grease frame only once + // self.inner.send_grease_frame = false; + tracing::trace!("replying with: {:?}", req); + Ok((req, self.request_stream)) + } +} diff --git a/h3/src/server/stream.rs b/h3/src/server/stream.rs new file mode 100644 index 00000000..610b83b6 --- /dev/null +++ b/h3/src/server/stream.rs @@ -0,0 +1,243 @@ +//! Server-side HTTP/3 stream management + +use bytes::Buf; + +use crate::{ + connection::{ConnectionState, SharedStateRef}, + ext::Datagram, + quic::{self, RecvDatagramExt}, + Error, +}; +use pin_project_lite::pin_project; + +use super::connection::{Connection, RequestEnd}; +use std::{marker::PhantomData, sync::Arc}; + +use std::{ + option::Option, + result::Result, + task::{Context, Poll}, +}; + +use bytes::BytesMut; +use futures_util::{future::Future, ready}; +use http::{response, HeaderMap, Response}; + +use quic::StreamId; + +use crate::{ + error::Code, + proto::{frame::Frame, headers::Header}, + qpack, + quic::SendStream as _, + stream::{self}, +}; + +use tracing::error; + +/// Manage request and response transfer for an incoming request +/// +/// The [`RequestStream`] struct is used to send and/or receive +/// information from the client. +pub struct RequestStream { + pub(super) inner: crate::connection::RequestStream, + pub(super) request_end: Arc, +} + +impl AsMut> for RequestStream { + fn as_mut(&mut self) -> &mut crate::connection::RequestStream { + &mut self.inner + } +} + +impl ConnectionState for RequestStream { + fn shared_state(&self) -> &SharedStateRef { + &self.inner.conn_state + } +} + +impl RequestStream +where + S: quic::RecvStream, + B: Buf, +{ + /// Receive data sent from the client + pub async fn recv_data(&mut self) -> Result, Error> { + self.inner.recv_data().await + } + + /// Poll for data sent from the client + pub fn poll_recv_data( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Error>> { + self.inner.poll_recv_data(cx) + } + + /// Receive an optional set of trailers for the request + pub async fn recv_trailers(&mut self) -> Result, Error> { + self.inner.recv_trailers().await + } + + /// Tell the peer to stop sending into the underlying QUIC stream + pub fn stop_sending(&mut self, error_code: crate::error::Code) { + self.inner.stream.stop_sending(error_code) + } + + /// Returns the underlying stream id + pub fn id(&self) -> StreamId { + self.inner.stream.id() + } +} + +impl RequestStream +where + S: quic::SendStream, + B: Buf, +{ + /// Send the HTTP/3 response + /// + /// This should be called before trying to send any data with + /// [`RequestStream::send_data`]. + pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> { + let (parts, _) = resp.into_parts(); + let response::Parts { + status, headers, .. + } = parts; + let headers = Header::response(status, headers); + + let mut block = BytesMut::new(); + let mem_size = qpack::encode_stateless(&mut block, headers)?; + + let max_mem_size = self + .inner + .conn_state + .read("send_response") + .peer_config + .max_field_section_size; + + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 + //# An implementation that + //# has received this parameter SHOULD NOT send an HTTP message header + //# that exceeds the indicated size, as the peer will likely refuse to + //# process it. + if mem_size > max_mem_size { + return Err(Error::header_too_big(mem_size, max_mem_size)); + } + + stream::write(&mut self.inner.stream, Frame::Headers(block.freeze())) + .await + .map_err(|e| self.maybe_conn_err(e))?; + + Ok(()) + } + + /// Send some data on the response body. + pub async fn send_data(&mut self, buf: B) -> Result<(), Error> { + self.inner.send_data(buf).await + } + + /// Stop a stream with an error code + /// + /// The code can be [`Code::H3_NO_ERROR`]. + pub fn stop_stream(&mut self, error_code: Code) { + self.inner.stop_stream(error_code); + } + + /// Send a set of trailers to end the response. + /// + /// Either [`RequestStream::finish`] or + /// [`RequestStream::send_trailers`] must be called to finalize a + /// request. + pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> { + self.inner.send_trailers(trailers).await + } + + /// End the response without trailers. + /// + /// Either [`RequestStream::finish`] or + /// [`RequestStream::send_trailers`] must be called to finalize a + /// request. + pub async fn finish(&mut self) -> Result<(), Error> { + self.inner.finish().await + } + + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.1 + //= type=TODO + //# Implementations SHOULD cancel requests by abruptly terminating any + //# directions of a stream that are still open. To do so, an + //# implementation resets the sending parts of streams and aborts reading + //# on the receiving parts of streams; see Section 2.4 of + //# [QUIC-TRANSPORT]. + + /// Returns the underlying stream id + pub fn send_id(&self) -> StreamId { + self.inner.stream.send_id() + } +} + +impl RequestStream +where + S: quic::BidiStream, + B: Buf, +{ + /// Splits the Request-Stream into send and receive. + /// This can be used the send and receive data on different tasks. + pub fn split( + self, + ) -> ( + RequestStream, + RequestStream, + ) { + let (send, recv) = self.inner.split(); + ( + RequestStream { + inner: send, + request_end: self.request_end.clone(), + }, + RequestStream { + inner: recv, + request_end: self.request_end, + }, + ) + } +} + +impl Drop for RequestEnd { + fn drop(&mut self) { + if let Err(e) = self.request_end.send(self.stream_id) { + error!( + "failed to notify connection of request end: {} {}", + self.stream_id, e + ); + } + } +} + +pin_project! { + /// Future for [`Connection::read_datagram`] + pub struct ReadDatagram<'a, C, B> + where + C: quic::Connection, + B: Buf, + { + pub(super) conn: &'a mut Connection, + pub(super) _marker: PhantomData, + } +} + +impl<'a, C, B> Future for ReadDatagram<'a, C, B> +where + C: quic::Connection + RecvDatagramExt, + B: Buf, +{ + type Output = Result>, Error>; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + tracing::trace!("poll: read_datagram"); + match ready!(self.conn.inner.conn.poll_accept_datagram(cx))? { + Some(v) => Poll::Ready(Ok(Some(Datagram::decode(v)?))), + None => Poll::Ready(Ok(None)), + } + } +} diff --git a/h3/src/stream.rs b/h3/src/stream.rs index 0314867d..ab91f559 100644 --- a/h3/src/stream.rs +++ b/h3/src/stream.rs @@ -1,24 +1,31 @@ -use std::task::{Context, Poll}; +use std::{ + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; -use bytes::{Buf, BufMut as _, Bytes}; +use bytes::{Buf, BufMut, Bytes}; use futures_util::{future, ready}; -use quic::RecvStream; +use pin_project_lite::pin_project; +use tokio::io::ReadBuf; use crate::{ buf::BufList, error::{Code, ErrorLevel}, frame::FrameStream, proto::{ - coding::{BufExt, Decode as _, Encode}, - frame::Frame, + coding::{Decode as _, Encode}, + frame::{Frame, Settings}, stream::StreamType, varint::VarInt, }, - quic::{self, SendStream}, + quic::{self, BidiStream, RecvStream, SendStream, SendStreamUnframed}, + webtransport::SessionId, Error, }; #[inline] +/// Transmits data by encoding in wire format. pub(crate) async fn write(stream: &mut S, data: D) -> Result<(), Error> where S: SendStream, @@ -43,10 +50,7 @@ const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_E /// data is necessary (say, in `quic::SendStream::send_data`). It also has a public API ergonomy /// advantage: `WriteBuf` doesn't have to appear in public associated types. On the other hand, /// QUIC implementers have to call `into()`, which will encode the header in `Self::buf`. -pub struct WriteBuf -where - B: Buf, -{ +pub struct WriteBuf { buf: [u8; WRITE_BUF_ENCODE_SIZE], len: usize, pos: usize, @@ -59,10 +63,17 @@ where { fn encode_stream_type(&mut self, ty: StreamType) { let mut buf_mut = &mut self.buf[self.len..]; + ty.encode(&mut buf_mut); self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut(); } + fn encode_value(&mut self, value: impl Encode) { + let mut buf_mut = &mut self.buf[self.len..]; + value.encode(&mut buf_mut); + self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut(); + } + fn encode_frame_header(&mut self) { if let Some(frame) = self.frame.as_ref() { let mut buf_mut = &mut self.buf[self.len..]; @@ -88,6 +99,80 @@ where } } +impl From for WriteBuf +where + B: Buf, +{ + fn from(header: UniStreamHeader) -> Self { + let mut this = Self { + buf: [0; WRITE_BUF_ENCODE_SIZE], + len: 0, + pos: 0, + frame: None, + }; + + this.encode_value(header); + this + } +} + +pub enum UniStreamHeader { + Control(Settings), + WebTransportUni(SessionId), +} + +impl Encode for UniStreamHeader { + fn encode(&self, buf: &mut B) { + match self { + Self::Control(settings) => { + StreamType::CONTROL.encode(buf); + settings.encode(buf); + } + Self::WebTransportUni(session_id) => { + StreamType::WEBTRANSPORT_UNI.encode(buf); + session_id.encode(buf); + } + } + } +} + +impl From for WriteBuf +where + B: Buf, +{ + fn from(header: BidiStreamHeader) -> Self { + let mut this = Self { + buf: [0; WRITE_BUF_ENCODE_SIZE], + len: 0, + pos: 0, + frame: None, + }; + + this.encode_value(header); + this + } +} + +pub enum BidiStreamHeader { + Control(Settings), + WebTransportBidi(SessionId), +} + +impl Encode for BidiStreamHeader { + fn encode(&self, buf: &mut B) { + match self { + Self::Control(settings) => { + StreamType::CONTROL.encode(buf); + settings.encode(buf); + } + Self::WebTransportBidi(session_id) => { + StreamType::WEBTRANSPORT_BIDI.encode(buf); + session_id.encode(buf); + } + } + } +} + impl From> for WriteBuf where B: Buf, @@ -116,7 +201,7 @@ where pos: 0, frame: Some(frame), }; - me.encode_stream_type(ty); + me.encode_value(ty); me.encode_frame_header(); me } @@ -162,50 +247,49 @@ where pub(super) enum AcceptedRecvStream where S: quic::RecvStream, + B: Buf, { Control(FrameStream), - Push(u64, FrameStream), - Encoder(S), - Decoder(S), + Push(FrameStream), + Encoder(BufRecvStream), + Decoder(BufRecvStream), + WebTransportUni(SessionId, BufRecvStream), Reserved, } -pub(super) struct AcceptRecvStream -where - S: quic::RecvStream, -{ - stream: S, +/// Resolves an incoming streams type as well as `PUSH_ID`s and `SESSION_ID`s +pub(super) struct AcceptRecvStream { + stream: BufRecvStream, ty: Option, - push_id: Option, - buf: BufList, + /// push_id or session_id + id: Option, expected: Option, } -impl AcceptRecvStream +impl AcceptRecvStream where S: RecvStream, + B: Buf, { pub fn new(stream: S) -> Self { Self { - stream, + stream: BufRecvStream::new(stream), ty: None, - push_id: None, - buf: BufList::new(), + id: None, expected: None, } } - pub fn into_stream(self) -> Result, Error> { + pub fn into_stream(self) -> Result, Error> { Ok(match self.ty.expect("Stream type not resolved yet") { - StreamType::CONTROL => { - AcceptedRecvStream::Control(FrameStream::with_bufs(self.stream, self.buf)) - } - StreamType::PUSH => AcceptedRecvStream::Push( - self.push_id.expect("Push ID not resolved yet"), - FrameStream::with_bufs(self.stream, self.buf), - ), + StreamType::CONTROL => AcceptedRecvStream::Control(FrameStream::new(self.stream)), + StreamType::PUSH => AcceptedRecvStream::Push(FrameStream::new(self.stream)), StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream), StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream), + StreamType::WEBTRANSPORT_UNI => AcceptedRecvStream::WebTransportUni( + SessionId::from_varint(self.id.expect("Session ID not resolved yet")), + self.stream, + ), t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved, //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2 @@ -233,37 +317,42 @@ where pub fn poll_type(&mut self, cx: &mut Context) -> Poll> { loop { - match (self.ty.as_ref(), self.push_id) { - // When accepting a Push stream, we want to parse two VarInts: [StreamType, PUSH_ID] - (Some(&StreamType::PUSH), Some(_)) | (Some(_), _) => return Poll::Ready(Ok(())), - _ => (), - } - - match ready!(self.stream.poll_data(cx))? { - Some(mut b) => self.buf.push_bytes(&mut b), - None => { - return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason( - "Stream closed before type received", - ErrorLevel::ConnectionError, - ))); + // Return if all identification data is met + match self.ty { + Some(StreamType::PUSH | StreamType::WEBTRANSPORT_UNI) => { + if self.id.is_some() { + return Poll::Ready(Ok(())); + } } + Some(_) => return Poll::Ready(Ok(())), + None => (), + }; + + if ready!(self.stream.poll_read(cx))? { + return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason( + "Stream closed before type received", + ErrorLevel::ConnectionError, + ))); }; - if self.expected.is_none() && self.buf.remaining() >= 1 { - self.expected = Some(VarInt::encoded_size(self.buf.chunk()[0])); + let mut buf = self.stream.buf_mut(); + if self.expected.is_none() && buf.remaining() >= 1 { + self.expected = Some(VarInt::encoded_size(buf.chunk()[0])); } if let Some(expected) = self.expected { - if self.buf.remaining() < expected { + // Poll for more data + if buf.remaining() < expected { continue; } } else { continue; } + // Parse ty and then id if self.ty.is_none() { // Parse StreamType - self.ty = Some(StreamType::decode(&mut self.buf).map_err(|_| { + self.ty = Some(StreamType::decode(&mut buf).map_err(|_| { Code::H3_INTERNAL_ERROR.with_reason( "Unexpected end parsing stream type", ErrorLevel::ConnectionError, @@ -273,9 +362,9 @@ where self.expected = None; } else { // Parse PUSH_ID - self.push_id = Some(self.buf.get_var().map_err(|_| { + self.id = Some(VarInt::decode(&mut buf).map_err(|_| { Code::H3_INTERNAL_ERROR.with_reason( - "Unexpected end parsing stream type", + "Unexpected end parsing push or session id", ErrorLevel::ConnectionError, ) })?); @@ -284,10 +373,336 @@ where } } +pin_project! { + /// A stream which allows partial reading of the data without data loss. + /// + /// This fixes the problem where `poll_data` returns more than the needed amount of bytes, + /// requiring correct implementations to hold on to that extra data and return it later. + /// + /// # Usage + /// + /// Implements `quic::RecvStream` which will first return buffered data, and then read from the + /// stream + pub struct BufRecvStream { + buf: BufList, + // Indicates that the end of the stream has been reached + // + // Data may still be available as buffered + eos: bool, + stream: S, + _marker: PhantomData, + } +} + +impl std::fmt::Debug for BufRecvStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufRecvStream") + .field("buf", &self.buf) + .field("eos", &self.eos) + .field("stream", &"...") + .finish() + } +} + +impl BufRecvStream { + pub fn new(stream: S) -> Self { + Self { + buf: BufList::new(), + eos: false, + stream, + _marker: PhantomData, + } + } +} + +impl BufRecvStream { + /// Reads more data into the buffer, returning the number of bytes read. + /// + /// Returns `true` if the end of the stream is reached. + pub fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { + let data = ready!(self.stream.poll_data(cx))?; + + if let Some(mut data) = data { + self.buf.push_bytes(&mut data); + Poll::Ready(Ok(false)) + } else { + self.eos = true; + Poll::Ready(Ok(true)) + } + } + + /// Returns the currently buffered data, allowing it to be partially read + #[inline] + pub(crate) fn buf_mut(&mut self) -> &mut BufList { + &mut self.buf + } + + /// Returns the next chunk of data from the stream + /// + /// Return `None` when there is no more buffered data; use [`Self::poll_read`]. + pub fn take_chunk(&mut self, limit: usize) -> Option { + self.buf.take_chunk(limit) + } + + /// Returns true if there is remaining buffered data + pub fn has_remaining(&mut self) -> bool { + self.buf.has_remaining() + } + + #[inline] + pub(crate) fn buf(&self) -> &BufList { + &self.buf + } + + pub fn is_eos(&self) -> bool { + self.eos + } +} + +impl RecvStream for BufRecvStream { + type Buf = Bytes; + + type Error = S::Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>> { + // There is data buffered, return that immediately + if let Some(chunk) = self.buf.take_first_chunk() { + return Poll::Ready(Ok(Some(chunk))); + } + + if let Some(mut data) = ready!(self.stream.poll_data(cx))? { + Poll::Ready(Ok(Some(data.copy_to_bytes(data.remaining())))) + } else { + self.eos = true; + Poll::Ready(Ok(None)) + } + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.recv_id() + } +} + +impl SendStream for BufRecvStream +where + B: Buf, + S: SendStream, +{ + type Error = S::Error; + + fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.stream.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.stream.send_id() + } + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.stream.send_data(data) + } +} + +impl SendStreamUnframed for BufRecvStream +where + B: Buf, + S: SendStreamUnframed, +{ + #[inline] + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } +} + +impl BidiStream for BufRecvStream +where + B: Buf, + S: BidiStream, +{ + type SendStream = BufRecvStream; + + type RecvStream = BufRecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + let (send, recv) = self.stream.split(); + ( + BufRecvStream { + // Sending is not buffered + buf: BufList::new(), + eos: self.eos, + stream: send, + _marker: PhantomData, + }, + BufRecvStream { + buf: self.buf, + eos: self.eos, + stream: recv, + _marker: PhantomData, + }, + ) + } +} + +impl futures_util::io::AsyncRead for BufRecvStream +where + B: Buf, + S: RecvStream, + S::Error: Into, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let p = &mut *self; + // Poll for data if the buffer is empty + // + // If there is data available *do not* poll for more data, as that may suspend indefinitely + // if no more data is sent, causing data loss. + if !p.has_remaining() { + let eos = ready!(p.poll_read(cx).map_err(Into::into))?; + if eos { + return Poll::Ready(Ok(0)); + } + } + + let chunk = p.buf_mut().take_chunk(buf.len()); + if let Some(chunk) = chunk { + assert!(chunk.len() <= buf.len()); + let len = chunk.len().min(buf.len()); + // Write the subset into the destination + buf[..len].copy_from_slice(&chunk); + Poll::Ready(Ok(len)) + } else { + Poll::Ready(Ok(0)) + } + } +} + +impl tokio::io::AsyncRead for BufRecvStream +where + B: Buf, + S: RecvStream, + S::Error: Into, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let p = &mut *self; + // Poll for data if the buffer is empty + // + // If there is data available *do not* poll for more data, as that may suspend indefinitely + // if no more data is sent, causing data loss. + if !p.has_remaining() { + let eos = ready!(p.poll_read(cx).map_err(Into::into))?; + if eos { + return Poll::Ready(Ok(())); + } + } + + let chunk = p.buf_mut().take_chunk(buf.remaining()); + if let Some(chunk) = chunk { + assert!(chunk.len() <= buf.remaining()); + // Write the subset into the destination + buf.put_slice(&chunk); + Poll::Ready(Ok(())) + } else { + Poll::Ready(Ok(())) + } + } +} + +impl futures_util::io::AsyncWrite for BufRecvStream +where + B: Buf, + S: SendStreamUnframed, + S::Error: Into, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + let p = &mut *self; + p.poll_send(cx, &mut buf).map_err(Into::into) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let p = &mut *self; + p.poll_finish(cx).map_err(Into::into) + } +} + +impl tokio::io::AsyncWrite for BufRecvStream +where + B: Buf, + S: SendStreamUnframed, + S::Error: Into, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + let p = &mut *self; + p.poll_send(cx, &mut buf).map_err(Into::into) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let p = &mut *self; + p.poll_finish(cx).map_err(Into::into) + } +} + #[cfg(test)] mod tests { + use quinn_proto::coding::BufExt; + use super::*; + #[test] + fn write_wt_uni_header() { + let mut w = WriteBuf::::from(UniStreamHeader::WebTransportUni( + SessionId::from_varint(VarInt(5)), + )); + + let ty = w.get_var().unwrap(); + println!("Got type: {ty} {ty:#x}"); + assert_eq!(ty, 0x54); + + let id = w.get_var().unwrap(); + println!("Got id: {id}"); + } + #[test] fn write_buf_encode_streamtype() { let wbuf = WriteBuf::::from(StreamType::ENCODER); diff --git a/h3/src/tests/connection.rs b/h3/src/tests/connection.rs index 76056965..7d65b604 100644 --- a/h3/src/tests/connection.rs +++ b/h3/src/tests/connection.rs @@ -5,11 +5,13 @@ use std::{borrow::BorrowMut, time::Duration}; use assert_matches::assert_matches; use bytes::{Buf, Bytes, BytesMut}; -use futures_util::{future, StreamExt}; +use futures_util::future; use http::{Request, Response, StatusCode}; +use tokio::sync::oneshot::{self}; +use crate::client::SendRequest; +use crate::{client, server}; use crate::{ - client::{self, SendRequest}, connection::ConnectionState, error::{Code, Error, Kind}, proto::{ @@ -20,7 +22,6 @@ use crate::{ varint::VarInt, }, quic::{self, SendStream}, - server, }; use super::h3_quinn; @@ -32,15 +33,16 @@ async fn connect() { let mut server = pair.server(); let client_fut = async { - let _ = client::new(pair.client().await).await.expect("client init"); + let (mut drive, _client) = client::new(pair.client().await).await.expect("client init"); + future::poll_fn(|cx| drive.poll_close(cx)).await.unwrap(); }; let server_fut = async { let conn = server.next().await; - let _ = server::Connection::new(conn).await.unwrap(); + let _server = server::Connection::new(conn).await.unwrap(); }; - tokio::join!(server_fut, client_fut); + tokio::select!(() = server_fut => (), () = client_fut => panic!("client resolved first")); } #[tokio::test] @@ -48,14 +50,21 @@ async fn accept_request_end_on_client_close() { let mut pair = Pair::default(); let mut server = pair.server(); + let (tx, rx) = oneshot::channel::<()>(); + let client_fut = async { - let _ = client::new(pair.client().await).await.expect("client init"); + let client = pair.client().await; + let client = client::new(client).await.expect("client init"); + // wait for the server to accept the connection + rx.await.unwrap(); // client is dropped, it will send H3_NO_ERROR + drop(client); }; let server_fut = async { let conn = server.next().await; let mut incoming = server::Connection::new(conn).await.unwrap(); + tx.send(()).unwrap(); // Accept returns Ok(None) assert!(incoming.accept().await.unwrap().is_none()); }; @@ -65,6 +74,7 @@ async fn accept_request_end_on_client_close() { #[tokio::test] async fn server_drop_close() { + init_tracing(); let mut pair = Pair::default(); let mut server = pair.server(); @@ -73,8 +83,8 @@ async fn server_drop_close() { let _ = server::Connection::new(conn).await.unwrap(); }; - let (mut conn, mut send) = client::new(pair.client().await).await.expect("client init"); let client_fut = async { + let (mut conn, mut send) = client::new(pair.client().await).await.expect("client init"); let request_fut = async move { let mut request_stream = send .send_request(Request::get("http://no.way").body(()).unwrap()) @@ -93,8 +103,45 @@ async fn server_drop_close() { tokio::join!(server_fut, client_fut); } +// In this test the client calls send_data() without doing a finish(), +// i.e client keeps the body stream open. And cient expects server to +// read_data() and send a response +#[tokio::test] +async fn server_send_data_without_finish() { + let mut pair = Pair::default(); + let mut server = pair.server(); + + let client_fut = async { + let (_driver, mut send_request) = client::new(pair.client().await).await.unwrap(); + + let mut req = send_request + .send_request(Request::get("http://no.way").body(()).unwrap()) + .await + .unwrap(); + let data = vec![0; 100]; + req.send_data(bytes::Bytes::copy_from_slice(&data)) + .await + .unwrap(); + let _ = req.recv_response().await.unwrap(); + }; + + let server_fut = async { + let conn = server.next().await; + let mut incoming = server::Connection::new(conn).await.unwrap(); + let (_, mut stream) = incoming.accept().await.unwrap().unwrap(); + let mut data = stream.recv_data().await.unwrap().unwrap(); + let data = data.copy_to_bytes(data.remaining()); + assert_eq!(data.len(), 100); + response(stream).await; + server.endpoint.wait_idle().await; + }; + + tokio::join!(server_fut, client_fut); +} + #[tokio::test] async fn client_close_only_on_last_sender_drop() { + init_tracing(); let mut pair = Pair::default(); let mut server = pair.server(); @@ -109,18 +156,24 @@ async fn client_close_only_on_last_sender_drop() { let client_fut = async { let (mut conn, mut send1) = client::new(pair.client().await).await.expect("client init"); let mut send2 = send1.clone(); - let _ = send1 + let mut request_stream_1 = send1 .send_request(Request::get("http://no.way").body(()).unwrap()) .await - .unwrap() - .finish() - .await; - let _ = send2 + .unwrap(); + + let _ = request_stream_1.recv_response().await; + + let _ = request_stream_1.finish().await; + + let mut request_stream_2 = send2 .send_request(Request::get("http://no.way").body(()).unwrap()) .await - .unwrap() - .finish() - .await; + .unwrap(); + + let _ = request_stream_2.recv_response().await; + + let _ = request_stream_2.finish().await; + drop(send1); drop(send2); @@ -144,7 +197,8 @@ async fn settings_exchange_client() { if client .shared_state() .read("client") - .peer_max_field_section_size + .peer_config + .max_field_section_size == 12 { return; @@ -202,7 +256,12 @@ async fn settings_exchange_server() { let settings_change = async { for _ in 0..10 { - if state.read("setting_change").peer_max_field_section_size == 12 { + if state + .read("setting_change") + .peer_config + .max_field_section_size + == 12 + { return; } tokio::time::sleep(Duration::from_millis(2)).await; @@ -218,7 +277,7 @@ async fn settings_exchange_server() { #[tokio::test] async fn client_error_on_bidi_recv() { let mut pair = Pair::default(); - let mut server = pair.server(); + let server = pair.server(); macro_rules! check_err { ($e:expr) => { @@ -247,8 +306,7 @@ async fn client_error_on_bidi_recv() { }; let server_fut = async { - let quinn::NewConnection { connection, .. } = - server.incoming.next().await.unwrap().await.unwrap(); + let connection = server.endpoint.accept().await.unwrap().await.unwrap(); let (mut send, _recv) = connection.open_bi().await.unwrap(); for _ in 0..100 { match send.write(b"I'm not really a server").await { @@ -274,7 +332,7 @@ async fn two_control_streams() { let mut server = pair.server(); let client_fut = async { - let new_connection = pair.client_inner().await; + let connection = pair.client_inner().await; //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 //= type=test @@ -282,7 +340,7 @@ async fn two_control_streams() { //# receipt of a second stream claiming to be a control stream MUST be //# treated as a connection error of type H3_STREAM_CREATION_ERROR. for _ in 0..=1 { - let mut control_stream = new_connection.connection.open_uni().await.unwrap(); + let mut control_stream = connection.open_uni().await.unwrap(); let mut buf = BytesMut::new(); StreamType::CONTROL.encode(&mut buf); control_stream.write_all(&buf[..]).await.unwrap(); @@ -313,8 +371,8 @@ async fn control_close_send_error() { let mut server = pair.server(); let client_fut = async { - let new_connection = pair.client_inner().await; - let mut control_stream = new_connection.connection.open_uni().await.unwrap(); + let connection = pair.client_inner().await; + let mut control_stream = connection.open_uni().await.unwrap(); let mut buf = BytesMut::new(); StreamType::CONTROL.encode(&mut buf); @@ -325,19 +383,33 @@ async fn control_close_send_error() { //# If either control //# stream is closed at any point, this MUST be treated as a connection //# error of type H3_CLOSED_CRITICAL_STREAM. - control_stream.finish().await.unwrap(); // close the client control stream immediately + control_stream.finish().unwrap(); // close the client control stream immediately - let (mut driver, _send) = client::new(h3_quinn::Connection::new(new_connection)) - .await - .unwrap(); + // create the Connection manually so it does not open a second Control stream + + let connection_error = loop { + let accepted = connection.accept_bi().await; + match accepted { + // do nothing with the stream + Ok(_) => continue, + Err(err) => break err, + } + }; - future::poll_fn(|cx| driver.poll_close(cx)).await + let err_code = match connection_error { + quinn::ConnectionError::ApplicationClosed(quinn::ApplicationClose { + error_code, + .. + }) => error_code.into_inner(), + e => panic!("unexpected error: {:?}", e), + }; + assert_eq!(err_code, Code::H3_CLOSED_CRITICAL_STREAM.value()); }; let server_fut = async { let conn = server.next().await; let mut incoming = server::Connection::new(conn).await.unwrap(); - // Driver detects that the recieving side of the control stream has been closed + // Driver detects that the receiving side of the control stream has been closed assert_matches!( incoming.accept().await.map(|_| ()).unwrap_err().kind(), Kind::Application { reason: Some(reason), code: Code::H3_CLOSED_CRITICAL_STREAM, .. } @@ -349,7 +421,7 @@ async fn control_close_send_error() { if *reason == *"control stream closed"); }; - tokio::select! { _ = server_fut => (), _ = client_fut => panic!("client resolved first") }; + tokio::join!(server_fut, client_fut); } #[tokio::test] @@ -359,8 +431,8 @@ async fn missing_settings() { let mut server = pair.server(); let client_fut = async { - let new_connection = pair.client_inner().await; - let mut control_stream = new_connection.connection.open_uni().await.unwrap(); + let connection = pair.client_inner().await; + let mut control_stream = connection.open_uni().await.unwrap(); let mut buf = BytesMut::new(); StreamType::CONTROL.encode(&mut buf); @@ -398,8 +470,8 @@ async fn control_stream_frame_unexpected() { let mut server = pair.server(); let client_fut = async { - let new_connection = pair.client_inner().await; - let mut control_stream = new_connection.connection.open_uni().await.unwrap(); + let connection = pair.client_inner().await; + let mut control_stream = connection.open_uni().await.unwrap(); //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.1 //= type=test @@ -458,18 +530,18 @@ async fn timeout_on_control_frame_read() { async fn goaway_from_server_not_request_id() { init_tracing(); let mut pair = Pair::default(); - let (_, mut server) = pair.server_inner(); + let server = pair.server_inner(); let client_fut = async { - let new_connection = pair.client_inner().await; - let mut control_stream = new_connection.connection.open_uni().await.unwrap(); + let connection = pair.client_inner().await; + let mut control_stream = connection.open_uni().await.unwrap(); let mut buf = BytesMut::new(); StreamType::CONTROL.encode(&mut buf); control_stream.write_all(&buf[..]).await.unwrap(); - control_stream.finish().await.unwrap(); // close the client control stream immediately + control_stream.finish().unwrap(); // close the client control stream immediately - let (mut driver, _send) = client::new(h3_quinn::Connection::new(new_connection)) + let (mut driver, _send) = client::new(h3_quinn::Connection::new(connection)) .await .unwrap(); @@ -487,8 +559,8 @@ async fn goaway_from_server_not_request_id() { }; let server_fut = async { - let conn = server.next().await.unwrap().await.unwrap(); - let mut control_stream = conn.connection.open_uni().await.unwrap(); + let conn = server.accept().await.unwrap().await.unwrap(); + let mut control_stream = conn.open_uni().await.unwrap(); let mut buf = BytesMut::new(); StreamType::CONTROL.encode(&mut buf); diff --git a/h3/src/tests/mod.rs b/h3/src/tests/mod.rs index b3a02ce5..d34d234c 100644 --- a/h3/src/tests/mod.rs +++ b/h3/src/tests/mod.rs @@ -19,14 +19,11 @@ use std::{ }; use bytes::Bytes; -use futures_util::StreamExt; -use rustls::{Certificate, PrivateKey}; +use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use crate::quic; -use h3_quinn::{ - quinn::{Incoming, NewConnection, TransportConfig}, - Connection, -}; +use h3_quinn::{quinn::TransportConfig, Connection}; pub fn init_tracing() { let _ = tracing_subscriber::fmt() @@ -36,11 +33,10 @@ pub fn init_tracing() { .try_init(); } -#[derive(Clone)] pub struct Pair { port: u16, - cert: Certificate, - key: PrivateKey, + cert: CertificateDer<'static>, + key: PrivateKeyDer<'static>, config: Arc, } @@ -66,34 +62,36 @@ impl Pair { .initial_rtt(Duration::from_millis(10)); } - pub fn server_inner(&mut self) -> (h3_quinn::Endpoint, Incoming) { - let mut crypto = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .with_no_client_auth() - .with_single_cert(vec![self.cert.clone()], self.key.clone()) - .unwrap(); + pub fn server_inner(&mut self) -> h3_quinn::Endpoint { + let mut crypto = rustls::ServerConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(vec![self.cert.clone()], self.key.clone_key()) + .unwrap(); crypto.max_early_data_size = u32::MAX; crypto.alpn_protocols = vec![b"h3".to_vec()]; - let mut server_config = h3_quinn::quinn::ServerConfig::with_crypto(crypto.into()); + let mut server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(crypto).unwrap(), + )); server_config.transport = self.config.clone(); - let (endpoint, incoming) = + let endpoint = h3_quinn::quinn::Endpoint::server(server_config, "[::]:0".parse().unwrap()).unwrap(); self.port = endpoint.local_addr().unwrap().port(); - (endpoint, incoming) + endpoint } pub fn server(&mut self) -> Server { - let (endpoint, incoming) = self.server_inner(); - Server { endpoint, incoming } + let endpoint = self.server_inner(); + Server { endpoint } } - pub async fn client_inner(&self) -> NewConnection { + pub async fn client_inner(&self) -> quinn::Connection { let addr = (Ipv6Addr::LOCALHOST, self.port) .to_socket_addrs() .unwrap() @@ -101,18 +99,20 @@ impl Pair { .unwrap(); let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add(&self.cert).unwrap(); - let mut crypto = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); + root_cert_store.add(self.cert.clone()).unwrap(); + let mut crypto = rustls::ClientConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); crypto.enable_early_data = true; crypto.alpn_protocols = vec![b"h3".to_vec()]; - let client_config = h3_quinn::quinn::ClientConfig::new(Arc::new(crypto)); + let client_config = h3_quinn::quinn::ClientConfig::new(Arc::new( + QuicClientConfig::try_from(crypto).unwrap(), + )); let mut client_endpoint = h3_quinn::quinn::Endpoint::client("[::]:0".parse().unwrap()).unwrap(); @@ -131,18 +131,18 @@ impl Pair { pub struct Server { pub endpoint: h3_quinn::Endpoint, - pub incoming: Incoming, } impl Server { pub async fn next(&mut self) -> impl quic::Connection { - Connection::new(self.incoming.next().await.unwrap().await.unwrap()) + Connection::new(self.endpoint.accept().await.unwrap().await.unwrap()) } } -pub fn build_certs() -> (Certificate, PrivateKey) { +pub fn build_certs() -> (CertificateDer<'static>, PrivateKeyDer<'static>) { let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let key = PrivateKey(cert.serialize_private_key_der()); - let cert = Certificate(cert.serialize_der().unwrap()); - (cert, key) + ( + cert.cert.into(), + PrivateKeyDer::Pkcs8(cert.key_pair.serialize_der().into()), + ) } diff --git a/h3/src/tests/request.rs b/h3/src/tests/request.rs index 17a174b5..2001b0b3 100644 --- a/h3/src/tests/request.rs +++ b/h3/src/tests/request.rs @@ -69,6 +69,8 @@ async fn get() { .await .expect("send_data"); request_stream.finish().await.expect("finish"); + + let _ = incoming_req.accept().await.unwrap(); }; tokio::join!(server_fut, client_fut); @@ -131,6 +133,8 @@ async fn get_with_trailers_unknown_content_type() { .await .expect("send_trailers"); request_stream.finish().await.expect("finish"); + + let _ = incoming_req.accept().await.unwrap(); }; tokio::join!(server_fut, client_fut); @@ -193,6 +197,8 @@ async fn get_with_trailers_known_content_type() { .await .expect("send_trailers"); request_stream.finish().await.expect("finish"); + + let _ = incoming_req.accept().await.unwrap(); }; tokio::join!(server_fut, client_fut); @@ -246,6 +252,9 @@ async fn post() { .expect("server recv body"); assert_eq!(request_body.chunk(), b"wonderful json"); request_stream.finish().await.expect("client finish"); + + // keep connection until client is finished + let _ = incoming_req.accept().await.expect("accept"); }; tokio::join!(server_fut, client_fut); @@ -258,7 +267,6 @@ async fn header_too_big_response_from_server() { let mut server = pair.server(); let client_fut = async { - // Do not poll driver so client doesn't know about server's max_field section size setting let (mut driver, mut client) = client::new(pair.client().await).await.expect("client init"); let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; let req_fut = async { @@ -310,7 +318,6 @@ async fn header_too_big_response_from_server_trailers() { let mut server = pair.server(); let client_fut = async { - // Do not poll driver so client doesn't know about server's max_field_section_size setting let (mut driver, mut client) = client::new(pair.client().await).await.expect("client init"); let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; let req_fut = async { @@ -330,6 +337,7 @@ async fn header_too_big_response_from_server_trailers() { .await .expect("send trailers"); request_stream.finish().await.expect("client finish"); + let _ = request_stream.recv_response().await; }; tokio::select! {biased; _ = req_fut => (), _ = drive_fut => () } }; @@ -375,13 +383,25 @@ async fn header_too_big_client_error() { let client_fut = async { let (mut driver, mut client) = client::new(pair.client().await).await.expect("client init"); - let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; + let drive_fut = async { + let err = future::poll_fn(|cx| driver.poll_close(cx)) + .await + .unwrap_err(); + match err.kind() { + // The client never sends a data on the request stream + Kind::Application { code, .. } => { + assert_eq!(code, Code::H3_REQUEST_INCOMPLETE) + } + _ => panic!("unexpected error: {:?}", err), + } + }; let req_fut = async { // pretend client already received server's settings client .shared_state() .write("client") - .peer_max_field_section_size = 12; + .peer_config + .max_field_section_size = 12; let req = Request::get("http://localhost/salut").body(()).unwrap(); let err_kind = client @@ -399,20 +419,19 @@ async fn header_too_big_client_error() { } ); }; - tokio::select! {biased; _ = req_fut => (),_ = drive_fut => () } + tokio::join! {req_fut, drive_fut } }; let server_fut = async { let conn = server.next().await; - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 - //= type=test - //# An HTTP/3 implementation MAY impose a limit on the maximum size of - //# the message header it will accept on an individual HTTP message. - server::builder() + + let mut incoming_req = server::builder() .max_field_section_size(12) .build(conn) .await .unwrap(); + + let _ = incoming_req.accept().await; }; tokio::join!(server_fut, client_fut); @@ -425,14 +444,22 @@ async fn header_too_big_client_error_trailer() { let mut server = pair.server(); let client_fut = async { - // Do not poll driver so client doesn't know about server's max_field_section_size setting let (mut driver, mut client) = client::new(pair.client().await).await.expect("client init"); - let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; + let drive_fut = async { + let err = future::poll_fn(|cx| driver.poll_close(cx)) + .await + .unwrap_err(); + match err.kind() { + Kind::Timeout => (), + _ => panic!("unexpected error: {:?}", err), + } + }; let req_fut = async { client .shared_state() .write("client") - .peer_max_field_section_size = 200; + .peer_config + .max_field_section_size = 200; let mut request_stream = client .send_request(Request::get("http://localhost/salut").body(()).unwrap()) @@ -462,7 +489,7 @@ async fn header_too_big_client_error_trailer() { request_stream.finish().await.expect("client finish"); }; - tokio::select! {biased; _ = req_fut => (), _ = drive_fut => () } + tokio::join! {req_fut,drive_fut}; }; let server_fut = async { @@ -503,33 +530,38 @@ async fn header_too_big_discard_from_client() { //# that exceeds the indicated size, as the peer will likely refuse to //# process it. - // Do not poll driver so client doesn't know about server's max_field section size setting - let (_conn, mut client) = client::builder() + let (mut driver, mut client) = client::builder() .max_field_section_size(12) + // Don't send settings, so server doesn't know about the low max_field_section_size + .send_settings(false) .build::<_, _, Bytes>(pair.client().await) .await .expect("client init"); - let mut request_stream = client - .send_request(Request::get("http://localhost/salut").body(()).unwrap()) - .await - .expect("request"); - request_stream.finish().await.expect("client finish"); - let err_kind = request_stream.recv_response().await.unwrap_err().kind(); - assert_matches!( - err_kind, - Kind::HeaderTooBig { - actual_size: 42, - max_size: 12, - .. - } - ); + let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; + let req_fut = async { + let mut request_stream = client + .send_request(Request::get("http://localhost/salut").body(()).unwrap()) + .await + .expect("request"); + request_stream.finish().await.expect("client finish"); + let err_kind = request_stream.recv_response().await.unwrap_err().kind(); + assert_matches!( + err_kind, + Kind::HeaderTooBig { + actual_size: 42, + max_size: 12, + .. + } + ); - let mut request_stream = client - .send_request(Request::get("http://localhost/salut").body(()).unwrap()) - .await - .expect("request"); - request_stream.finish().await.expect("client finish"); - let _ = request_stream.recv_response().await.unwrap_err(); + let mut request_stream = client + .send_request(Request::get("http://localhost/salut").body(()).unwrap()) + .await + .expect("request"); + request_stream.finish().await.expect("client finish"); + let _ = request_stream.recv_response().await.unwrap_err(); + }; + tokio::select! {biased; _ = req_fut => (), _ = drive_fut => () } }; let server_fut = async { @@ -537,11 +569,6 @@ async fn header_too_big_discard_from_client() { let mut incoming_req = server::Connection::new(conn).await.unwrap(); let (_request, mut request_stream) = incoming_req.accept().await.expect("accept").unwrap(); - // pretend server didn't receive settings - incoming_req - .shared_state() - .write("client") - .peer_max_field_section_size = u64::MAX; request_stream .send_response( Response::builder() @@ -588,12 +615,14 @@ async fn header_too_big_discard_from_client_trailers() { //# that exceeds the indicated size, as the peer will likely refuse to //# process it. - // Do not poll driver so client doesn't know about server's max_field section size setting let (mut driver, mut client) = client::builder() .max_field_section_size(200) + // Don't send settings, so server doesn't know about the low max_field_section_size + .send_settings(false) .build::<_, _, Bytes>(pair.client().await) .await .expect("client init"); + let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; let req_fut = async { let mut request_stream = client @@ -623,12 +652,6 @@ async fn header_too_big_discard_from_client_trailers() { let (_request, mut request_stream) = incoming_req.accept().await.expect("accept").unwrap(); - // pretend server didn't receive settings - incoming_req - .shared_state() - .write("server") - .peer_max_field_section_size = u64::MAX; - request_stream .send_response( Response::builder() @@ -698,7 +721,8 @@ async fn header_too_big_server_error() { incoming_req .shared_state() .write("server") - .peer_max_field_section_size = 12; + .peer_config + .max_field_section_size = 12; let err_kind = request_stream .send_response( @@ -778,7 +802,8 @@ async fn header_too_big_server_error_trailers() { incoming_req .shared_state() .write("write") - .peer_max_field_section_size = 200; + .peer_config + .max_field_section_size = 200; let mut trailers = HeaderMap::new(); trailers.insert("trailer", "value".repeat(100).parse().unwrap()); @@ -1330,9 +1355,10 @@ fn request_encode(buf: &mut B, req: http::Request<()>) { method, uri, headers, + extensions, .. } = parts; - let headers = Header::request(method, uri, headers).unwrap(); + let headers = Header::request(method, uri, headers, extensions).unwrap(); let mut block = BytesMut::new(); qpack::encode_stateless(&mut block, headers).unwrap(); Frame::headers(block).encode_with_payload(buf); @@ -1402,13 +1428,13 @@ where let mut server = pair.server(); let client_fut = async { - let new_connection = pair.client_inner().await; - let (mut req_send, mut req_recv) = new_connection.connection.open_bi().await.unwrap(); + let connection = pair.client_inner().await; + let (mut req_send, mut req_recv) = connection.open_bi().await.unwrap(); let mut buf = BytesMut::new(); request(&mut buf); req_send.write_all(&buf[..]).await.unwrap(); - req_send.finish().await.unwrap(); + req_send.finish().unwrap(); let res = req_recv .read(&mut buf) @@ -1418,7 +1444,7 @@ where .map(|_| ()); check(res); - let (mut driver, _send) = client::new(h3_quinn::Connection::new(new_connection)) + let (mut driver, _send) = client::new(h3_quinn::Connection::new(connection)) .await .unwrap(); diff --git a/h3/src/webtransport/mod.rs b/h3/src/webtransport/mod.rs new file mode 100644 index 00000000..74ddc906 --- /dev/null +++ b/h3/src/webtransport/mod.rs @@ -0,0 +1,2 @@ +mod session_id; +pub use session_id::SessionId; diff --git a/h3/src/webtransport/session_id.rs b/h3/src/webtransport/session_id.rs new file mode 100644 index 00000000..b6f4424d --- /dev/null +++ b/h3/src/webtransport/session_id.rs @@ -0,0 +1,50 @@ +use std::convert::TryFrom; + +use crate::proto::{ + coding::{Decode, Encode}, + stream::{InvalidStreamId, StreamId}, + varint::VarInt, +}; + +/// Identifies a WebTransport session +/// +/// The session id is the same as the stream id of the CONNECT request. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct SessionId(u64); +impl SessionId { + pub(crate) fn from_varint(id: VarInt) -> SessionId { + Self(id.0) + } + + pub(crate) fn into_inner(self) -> u64 { + self.0 + } +} + +impl TryFrom for SessionId { + type Error = InvalidStreamId; + fn try_from(v: u64) -> Result { + if v > VarInt::MAX.0 { + return Err(InvalidStreamId(v)); + } + Ok(Self(v)) + } +} + +impl Encode for SessionId { + fn encode(&self, buf: &mut B) { + VarInt::from_u64(self.0).unwrap().encode(buf); + } +} + +impl Decode for SessionId { + fn decode(buf: &mut B) -> crate::proto::coding::Result { + Ok(Self(VarInt::decode(buf)?.into_inner())) + } +} + +impl From for SessionId { + fn from(value: StreamId) -> Self { + Self(value.index()) + } +}