From 40bcb1d9808f55f4a320b16ae1dfd997ca5f4f9b Mon Sep 17 00:00:00 2001 From: Connor Slade Date: Sat, 6 May 2023 15:16:49 -0400 Subject: [PATCH] WebSocket Progress --- examples/tmp.rs | 3 ++ lib/http/web_socket.rs | 116 ++++++++++++++++++++++++++++++----------- 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/examples/tmp.rs b/examples/tmp.rs index 89f7d5f..27ad2cb 100644 --- a/examples/tmp.rs +++ b/examples/tmp.rs @@ -111,6 +111,9 @@ fn main() { server.route(Method::GET, "/ws", |req| { let stream = req.ws().unwrap(); + stream.send("ello world"); + thread::park(); + Response::end() }); diff --git a/lib/http/web_socket.rs b/lib/http/web_socket.rs index a2e790a..04ddcc4 100644 --- a/lib/http/web_socket.rs +++ b/lib/http/web_socket.rs @@ -1,9 +1,10 @@ use std::{ convert::TryInto, + fmt::Display, io::{self, Read, Write}, net::TcpStream, sync::{ - mpsc::{self, Receiver, Sender, SyncSender}, + mpsc::{self, Iter, Receiver, SyncSender}, Arc, Mutex, }, thread, @@ -19,9 +20,10 @@ use crate::{ const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +/// A WebSocket stream. pub struct WebSocketStream { - _rx: Arc>, - _tx: Arc>, + rx: Arc>, + tx: Arc>, } #[derive(Debug)] @@ -36,14 +38,14 @@ struct Frame { payload: Vec, } +/// Types of WebSocket frames #[derive(Debug)] -enum TxType { +pub enum TxType { + /// Close the socket Close, -} - -#[derive(Debug)] -enum RxType { + /// Send / Receive a text message Text(String), + /// Send / Receive a binary message Binary(Vec), } @@ -79,11 +81,14 @@ impl WebSocketStream { break; } + trace!(Level::Debug, "WS: Received: {:?}", &buf[..len]); let frame = match Frame::from_slice(&buf[..len]) { Some(f) => f, None => continue, }; + assert_eq!(&buf[..len], &frame.to_bytes()[..]); + if !frame.fin { todo!("Handle fragmented frames"); } @@ -102,9 +107,7 @@ impl WebSocketStream { 0 => {} 1 => {} 2 => {} - 8 => { - this_s2c.send(TxType::Close).unwrap(); - } + 8 => this_s2c.send(TxType::Close).unwrap(), 9 => {} 10 => {} _ => {} @@ -118,24 +121,43 @@ impl WebSocketStream { for i in rx { trace!(Level::Debug, "WS: Sending {:?}", i); match i { - TxType::Close => { - Frame::close().write(socket.clone()).unwrap(); - } + TxType::Close => Frame::close(), + TxType::Text(s) => Frame::text(s), + TxType::Binary(b) => Frame::binary(b), } + .write(socket.clone()) + .unwrap(); trace!(Level::Debug, "WS: Sent :p"); } }); - // todo: everything else :sweat_smile:\ + Ok(Self { rx: c2s, tx: s2c }) + } + + /// Sends 'text' data to the client. + pub fn send(&self, data: impl Display) { + self.tx.send(TxType::Text(data.to_string())).unwrap(); + } + + /// Sends binary data to the client. + pub fn send_binary(&self, data: Vec) { + self.tx.send(TxType::Binary(data)).unwrap(); + } +} - Ok(Self { _rx: c2s, _tx: s2c }) +impl<'a> IntoIterator for &'a WebSocketStream { + type Item = TxType; + type IntoIter = Iter<'a, TxType>; + + fn into_iter(self) -> Iter<'a, TxType> { + self.rx.iter() } } impl Frame { fn from_slice(buf: &[u8]) -> Option { let fin = buf[0] & 0b1000_0000 != 0; - let rsv = buf[0] & 0b0111_0000 >> 4; + let rsv = (buf[0] & 0b0111_0000) >> 4; let mask = buf[1] & 0b1000_0000 != 0; let opcode = buf[0] & 0b0000_1111; @@ -207,29 +229,38 @@ impl Frame { | Payload Data continued ... | +---------------------------------------------------------------+ */ - fn write(&self, socket: Arc>) -> io::Result<()> { + fn to_bytes(&self) -> Vec { let mut buf = Vec::new(); buf.push((self.fin as u8) << 7 | self.rsv << 4 | self.opcode); - if self.payload_len < 126 { - buf.push((self.mask.is_some() as u8) << 7 | self.payload_len as u8); - } else if self.payload_len < 65536 { - buf.push((self.mask.is_some() as u8) << 7 | 126); - buf.extend_from_slice(&self.payload_len.to_be_bytes()); - } else { - buf.push((self.mask.is_some() as u8) << 7 | 127); - buf.extend_from_slice(&self.payload_len.to_be_bytes()); + match self.payload_len { + ..=125 => buf.push((self.mask.is_some() as u8) << 7 | self.payload_len as u8), + 126..=65535 => { + buf.push((self.mask.is_some() as u8) << 7 | 126); + buf.extend_from_slice(&self.payload_len.to_be_bytes()); + } + _ => { + buf.push((self.mask.is_some() as u8) << 7 | 127); + buf.extend_from_slice(&self.payload_len.to_be_bytes()); + } } - if let Some(mask) = self.mask { - buf.extend_from_slice(&mask); + match self.mask { + Some(mask) => { + buf.extend_from_slice(&mask); + buf.extend_from_slice(&xor_mask(&mask, &self.payload)) + } + None => buf.extend_from_slice(&self.payload), } - buf.extend_from_slice(&self.payload); + buf + } + fn write(&self, socket: Arc>) -> io::Result<()> { + let buf = self.to_bytes(); trace!(Level::Debug, "WS: Writing: {:?}", buf); - socket.force_lock().write(&buf)?; + socket.force_lock().write_all(&buf)?; Ok(()) } @@ -244,6 +275,28 @@ impl Frame { } } + fn text(text: String) -> Self { + Self { + fin: true, + rsv: 0, + opcode: 1, + payload_len: text.len() as u64, + mask: None, + payload: text.into_bytes(), + } + } + + fn binary(binary: Vec) -> Self { + Self { + fin: true, + rsv: 0, + opcode: 2, + payload_len: binary.len() as u64, + mask: None, + payload: binary, + } + } + fn rsv1(&self) -> bool { self.rsv & 0b100 != 0 } @@ -269,12 +322,13 @@ impl WebSocketExt for Request { } } -fn decode(mask: &[u8], data: &[u8]) -> Vec { +fn xor_mask(mask: &[u8], data: &[u8]) -> Vec { debug_assert_eq!(mask.len(), 4); let mut decoded = Vec::with_capacity(data.len()); for (i, byte) in data.iter().enumerate() { decoded.push(byte ^ mask[i % 4]); } + decoded }