From a09a0e81d7ddc5df511a6dd7fcd0737178e1aeba Mon Sep 17 00:00:00 2001 From: janskiba Date: Tue, 22 Oct 2024 07:31:28 +0000 Subject: [PATCH] refactor: seperate read/writing streams into get/post requests --- broker/src/serve_sockets.rs | 32 +++++++++++++++++--------------- proxy/src/serve_sockets.rs | 26 ++++++++++++++------------ 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/broker/src/serve_sockets.rs b/broker/src/serve_sockets.rs index 3b24e398..5287a7f4 100644 --- a/broker/src/serve_sockets.rs +++ b/broker/src/serve_sockets.rs @@ -1,12 +1,12 @@ use std::{sync::Arc, time::Duration}; -use axum::{body::{Body, BodyDataStream}, extract::{Path, State}, http::{header, request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, Router}; +use axum::{body::{Body, BodyDataStream}, extract::{Path, State}, http::{header, request::Parts, Method, StatusCode}, response::{IntoResponse, Response}, routing::get, Router}; use bytes::{BufMut, Bytes, BytesMut}; use dashmap::mapref::entry::Entry; use futures_util::{stream, StreamExt}; use shared::{expire_map::LazyExpireMap, serde_helpers::DerefSerializer, Encrypted, HasWaitId, HowLongToBlock, Msg, MsgEmpty, MsgId, MsgSigned, MsgSocketRequest}; use tokio::{sync::oneshot, time::Instant}; -use tracing::{debug, warn}; +use tracing::{debug, warn, Span}; use crate::task_manager::TaskManager; @@ -49,7 +49,7 @@ impl Default for SocketState { pub(crate) fn router() -> Router { Router::new() .route("/v1/sockets", get(get_socket_requests).post(post_socket_request)) - .route("/v1/sockets/:id", get(connect_socket)) + .route("/v1/sockets/:id", get(connect_socket).post(connect_socket)) .with_state(SocketState::default()) } @@ -85,16 +85,19 @@ async fn post_socket_request( )) } -// TODO: Instrument with task_id +#[tracing::instrument(skip(method, state, parts, body), fields(is_read))] async fn connect_socket( + method: Method, state: State, Path(task_id): Path, mut parts: Parts, body: Body, // This Result is just an Either type. An error value does not mean something went wrong ) -> Result { + let is_read = method == Method::GET; + Span::current().record("is_read", is_read); let mut body_stream = body.into_data_stream(); - let (is_read, jwt, remaining) = read_header(&mut body_stream).await?; + let (jwt, remaining) = read_header(&mut body_stream).await?; let result = shared::crypto_jwt::verify_with_extended_header::(&mut parts, String::from_utf8_lossy(&jwt).as_ref()).await; let msg = match result { Ok(msg) => msg.msg, @@ -124,13 +127,13 @@ async fn connect_socket( rx }, }; - debug!(%task_id, "Read waiting on write"); + debug!("Read waiting on write"); let recv_res = recv.await; _ = state.task_manager.remove(&task_id); match recv_res { Ok(s) => Ok(Body::from_stream(s).into_response()), Err(_) => { - warn!(%task_id, "Socket connection expired"); + warn!("Socket connection expired"); Err(StatusCode::GONE) }, } @@ -158,7 +161,7 @@ async fn connect_socket( std::task::Poll::Ready(None) })).boxed()); let Ok(()) = send_res else { - warn!(%task_id, "Failed to send socket body. Reciever dropped"); + warn!("Failed to send socket body. Reciever dropped"); return Err(StatusCode::GONE); }; debug!("Write half send the stream to the read half"); @@ -168,7 +171,7 @@ async fn connect_socket( } } -async fn read_header(s: &mut BodyDataStream) -> Result<(bool, Bytes, BytesMut), StatusCode> { +async fn read_header(s: &mut BodyDataStream) -> Result<(Bytes, BytesMut), StatusCode> { async fn next(s: &mut BodyDataStream) -> Result, StatusCode> { s.next().await.transpose().map_err(|e| { warn!(%e, "Failed to read init for sockets"); @@ -179,23 +182,22 @@ async fn read_header(s: &mut BodyDataStream) -> Result<(bool, Bytes, BytesMut), #[derive(Debug)] enum ReadState { ReadingHeader, - ReadingMessage { is_read: bool, len: usize }, + ReadingMessage { len: usize }, } let mut state = ReadState::ReadingHeader; while let Some(mut packet) = next(s).await? { loop { match state { - ReadState::ReadingHeader if buf.len() + packet.len() >= 5 => { + ReadState::ReadingHeader if buf.len() + packet.len() >= 4 => { buf.put(packet.split_to(packet.len())); debug_assert!(packet.is_empty()); - let is_read = buf.split_to(1)[0] == 1; let len = u32::from_be_bytes(buf.split_to(4).as_ref().try_into().unwrap()); - state = ReadState::ReadingMessage { is_read, len: len as usize }; + state = ReadState::ReadingMessage { len: len as usize }; continue; }, - ReadState::ReadingMessage { is_read, len } if buf.len() + packet.len() >= len => { + ReadState::ReadingMessage { len } if buf.len() + packet.len() >= len => { buf.put(packet); - return Ok((is_read, buf.split_to(len).freeze(), buf)) + return Ok((buf.split_to(len).freeze(), buf)) }, _ => break, } diff --git a/proxy/src/serve_sockets.rs b/proxy/src/serve_sockets.rs index f778abdb..cca5d807 100644 --- a/proxy/src/serve_sockets.rs +++ b/proxy/src/serve_sockets.rs @@ -3,7 +3,7 @@ use std::{ }; use axum::{ - body::Body, extract::{Path, Request, State}, http::{self, header, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Json, RequestPartsExt, Router + body::Body, extract::{Path, Request, State}, http::{self, header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Json, RequestPartsExt, Router }; use bytes::{Buf, BufMut, Bytes, BytesMut}; use crypto_secretstream::{Header, Key, PullStream, PushStream}; @@ -136,15 +136,13 @@ async fn create_socket_con( ); return (res.status(), "Failed to post MsgSocketRequest to broker").into_response(); } - let req = match prepare_socket_request(sender, task_id, &state).await { + let req = match prepare_socket_request(false, sender, task_id, &state).await { Ok(req) => req, Err(e) => return e, }; let req = req.map(|b| { let n = b.as_bytes().len(); - let mut body = Vec::with_capacity(n + 5); - // This 0 signals write interest - body.push(0); + let mut body = Vec::with_capacity(n + 4); body.extend(u32::to_be_bytes(n as _)); body.extend(b.as_bytes()); Bytes::from(body) @@ -155,7 +153,6 @@ async fn create_socket_con( }; let (mut parts, body) = req.into_parts(); parts.headers.append(header::TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - parts.headers.append(header::CONNECTION, HeaderValue::from_static("keep-alive")); let stream = stream::once(ready(Ok(body))).chain(Encrypter::new(key).encrypt(og_req.into_body().into_data_stream())); let req = Request::from_parts(parts, reqwest::Body::wrap_stream(stream)); match state.client.execute(req.try_into().expect("Conversion to reqwest::Request should always work")).await { @@ -173,15 +170,13 @@ async fn connect_read( Extension(task_secret_map): Extension, Path(task_id): Path, ) -> Response { - let req = match prepare_socket_request(sender, task_id, &state).await { + let req = match prepare_socket_request(true, sender, task_id, &state).await { Ok(value) => value, Err(e) => return e, }; let req = req.map(|b| { let n = b.as_bytes().len(); - let mut body = Vec::with_capacity(n + 4 + 1); - // This 1 signals read interest - body.push(1); + let mut body = Vec::with_capacity(n + 4); body.extend(u32::to_be_bytes(n as _)); body.extend(b.as_bytes()); Bytes::from(body) @@ -200,7 +195,7 @@ async fn connect_read( Response::new(Body::from_stream(Decrypter::new(key).decrypt(res.bytes_stream()))) } -async fn prepare_socket_request(sender: beam_lib::AppId, task_id: MsgId, state: &State) -> Result, http::Response> { +async fn prepare_socket_request(is_read: bool, sender: beam_lib::AppId, task_id: MsgId, state: &State) -> Result, http::Response> { let msg_empty = MsgEmpty { from: AppOrProxyId::App(sender.clone()), }; @@ -208,7 +203,14 @@ async fn prepare_socket_request(sender: beam_lib::AppId, task_id: MsgId, state: warn!("Failed to serialize MsgEmpty"); return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response()); }; - let new_req = Request::get(format!("/v1/sockets/{task_id}")).body(axum::body::Body::from(body)); + let new_req = Request::builder() + .method(if is_read { + Method::GET + } else { + Method::POST + }) + .uri(format!("/v1/sockets/{task_id}")) + .body(axum::body::Body::from(body)); let get_socket_con_req = match new_req { Ok(req) => req, Err(e) => {