Skip to content

Commit

Permalink
Removes duplicated close functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SirCipher committed May 23, 2024
1 parent 26270ee commit c2d8001
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 114 deletions.
14 changes: 10 additions & 4 deletions ratchet_core/src/framed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use rand::{Rng, SeedableRng};
use ratchet_ext::{ExtensionDecoder, FrameHeader as ExtFrameHeader, OpCode as ExtOpCode};
use std::convert::TryFrom;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -494,6 +495,15 @@ where
}
}

pub async fn with_writer<'s, F, A>(&'s mut self, f: F) -> A::Output
where
F: FnOnce(&'s mut I, &'s mut FramedWrite) -> A,
A: Future + 's,
{
let FramedIo { io, writer, .. } = self;
f(io, writer).await
}

pub fn is_server(&self) -> bool {
self.flags.contains(CodecFlags::ROLE)
}
Expand Down Expand Up @@ -574,10 +584,6 @@ where
)
.await
}

pub async fn close(&mut self) {
let _r = self.io.shutdown().await;
}
}

pub async fn read_next<I, E>(
Expand Down
116 changes: 42 additions & 74 deletions ratchet_core/src/split/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::sync::Arc;

use bytes::BytesMut;
use log::{error, trace};
use tokio::io::AsyncWriteExt;

use bilock::{bilock, BiLock};
use ratchet_ext::{ExtensionDecoder, ExtensionEncoder, ReunitableExtension, SplittableExtension};
Expand Down Expand Up @@ -202,10 +201,6 @@ where
}
}
}

async fn close(&mut self) {
let _r = self.split_writer.shutdown().await;
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -474,26 +469,34 @@ where
Ok(Message::Pong(payload.freeze()))
}
Item::Close(reason) => {
let code = reason
.as_ref()
.map(|reason| reason.code)
.unwrap_or(CloseCode::Normal);

close(
role.is_server(),
close_state,
&mut *split_writer.lock().await,
role.is_server(),
reason,
None,
code,
)
.await
.await?;
Ok(Message::Close(reason))
}
},
Err(e) => {
error!("WebSocket read failure: {:?}", e);
close(

// We want to close the connection but return the error produced during the session,
// not any during the close sequence.
let _r = close(
role.is_server(),
close_state,
&mut *split_writer.lock().await,
role.is_server(),
None,
Some(e),
CloseCode::Protocol,
)
.await
.await;
Err(e)
}
}
}
Expand Down Expand Up @@ -528,12 +531,11 @@ where
}

async fn close<S>(
close_state: &AtomicU8,
framed: &mut WriteHalf<S>,
is_server: bool,
reason: Option<CloseReason>,
ret: Option<Error>,
) -> Result<Message, Error>
state_ref: &AtomicU8,
framed: &mut WriteHalf<S>,
code: CloseCode,
) -> Result<(), Error>
where
S: WebSocketStream,
{
Expand All @@ -542,63 +544,29 @@ where
writer,
..
} = framed;
let close_state = match state_ref.load(Ordering::SeqCst) {
STATE_OPEN => CloseState::NotClosed,
STATE_CLOSING => CloseState::Closing,
STATE_CLOSED => CloseState::Closing,
s => panic!("Unexpected close state: {}", s),
};

match close_state.load(Ordering::SeqCst) {
STATE_OPEN => {
let mut code = match (&reason, &ret) {
(Some(reason), None) => u16::from(reason.code).to_be_bytes(),
(None, Some(error)) if error.is_protocol() | error.is_encoding() => {
u16::from(CloseCode::Protocol).to_be_bytes()
}
(Some(reason), Some(_)) => u16::from(reason.code).to_be_bytes(),
_ => u16::from(CloseCode::Normal).to_be_bytes(),
crate::ws::close(
is_server,
close_state,
|state| {
let state_u8 = match state {
CloseState::NotClosed => STATE_OPEN,
CloseState::Closing => STATE_CLOSING,
CloseState::Closed => STATE_CLOSED,
};

// we don't want to immediately await the echoed close frame as the peer may elect to
// drain any pending messages **before** echoing the close frame

let _write_result = writer
.write(
split_writer,
is_server,
OpCode::ControlCode(ControlCode::Close),
HeaderFlags::FIN,
&mut code,
|_, _| Ok(()),
)
.await;

if is_server {
// 7.1.1: the TCP stream should be closed first by the server
//
// We aren't interested in any IO errors produced here as the peer *may* have
// already closed the TCP stream.
framed.close().await;
trace!("IO closed");
}

close_state.store(STATE_CLOSED, Ordering::SeqCst);

match ret {
Some(err) => Err(err),
None => Ok(Message::Close(reason)),
}
}
STATE_CLOSING => {
close_state.store(STATE_CLOSED, Ordering::SeqCst);
if is_server {
// 7.1.1: the TCP stream should be closed first by the server
//
// We aren't interested in any IO errors produced here as the peer *may* have
// already closed the TCP stream.
framed.close().await;
}

Err(ret.unwrap_or_else(|| Error::with_cause(ErrorKind::Close, CloseCause::Stopped)))
}
STATE_CLOSED => Err(Error::with_cause(ErrorKind::Close, CloseCause::Error)),
s => panic!("Unexpected close state: {}", s),
}
state_ref.store(state_u8, Ordering::SeqCst)
},
split_writer,
writer,
code,
)
.await
}

