Skip to content

Commit

Permalink
Add specific contexts for websocket types, delete BasicWebSocketConte…
Browse files Browse the repository at this point in the history
…xt (#67)

* Add specific contexts for websocket types, delete BasicWebSocketContext

* WebSocketExtensionContext

* Add allowRemoteHalfClosure

* don't propagate error if channel is already closed

* don't propagate error if channel is already closed #2

* swift format
  • Loading branch information
adam-fowler authored Jul 5, 2024
1 parent 948b417 commit e5e169c
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 74 deletions.
5 changes: 4 additions & 1 deletion Sources/HummingbirdWSClient/Client/ClientConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,15 @@ public struct ClientConnection<ClientChannel: ClientConnectionChannel>: Sendable
/// create a BSD sockets based bootstrap
private func createSocketsBootstrap() -> ClientBootstrap {
return ClientBootstrap(group: self.eventLoopGroup)
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
}

#if canImport(Network)
/// create a NIOTransportServices bootstrap using Network.framework
private func createTSBootstrap() -> NIOTSConnectionBootstrap? {
guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup) else {
guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup)
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
else {
return nil
}
if let tlsOptions {
Expand Down
22 changes: 17 additions & 5 deletions Sources/HummingbirdWSClient/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ import NIOWebSocket
/// }
/// ```
public struct WebSocketClient {
/// Basic context implementation of ``WebSocketContext``.
/// Used by non-router web socket handle function
public struct Context: WebSocketContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

package init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}

enum MultiPlatformTLSConfiguration: Sendable {
case niossl(TLSConfiguration)
#if canImport(Network)
Expand All @@ -50,7 +62,7 @@ public struct WebSocketClient {
/// WebSocket URL
let url: URI
/// WebSocket data handler
let handler: WebSocketDataHandler<BasicWebSocketContext>
let handler: WebSocketDataHandler<Context>
/// configuration
let configuration: WebSocketClientConfiguration
/// EventLoopGroup to use
Expand All @@ -75,7 +87,7 @@ public struct WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) {
self.url = .init(url)
self.handler = handler
Expand All @@ -101,7 +113,7 @@ public struct WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) {
self.url = .init(url)
self.handler = handler
Expand Down Expand Up @@ -195,7 +207,7 @@ extension WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) async throws -> WebSocketCloseFrame? {
let ws = self.init(
url: url,
Expand Down Expand Up @@ -225,7 +237,7 @@ extension WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) async throws -> WebSocketCloseFrame? {
let ws = self.init(
url: url,
Expand Down
6 changes: 3 additions & 3 deletions Sources/HummingbirdWSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ struct WebSocketClientChannel: ClientConnectionChannel {

let urlPath: String
let hostHeader: String
let handler: WebSocketDataHandler<BasicWebSocketContext>
let handler: WebSocketDataHandler<WebSocketClient.Context>
let configuration: WebSocketClientConfiguration

init(handler: @escaping WebSocketDataHandler<BasicWebSocketContext>, url: URI, configuration: WebSocketClientConfiguration) throws {
init(handler: @escaping WebSocketDataHandler<WebSocketClient.Context>, url: URI, configuration: WebSocketClientConfiguration) throws {
guard let hostHeader = Self.urlHostHeader(for: url) else { throw WebSocketClientError.invalidURL }
self.hostHeader = hostHeader
self.urlPath = Self.urlPath(for: url)
Expand Down Expand Up @@ -104,7 +104,7 @@ struct WebSocketClientChannel: ClientConnectionChannel {
autoPing: self.configuration.autoPing
),
asyncChannel: webSocketChannel,
context: BasicWebSocketContext(allocator: webSocketChannel.channel.allocator, logger: logger),
context: WebSocketClient.Context(allocator: webSocketChannel.channel.allocator, logger: logger),
handler: self.handler
)
case .notUpgraded:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try self.decompressor.startStream()
}

func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContext) throws -> WebSocketFrame {
func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame {
if self.state == .idle {
if frame.rsv1 {
self.state = .decompressingMessage
Expand Down Expand Up @@ -231,7 +231,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try self.compressor.startStream()
}

func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContext) throws -> WebSocketFrame {
func compress(_ frame: WebSocketFrame, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame {
// if the frame is larger than `minFrameSizeToCompress` bytes, we haven't received a final frame
// or we are in the process of sending a message compress the data
let shouldWeCompress = frame.data.readableBytes >= self.minFrameSizeToCompress || !frame.fin || self.sendState != .idle
Expand Down Expand Up @@ -292,7 +292,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try? await self.compressor.shutdown()
}

func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame {
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame {
return try await self.decompressor.decompress(
frame,
maxSize: self.configuration.maxDecompressedFrameSize,
Expand All @@ -301,7 +301,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
)
}

func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame {
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame {
let isCorrectType = frame.opcode == .text || frame.opcode == .binary || frame.opcode == .continuation
if isCorrectType {
return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context)
Expand Down
15 changes: 1 addition & 14 deletions Sources/HummingbirdWSCore/WebSocketContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,8 @@
import Logging
import NIOCore

/// Context for WebSocket
/// Protocol for WebSocket Data handling functions context parameter
public protocol WebSocketContext: Sendable {
var allocator: ByteBufferAllocator { get }
var logger: Logger { get }
}

/// Basic context implementation of ``WebSocketContext``.
///
/// Used by non-router and client WebSocket connections
public struct BasicWebSocketContext: WebSocketContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

package init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}
17 changes: 15 additions & 2 deletions Sources/HummingbirdWSCore/WebSocketExtension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@

import Foundation
import HTTPTypes
import Logging
import NIOCore
import NIOWebSocket

/// Basic context implementation of ``WebSocketContext``.
public struct WebSocketExtensionContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}

/// Protocol for WebSocket extension
public protocol WebSocketExtension: Sendable {
/// Extension name
var name: String { get }
/// Process frame received from websocket
func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// Process frame about to be sent to websocket
func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// shutdown extension
func shutdown() async
}
Expand Down
38 changes: 23 additions & 15 deletions Sources/HummingbirdWSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ package actor WebSocketHandler {
var outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>
let type: WebSocketType
let configuration: Configuration
let context: BasicWebSocketContext
let logger: Logger
let allocator: ByteBufferAllocator
var pingData: ByteBuffer
var pingTime: ContinuousClock.Instant = .now
var closeState: CloseState
Expand All @@ -92,7 +93,8 @@ package actor WebSocketHandler {
self.outbound = outbound
self.type = type
self.configuration = configuration
self.context = .init(allocator: context.allocator, logger: context.logger)
self.logger = context.logger
self.allocator = context.allocator
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize)
self.closeState = .open
}
Expand Down Expand Up @@ -175,16 +177,19 @@ package actor WebSocketHandler {
} catch {
closeCode = .unexpectedServerError
}
try await self.close(code: closeCode)
if case .closing = self.closeState {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
break
do {
try await self.close(code: closeCode)
if case .closing = self.closeState {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
break
}
}
}
}
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
} onGracefulShutdown: {
Task {
try? await self.close(code: .normalClosure)
Expand All @@ -201,10 +206,13 @@ package actor WebSocketHandler {
var frame = frame
do {
for ext in self.configuration.extensions {
frame = try await ext.processFrameToSend(frame, context: self.context)
frame = try await ext.processFrameToSend(
frame,
context: WebSocketExtensionContext(allocator: self.allocator, logger: self.logger)
)
}
} catch {
self.context.logger.debug("Closing as we failed to generate valid frame data")
self.logger.debug("Closing as we failed to generate valid frame data")
throw WebSocketHandler.InternalError.close(.unexpectedServerError)
}
// Set mask key if client
Expand All @@ -213,7 +221,7 @@ package actor WebSocketHandler {
}
try await self.outbound.write(frame)

self.context.logger.trace("Sent \(frame.traceDescription)")
self.logger.trace("Sent \(frame.traceDescription)")
}

func finish() {
Expand Down Expand Up @@ -273,7 +281,7 @@ package actor WebSocketHandler {
) async throws {
switch self.closeState {
case .open:
var buffer = self.context.allocator.buffer(capacity: 2 + (reason?.utf8.count ?? 0))
var buffer = self.allocator.buffer(capacity: 2 + (reason?.utf8.count ?? 0))
buffer.write(webSocketErrorCode: code)
if let reason {
buffer.writeString(reason)
Expand Down Expand Up @@ -318,7 +326,7 @@ package actor WebSocketHandler {
.protocolError
}

var buffer = self.context.allocator.buffer(capacity: 2)
var buffer = self.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)

try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
Expand Down
16 changes: 11 additions & 5 deletions Sources/HummingbirdWSCore/WebSocketInboundStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
// parse messages coming from inbound
while let frame = try await self.iterator.next() {
do {
self.handler.context.logger.trace("Received \(frame.traceDescription)")
self.handler.logger.trace("Received \(frame.traceDescription)")
switch frame.opcode {
case .connectionClose:
try await self.handler.receivedClose(frame)
Expand All @@ -70,20 +70,26 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
// apply extensions
var frame = frame
for ext in self.handler.configuration.extensions.reversed() {
frame = try await ext.processReceivedFrame(frame, context: self.handler.context)
frame = try await ext.processReceivedFrame(
frame,
context: WebSocketExtensionContext(allocator: self.handler.allocator, logger: self.handler.logger)
)
}
return .init(from: frame)
default:
// if we receive a reserved opcode we should fail the connection
self.handler.context.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)])
self.handler.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)])
throw WebSocketHandler.InternalError.close(.protocolError)
}
} catch {
self.handler.context.logger.trace("Error: \(error)")
self.handler.logger.trace("Error: \(error)")
// catch errors while processing websocket frames so responding close message
// can be dealt with
let errorCode = WebSocketErrorCode(error)
try await self.handler.close(code: errorCode)
do {
try await self.handler.close(code: errorCode)
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/HummingbirdWSCore/WebSocketOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public struct WebSocketOutboundWriter: Sendable {
try await self.handler.write(frame: .init(fin: true, opcode: .binary, data: buffer))
case .text(let string):
// send text based data
let buffer = self.handler.context.allocator.buffer(string: string)
let buffer = self.handler.allocator.buffer(string: string)
try await self.handler.write(frame: .init(fin: true, opcode: .text, data: buffer))
case .pong:
// send unexplained pong as a heartbeat
Expand Down Expand Up @@ -73,7 +73,7 @@ public struct WebSocketOutboundWriter: Sendable {

/// Write string to WebSocket frame
public mutating func callAsFunction(_ text: String) async throws {
let buffer = self.handler.context.allocator.buffer(string: text)
let buffer = self.handler.allocator.buffer(string: text)
try await self.write(buffer, opcode: self.opcode)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extension HTTPServerBuilder {
public static func http1WebSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<BasicWebSocketContext>>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<HTTP1WebSocketUpgradeChannel.Context>>
) -> HTTPServerBuilder {
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
Expand All @@ -41,7 +41,7 @@ extension HTTPServerBuilder {
public static func http1WebSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<BasicWebSocketContext>>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<HTTP1WebSocketUpgradeChannel.Context>>
) -> HTTPServerBuilder {
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
Expand Down
Loading

0 comments on commit e5e169c

Please sign in to comment.