Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add specific contexts for websocket types, delete BasicWebSocketContext #67

Merged
merged 6 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Sources/HummingbirdWSClient/Client/ClientConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,15 @@ public struct ClientConnection<ClientChannel: ClientConnectionChannel>: Sendable
/// create a BSD sockets based bootstrap
private func createSocketsBootstrap() -> ClientBootstrap {
return ClientBootstrap(group: self.eventLoopGroup)
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
}

#if canImport(Network)
/// create a NIOTransportServices bootstrap using Network.framework
private func createTSBootstrap() -> NIOTSConnectionBootstrap? {
guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup) else {
guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup)
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
else {
return nil
}
if let tlsOptions {
Expand Down
22 changes: 17 additions & 5 deletions Sources/HummingbirdWSClient/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ import NIOWebSocket
/// }
/// ```
public struct WebSocketClient {
/// Basic context implementation of ``WebSocketContext``.
/// Used by non-router web socket handle function
public struct Context: WebSocketContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

package init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}

enum MultiPlatformTLSConfiguration: Sendable {
case niossl(TLSConfiguration)
#if canImport(Network)
Expand All @@ -50,7 +62,7 @@ public struct WebSocketClient {
/// WebSocket URL
let url: URI
/// WebSocket data handler
let handler: WebSocketDataHandler<BasicWebSocketContext>
let handler: WebSocketDataHandler<Context>
/// configuration
let configuration: WebSocketClientConfiguration
/// EventLoopGroup to use
Expand All @@ -75,7 +87,7 @@ public struct WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) {
self.url = .init(url)
self.handler = handler
Expand All @@ -101,7 +113,7 @@ public struct WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) {
self.url = .init(url)
self.handler = handler
Expand Down Expand Up @@ -195,7 +207,7 @@ extension WebSocketClient {
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) async throws -> WebSocketCloseFrame? {
let ws = self.init(
url: url,
Expand Down Expand Up @@ -225,7 +237,7 @@ extension WebSocketClient {
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
handler: @escaping WebSocketDataHandler<Context>
) async throws -> WebSocketCloseFrame? {
let ws = self.init(
url: url,
Expand Down
6 changes: 3 additions & 3 deletions Sources/HummingbirdWSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ struct WebSocketClientChannel: ClientConnectionChannel {

let urlPath: String
let hostHeader: String
let handler: WebSocketDataHandler<BasicWebSocketContext>
let handler: WebSocketDataHandler<WebSocketClient.Context>
let configuration: WebSocketClientConfiguration

init(handler: @escaping WebSocketDataHandler<BasicWebSocketContext>, url: URI, configuration: WebSocketClientConfiguration) throws {
init(handler: @escaping WebSocketDataHandler<WebSocketClient.Context>, url: URI, configuration: WebSocketClientConfiguration) throws {
guard let hostHeader = Self.urlHostHeader(for: url) else { throw WebSocketClientError.invalidURL }
self.hostHeader = hostHeader
self.urlPath = Self.urlPath(for: url)
Expand Down Expand Up @@ -104,7 +104,7 @@ struct WebSocketClientChannel: ClientConnectionChannel {
autoPing: self.configuration.autoPing
),
asyncChannel: webSocketChannel,
context: BasicWebSocketContext(allocator: webSocketChannel.channel.allocator, logger: logger),
context: WebSocketClient.Context(allocator: webSocketChannel.channel.allocator, logger: logger),
handler: self.handler
)
case .notUpgraded:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try self.decompressor.startStream()
}

func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContext) throws -> WebSocketFrame {
func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame {
if self.state == .idle {
if frame.rsv1 {
self.state = .decompressingMessage
Expand Down Expand Up @@ -231,7 +231,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try self.compressor.startStream()
}

func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContext) throws -> WebSocketFrame {
func compress(_ frame: WebSocketFrame, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame {
// if the frame is larger than `minFrameSizeToCompress` bytes, we haven't received a final frame
// or we are in the process of sending a message compress the data
let shouldWeCompress = frame.data.readableBytes >= self.minFrameSizeToCompress || !frame.fin || self.sendState != .idle
Expand Down Expand Up @@ -292,7 +292,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try? await self.compressor.shutdown()
}

func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame {
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame {
return try await self.decompressor.decompress(
frame,
maxSize: self.configuration.maxDecompressedFrameSize,
Expand All @@ -301,7 +301,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
)
}

func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame {
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame {
let isCorrectType = frame.opcode == .text || frame.opcode == .binary || frame.opcode == .continuation
if isCorrectType {
return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context)
Expand Down
15 changes: 1 addition & 14 deletions Sources/HummingbirdWSCore/WebSocketContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,8 @@
import Logging
import NIOCore

/// Context for WebSocket
/// Protocol for WebSocket Data handling functions context parameter
public protocol WebSocketContext: Sendable {
var allocator: ByteBufferAllocator { get }
var logger: Logger { get }
}

/// Basic context implementation of ``WebSocketContext``.
///
/// Used by non-router and client WebSocket connections
public struct BasicWebSocketContext: WebSocketContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

package init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}
17 changes: 15 additions & 2 deletions Sources/HummingbirdWSCore/WebSocketExtension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@

import Foundation
import HTTPTypes
import Logging
import NIOCore
import NIOWebSocket

/// Basic context implementation of ``WebSocketContext``.
public struct WebSocketExtensionContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}

/// Protocol for WebSocket extension
public protocol WebSocketExtension: Sendable {
/// Extension name
var name: String { get }
/// Process frame received from websocket
func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// Process frame about to be sent to websocket
func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// shutdown extension
func shutdown() async
}
Expand Down
38 changes: 23 additions & 15 deletions Sources/HummingbirdWSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ package actor WebSocketHandler {
var outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>
let type: WebSocketType
let configuration: Configuration
let context: BasicWebSocketContext
let logger: Logger
let allocator: ByteBufferAllocator
var pingData: ByteBuffer
var pingTime: ContinuousClock.Instant = .now
var closeState: CloseState
Expand All @@ -92,7 +93,8 @@ package actor WebSocketHandler {
self.outbound = outbound
self.type = type
self.configuration = configuration
self.context = .init(allocator: context.allocator, logger: context.logger)
self.logger = context.logger
self.allocator = context.allocator
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize)
self.closeState = .open
}
Expand Down Expand Up @@ -175,16 +177,19 @@ package actor WebSocketHandler {
} catch {
closeCode = .unexpectedServerError
}
try await self.close(code: closeCode)
if case .closing = self.closeState {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
break
do {
try await self.close(code: closeCode)
if case .closing = self.closeState {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
try await self.receivedClose(frame)
break
}
}
}
}
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
} onGracefulShutdown: {
Task {
try? await self.close(code: .normalClosure)
Expand All @@ -201,10 +206,13 @@ package actor WebSocketHandler {
var frame = frame
do {
for ext in self.configuration.extensions {
frame = try await ext.processFrameToSend(frame, context: self.context)
frame = try await ext.processFrameToSend(
frame,
context: WebSocketExtensionContext(allocator: self.allocator, logger: self.logger)
)
}
} catch {
self.context.logger.debug("Closing as we failed to generate valid frame data")
self.logger.debug("Closing as we failed to generate valid frame data")
throw WebSocketHandler.InternalError.close(.unexpectedServerError)
}
// Set mask key if client
Expand All @@ -213,7 +221,7 @@ package actor WebSocketHandler {
}
try await self.outbound.write(frame)

self.context.logger.trace("Sent \(frame.traceDescription)")
self.logger.trace("Sent \(frame.traceDescription)")
}

func finish() {
Expand Down Expand Up @@ -273,7 +281,7 @@ package actor WebSocketHandler {
) async throws {
switch self.closeState {
case .open:
var buffer = self.context.allocator.buffer(capacity: 2 + (reason?.utf8.count ?? 0))
var buffer = self.allocator.buffer(capacity: 2 + (reason?.utf8.count ?? 0))
buffer.write(webSocketErrorCode: code)
if let reason {
buffer.writeString(reason)
Expand Down Expand Up @@ -318,7 +326,7 @@ package actor WebSocketHandler {
.protocolError
}

var buffer = self.context.allocator.buffer(capacity: 2)
var buffer = self.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)

try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
Expand Down
16 changes: 11 additions & 5 deletions Sources/HummingbirdWSCore/WebSocketInboundStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
// parse messages coming from inbound
while let frame = try await self.iterator.next() {
do {
self.handler.context.logger.trace("Received \(frame.traceDescription)")
self.handler.logger.trace("Received \(frame.traceDescription)")
switch frame.opcode {
case .connectionClose:
try await self.handler.receivedClose(frame)
Expand All @@ -70,20 +70,26 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
// apply extensions
var frame = frame
for ext in self.handler.configuration.extensions.reversed() {
frame = try await ext.processReceivedFrame(frame, context: self.handler.context)
frame = try await ext.processReceivedFrame(
frame,
context: WebSocketExtensionContext(allocator: self.handler.allocator, logger: self.handler.logger)
)
}
return .init(from: frame)
default:
// if we receive a reserved opcode we should fail the connection
self.handler.context.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)])
self.handler.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)])
throw WebSocketHandler.InternalError.close(.protocolError)
}
} catch {
self.handler.context.logger.trace("Error: \(error)")
self.handler.logger.trace("Error: \(error)")
// catch errors while processing websocket frames so responding close message
// can be dealt with
let errorCode = WebSocketErrorCode(error)
try await self.handler.close(code: errorCode)
do {
try await self.handler.close(code: errorCode)
// don't propagate error if channel is already closed
} catch ChannelError.ioOnClosedChannel {}
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/HummingbirdWSCore/WebSocketOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public struct WebSocketOutboundWriter: Sendable {
try await self.handler.write(frame: .init(fin: true, opcode: .binary, data: buffer))
case .text(let string):
// send text based data
let buffer = self.handler.context.allocator.buffer(string: string)
let buffer = self.handler.allocator.buffer(string: string)
try await self.handler.write(frame: .init(fin: true, opcode: .text, data: buffer))
case .pong:
// send unexplained pong as a heartbeat
Expand Down Expand Up @@ -73,7 +73,7 @@ public struct WebSocketOutboundWriter: Sendable {

/// Write string to WebSocket frame
public mutating func callAsFunction(_ text: String) async throws {
let buffer = self.handler.context.allocator.buffer(string: text)
let buffer = self.handler.allocator.buffer(string: text)
try await self.write(buffer, opcode: self.opcode)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extension HTTPServerBuilder {
public static func http1WebSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<BasicWebSocketContext>>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<HTTP1WebSocketUpgradeChannel.Context>>
) -> HTTPServerBuilder {
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
Expand All @@ -41,7 +41,7 @@ extension HTTPServerBuilder {
public static func http1WebSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<BasicWebSocketContext>>
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<HTTP1WebSocketUpgradeChannel.Context>>
) -> HTTPServerBuilder {
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
Expand Down
Loading
Loading