/// An error produced by `reunite` if the halves do not match.
Expand Down
105 changes: 69 additions & 36 deletions ratchet_core/src/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use crate::errors::{CloseCause, Error, ErrorKind, ProtocolError};
use crate::ext::NegotiatedExtension;
use crate::framed::{FramedIo, Item};
use crate::framed::{FramedIo, FramedWrite, Item};
use crate::protocol::{
CloseReason, ControlCode, DataCode, HeaderFlags, Message, MessageType, OpCode, PayloadType,
Role,
Expand All @@ -23,6 +23,7 @@ use crate::{CloseCode, WebSocketConfig, WebSocketStream};
use bytes::BytesMut;
use log::{error, trace};
use ratchet_ext::{Extension, ExtensionEncoder, FrameHeader as ExtFrameHeader};
use tokio::io::AsyncWriteExt;

#[cfg(feature = "split")]
use crate::split::{split, Receiver, Sender};
Expand Down Expand Up @@ -210,11 +211,54 @@ where
}
Ok(Message::Pong(payload.freeze()))
}
Item::Close(reason) => close(close_state, framed, reason, None).await,
Item::Close(reason) => {
let is_server = framed.is_server();
let code = reason
.as_ref()
.map(|reason| reason.code)
.unwrap_or(CloseCode::Normal);

framed
.with_writer(|io, writer| async {
close(
is_server,
*close_state,
|state| {
*close_state = state;
},
io,
writer,
code,
)
.await
})
.await?;
Ok(Message::Close(reason))
}
},
Err(e) => {
error!("WebSocket read failure: {:?}", e);
close(close_state, framed, None, Some(e)).await

let is_server = framed.is_server();

// We want to close the connection but return the error produced during the session,
// not any during the close sequence.
let _r = framed
.with_writer(|io, writer| async {
close(
is_server,
*close_state,
|state| {
*close_state = state;
},
io,
writer,
CloseCode::Protocol,
)
.await
})
.await;
Err(e)
}
}
}
Expand Down Expand Up @@ -375,69 +419,58 @@ where
}
}

async fn close<S>(
close_state: &mut CloseState,
framed: &mut FramedIo<S>,
reason: Option<CloseReason>,
ret: Option<Error>,
) -> Result<Message, Error>
pub async fn close<S, F>(
is_server: bool,
close_state: CloseState,
update_state: F,
io: &mut S,
framed: &mut FramedWrite,
code: CloseCode,
) -> Result<(), Error>
where
S: WebSocketStream,
F: FnOnce(CloseState),
{
trace!("Start close fn. Reason {reason:?}, return with {ret:?}");

let server = framed.is_server();
match *close_state {
match close_state {
CloseState::NotClosed => {
let mut code = match (&reason, &ret) {
(Some(reason), None) => u16::from(reason.code).to_be_bytes(),
(None, Some(error)) if error.is_protocol() | error.is_encoding() => {
u16::from(CloseCode::Protocol).to_be_bytes()
}
(Some(reason), Some(_)) => u16::from(reason.code).to_be_bytes(),
_ => u16::from(CloseCode::Normal).to_be_bytes(),
};

// we don't want to immediately await the echoed close frame as the peer may elect to
// drain any pending messages **before** echoing the close frame

trace!("Close send frame");

let _write_result = framed
.write(
io,
is_server,
OpCode::ControlCode(ControlCode::Close),
HeaderFlags::FIN,
&mut code,
&mut u16::from(code).to_be_bytes(),
|_, _| Ok(()),
)
.await;

if server {
if is_server {
// 7.1.1: the TCP stream should be closed first by the server
//
// We aren't interested in any IO errors produced here as the peer *may* have
// already closed the TCP stream.
framed.close().await;
let _r = io.shutdown().await;
}

*close_state = CloseState::Closed;
update_state(CloseState::Closed);

match ret {
Some(err) => Err(err),
None => Ok(Message::Close(reason)),
}
Ok(())
}
CloseState::Closing => {
*close_state = CloseState::Closed;
if server {
update_state(CloseState::Closed);

if is_server {
// 7.1.1: the TCP stream should be closed first by the server
//
// We aren't interested in any IO errors produced here as the peer *may* have
// already closed the TCP stream.
framed.close().await;
let _r = io.shutdown().await;
}

Err(ret.unwrap_or_else(|| Error::with_cause(ErrorKind::Close, CloseCause::Stopped)))
Err(Error::with_cause(ErrorKind::Close, CloseCause::Stopped))
}
CloseState::Closed => Err(Error::with_cause(ErrorKind::Close, CloseCause::Error)),
}
Expand Down

0 comments on commit c2d8001

Please sign in to comment.