diff --git a/Sources/Valkey/Node/ValkeyNodeClient.swift b/Sources/Valkey/Node/ValkeyNodeClient.swift index 83d9edfc..e5a08d3e 100644 --- a/Sources/Valkey/Node/ValkeyNodeClient.swift +++ b/Sources/Valkey/Node/ValkeyNodeClient.swift @@ -35,6 +35,13 @@ package final class ValkeyNodeClient: Sendable { ValkeyClientMetrics, ContinuousClock > + @usableFromInline + typealias ConnectionStateMachine = + SubscriptionConnectionStateMachine< + ValkeyConnection, + CheckedContinuation, + CheckedContinuation + > /// Server address public let serverAddress: ValkeyServerAddress /// Connection pool @@ -47,6 +54,17 @@ package final class ValkeyNodeClient: Sendable { public let eventLoopGroup: any EventLoopGroup /// Logger public let logger: Logger + /// subscription connection state + @usableFromInline + let subscriptionConnectionStateMachine: Mutex + @usableFromInline + let subscriptionConnectionIDGenerator: ConnectionIDGenerator + /// Actions that can be run on a node + enum RunAction: Sendable { + case leaseSubscriptionConnection(leaseID: Int) + } + let actionStream: AsyncStream + let actionStreamContinuation: AsyncStream.Continuation package init( _ address: ValkeyServerAddress, @@ -85,6 +103,9 @@ package final class ValkeyNodeClient: Sendable { self.connectionFactory = connectionFactory self.eventLoopGroup = eventLoopGroup self.logger = logger + self.subscriptionConnectionStateMachine = .init(.init()) + self.subscriptionConnectionIDGenerator = .init() + (self.actionStream, self.actionStreamContinuation) = AsyncStream.makeStream(of: RunAction.self) } } @@ -93,7 +114,18 @@ extension ValkeyNodeClient { /// Run ValkeyNode connection pool @usableFromInline package func run() async { - await self.connectionPool.run() + /// Run discarding task group running actions + await withDiscardingTaskGroup { group in + group.addTask { + await self.connectionPool.run() + self.shutdownSubscriptionConnection() + } + for await action in self.actionStream { + group.addTask { + await self.runAction(action) + } + } + } } func triggerForceShutdown() { @@ -213,3 +245,25 @@ extension ValkeyNodeClient: ValkeyNodeConnectionPool { self.triggerForceShutdown() } } + +@available(valkeySwift 1.0, *) +extension ValkeyNodeClient { + func queueAction(_ action: RunAction) { + self.actionStreamContinuation.yield(action) + } + + private func runAction(_ action: RunAction) async { + switch action { + case .leaseSubscriptionConnection(let leaseID): + do { + try await self.withConnection { connection in + await withCheckedContinuation { (cont: CheckedContinuation) in + self.acquiredSubscriptionConnection(leaseID: leaseID, connection: connection, releaseContinuation: cont) + } + } + } catch { + self.errorAcquiringSubscriptionConnection(leaseID: leaseID, error: error) + } + } + } +} diff --git a/Sources/Valkey/Subscriptions/SubscriptionConnectionStateMachine.swift b/Sources/Valkey/Subscriptions/SubscriptionConnectionStateMachine.swift index c32ab97b..156002e0 100644 --- a/Sources/Valkey/Subscriptions/SubscriptionConnectionStateMachine.swift +++ b/Sources/Valkey/Subscriptions/SubscriptionConnectionStateMachine.swift @@ -11,7 +11,7 @@ import Synchronization import _ValkeyConnectionPool @available(valkeySwift 1.0, *) -extension ValkeyClient { +extension ValkeyNodeClient { @usableFromInline func leaseSubscriptionConnection(id: Int, request: CheckedContinuation) { self.logger.trace("Get subscription connection", metadata: ["valkey_subscription_connection_id": .stringConvertible(id)]) @@ -102,6 +102,25 @@ extension ValkeyClient { break } } + + @usableFromInline + func shutdownSubscriptionConnection() { + self.logger.trace("Shutdown subscription connection") + let action = self.subscriptionConnectionStateMachine.withLock { stateMachine in + stateMachine.shutdown() + } + switch action { + case .yield(let continuations): + for cont in continuations { + cont.resume(throwing: ValkeyClientError(.connectionClosing)) + } + case .release(let continuation): + continuation.resume() + self.logger.trace("Released connection for subscriptions") + case .doNothing: + break + } + } } /// StateMachine for acquiring Subscription Connection. @@ -114,6 +133,8 @@ struct SubscriptionConnectionStateMachine: ~Copy case acquiring(leaseID: Int, waiters: [Int: Request]) /// We have a connection case acquired(AcquiredState) + /// Connection is shutdown + case shutdown struct AcquiredState { var leaseID: Int @@ -151,6 +172,8 @@ struct SubscriptionConnectionStateMachine: ~Copy state.requestIDs.insert(id) self = .acquired(state) return .completeRequest(state.value) + case .shutdown: + preconditionFailure("Cannot get subscription connection when shutdown") } } @@ -185,6 +208,9 @@ struct SubscriptionConnectionStateMachine: ~Copy self = .acquired(state) return .doNothing } + case .shutdown: + self = .shutdown + return .doNothing } } @@ -213,6 +239,9 @@ struct SubscriptionConnectionStateMachine: ~Copy } else { preconditionFailure("Acquired connection twice") } + case .shutdown: + self = .shutdown + return .release } } @@ -241,6 +270,9 @@ struct SubscriptionConnectionStateMachine: ~Copy } else { preconditionFailure("Error acquiring connection we already have") } + case .shutdown: + self = .shutdown + return .doNothing } } @@ -264,6 +296,32 @@ struct SubscriptionConnectionStateMachine: ~Copy self = .acquired(state) return .doNothing } + case .shutdown: + self = .shutdown + return .doNothing + } + } + + enum ShutdownAction { + case yield([Request]) + case release(ReleaseRequest) + case doNothing + } + + mutating func shutdown() -> ShutdownAction { + switch consume self.state { + case .uninitialized(let leaseID): + self = .uninitialized(nextLeaseID: leaseID) + return .doNothing + case .acquiring(let storedLeaseID, let waiters): + self = .uninitialized(nextLeaseID: storedLeaseID + 1) + return .yield(.init(waiters.values)) + case .acquired(let state): + self = .uninitialized(nextLeaseID: state.leaseID + 1) + return .release(state.releaseRequest) + case .shutdown: + self = .shutdown + return .doNothing } } @@ -272,10 +330,12 @@ struct SubscriptionConnectionStateMachine: ~Copy case .uninitialized: true case .acquiring: false case .acquired: false + case .shutdown: true } } static private func uninitialized(nextLeaseID: Int) -> Self { .init(state: .uninitialized(nextLeaseID: nextLeaseID)) } static private func acquiring(leaseID: Int, waiters: [Int: Request]) -> Self { .init(state: .acquiring(leaseID: leaseID, waiters: waiters)) } static private func acquired(_ state: State.AcquiredState) -> Self { .init(state: .acquired(state)) } + static private var shutdown: Self { .init(state: .shutdown) } } diff --git a/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift b/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift index 483eca28..3cdd843b 100644 --- a/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift +++ b/Sources/Valkey/Subscriptions/ValkeyClient+subscribe.swift @@ -20,18 +20,19 @@ extension ValkeyClient { isolation: isolated (any Actor)? = #isolation, _ operation: (ValkeyConnection) async throws -> sending Value ) async throws -> sending Value { - let id = self.subscriptionConnectionIDGenerator.next() + let node = self.node + let id = node.subscriptionConnectionIDGenerator.next() let connection = try await withTaskCancellationHandler { try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in - self.leaseSubscriptionConnection(id: id, request: cont) + node.leaseSubscriptionConnection(id: id, request: cont) } } onCancel: { - self.cancelSubscriptionConnection(id: id) + node.cancelSubscriptionConnection(id: id) } defer { - self.releaseSubscriptionConnection(id: id) + node.releaseSubscriptionConnection(id: id) } return try await operation(connection) } diff --git a/Sources/Valkey/Subscriptions/ValkeyClusterClient+subscribe.swift b/Sources/Valkey/Subscriptions/ValkeyClusterClient+subscribe.swift new file mode 100644 index 00000000..2824af55 --- /dev/null +++ b/Sources/Valkey/Subscriptions/ValkeyClusterClient+subscribe.swift @@ -0,0 +1,173 @@ +// +// This source file is part of the valkey-swift project +// Copyright (c) 2025 the valkey-swift project authors +// +// See LICENSE.txt for license information +// SPDX-License-Identifier: Apache-2.0 +// +import NIOCore +import Synchronization + +@available(valkeySwift 1.0, *) +extension ValkeyClusterClient { + /// Run operation with the valkey subscription connection + /// + /// - Parameters: + /// - isolation: Actor isolation + /// - operation: Closure to run with subscription connection + @inlinable + func withSubscriptionConnection( + isolation: isolated (any Actor)? = #isolation, + _ operation: (ValkeyConnection) async throws -> sending Value + ) async throws -> sending Value { + let node = try await self.nodeClient(for: []) + let id = node.subscriptionConnectionIDGenerator.next() + + let connection = try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + node.leaseSubscriptionConnection(id: id, request: cont) + } + } onCancel: { + node.cancelSubscriptionConnection(id: id) + } + + defer { + node.releaseSubscriptionConnection(id: id) + } + return try await operation(connection) + } + + /// Subscribe to list of channels and run closure with subscription + /// + /// When the closure is exited the channels are automatically unsubscribed from. + /// + /// When running subscribe from `ValkeyClient` a single connection is used for + /// all subscriptions. + /// + /// - Parameters: + /// - channels: list of channels to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + @inlinable + public func subscribe( + to channels: String..., + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> sending Value { + try await self.subscribe(to: channels, process: process) + } + + @inlinable + /// Subscribe to list of channels and run closure with subscription + /// + /// When the closure is exited the channels are automatically unsubscribed from. + /// + /// When running subscribe from `ValkeyClient` a single connection is used for + /// all subscriptions. + /// + /// - Parameters: + /// - channels: list of channels to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + public func subscribe( + to channels: [String], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> sending Value { + try await self.subscribe( + command: SUBSCRIBE(channels: channels), + filters: channels.map { .channel($0) }, + process: process + ) + } + + /// Subscribe to list of channel patterns and run closure with subscription + /// + /// When the closure is exited the patterns are automatically unsubscribed from. + /// + /// When running subscribe from `ValkeyClient` a single connection is used for + /// all subscriptions. + /// + /// - Parameters: + /// - patterns: list of channel patterns to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + @inlinable + public func psubscribe( + to patterns: String..., + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> sending Value { + try await self.psubscribe(to: patterns, process: process) + } + + /// Subscribe to list of pattern matching channels and run closure with subscription + /// + /// When the closure is exited the patterns are automatically unsubscribed from. + /// + /// When running subscribe from `ValkeyClient` a single connection is used for + /// all subscriptions. + /// + /// - Parameters: + /// - patterns: list of channel patterns to subscribe to + /// - isolation: Actor isolation + /// - process: Closure that is called with subscription async sequence + /// - Returns: Return value of closure + @inlinable + public func psubscribe( + to patterns: [String], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> sending Value { + try await self.subscribe( + command: PSUBSCRIBE(patterns: patterns), + filters: patterns.map { .pattern($0) }, + process: process + ) + } + + /// Subscribe to key invalidation channel required for client-side caching + /// + /// See https://valkey.io/topics/client-side-caching/ for more details. The `process` + /// closure is provided with a stream of ValkeyKeys that have been invalidated and also + /// the client id of the subscription connection to redirect client tracking messages to. + /// + /// When the closure is exited the channel is automatically unsubscribed from. + /// + /// When running subscribe from `ValkeyClient` a single connection is used for + /// all subscriptions. + /// + /// - Parameters: + /// - isolation: Actor isolation + /// - process: Closure that is called with async sequence of key invalidations and the client id + /// of the connection the subscription is running on. + /// - Returns: Return value of closure + @inlinable + public func subscribeKeyInvalidations( + isolation: isolated (any Actor)? = #isolation, + process: (AsyncMapSequence, Int) async throws -> sending Value + ) async throws -> sending Value { + try await withSubscriptionConnection { connection in + let id = try await connection.clientId() + return try await connection.subscribe(to: [ValkeySubscriptions.invalidateChannel]) { subscription in + let keys = subscription.map { ValkeyKey($0.message) } + return try await process(keys, id) + } + } + } + + @inlinable + func subscribe( + command: some ValkeyCommand, + filters: [ValkeySubscriptionFilter], + isolation: isolated (any Actor)? = #isolation, + process: (ValkeySubscription) async throws -> sending Value + ) async throws -> sending Value { + try await self.withSubscriptionConnection { connection in + try await connection.subscribe(command: command, filters: filters, process: process) + } + } +} diff --git a/Sources/Valkey/ValkeyClient.swift b/Sources/Valkey/ValkeyClient.swift index 28e648f0..63643bed 100644 --- a/Sources/Valkey/ValkeyClient.swift +++ b/Sources/Valkey/ValkeyClient.swift @@ -22,14 +22,6 @@ import ServiceLifecycle /// `ValkeyClient` supports TLS using both NIOSSL and the Network framework. @available(valkeySwift 1.0, *) public final class ValkeyClient: Sendable { - @usableFromInline - typealias ConnectionStateMachine = - SubscriptionConnectionStateMachine< - ValkeyConnection, - CheckedContinuation, - CheckedContinuation - > - let nodeClientFactory: ValkeyNodeClientFactory /// single node @usableFromInline @@ -42,15 +34,9 @@ public final class ValkeyClient: Sendable { let logger: Logger /// running atomic let runningAtomic: Atomic - /// subscription connection state - @usableFromInline - let subscriptionConnectionStateMachine: Mutex - @usableFromInline - let subscriptionConnectionIDGenerator: ConnectionIDGenerator enum RunAction: Sendable { case runNodeClient(ValkeyNodeClient) - case leaseSubscriptionConnection(leaseID: Int) } let actionStream: AsyncStream let actionStreamContinuation: AsyncStream.Continuation @@ -97,8 +83,6 @@ public final class ValkeyClient: Sendable { self.logger = logger self.runningAtomic = .init(false) self.node = self.nodeClientFactory.makeConnectionPool(serverAddress: address) - self.subscriptionConnectionStateMachine = .init(.init()) - self.subscriptionConnectionIDGenerator = .init() (self.actionStream, self.actionStreamContinuation) = AsyncStream.makeStream(of: RunAction.self) self.queueAction(.runNodeClient(self.node)) } @@ -155,16 +139,6 @@ extension ValkeyClient { switch action { case .runNodeClient(let nodeClient): await nodeClient.run() - case .leaseSubscriptionConnection(let leaseID): - do { - try await self.withConnection { connection in - await withCheckedContinuation { (cont: CheckedContinuation) in - self.acquiredSubscriptionConnection(leaseID: leaseID, connection: connection, releaseContinuation: cont) - } - } - } catch { - self.errorAcquiringSubscriptionConnection(leaseID: leaseID, error: error) - } } } } diff --git a/Tests/ClusterIntegrationTests/ClusterIntegrationTests.swift b/Tests/ClusterIntegrationTests/ClusterIntegrationTests.swift index 2fc9b0b1..bad38b3a 100644 --- a/Tests/ClusterIntegrationTests/ClusterIntegrationTests.swift +++ b/Tests/ClusterIntegrationTests/ClusterIntegrationTests.swift @@ -135,6 +135,130 @@ struct ClusterIntegrationTests { } } } + @Test + @available(valkeySwift 1.0, *) + func testClusterClientSubscriptions() async throws { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + var logger = Logger(label: "Subscriptions") + logger.logLevel = .trace + let firstNodeHostname = clusterFirstNodeHostname! + let firstNodePort = clusterFirstNodePort ?? 6379 + try await Self.withValkeyCluster([(host: firstNodeHostname, port: firstNodePort)], logger: logger) { client in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await client.subscribe(to: "testSubscriptions") { subscription in + cont.finish() + var iterator = subscription.makeAsyncIterator() + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "hello" } + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "goodbye" } + } + } + await stream.first { _ in true } + try await Task.sleep(for: .milliseconds(100)) + _ = try await client.publish(channel: "testSubscriptions", message: "hello") + _ = try await client.publish(channel: "testSubscriptions", message: "goodbye") + try await group.waitForAll() + } + } + } + + @Test + @available(valkeySwift 1.0, *) + func testClientSubscriptionsTwice() async throws { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + var logger = Logger(label: "Subscriptions") + logger.logLevel = .trace + let firstNodeHostname = clusterFirstNodeHostname! + let firstNodePort = clusterFirstNodePort ?? 6379 + try await Self.withValkeyCluster([(host: firstNodeHostname, port: firstNodePort)], logger: logger) { client in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await client.subscribe(to: "testSubscriptions") { subscription in + cont.yield() + var iterator = subscription.makeAsyncIterator() + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "hello" } + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "goodbye" } + } + try await client.subscribe(to: "testSubscriptions") { subscription in + cont.finish() + var iterator = subscription.makeAsyncIterator() + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "hello" } + await #expect(throws: Never.self) { try await iterator.next().map { String(buffer: $0.message) } == "goodbye" } + } + } + await stream.first { _ in true } + try await Task.sleep(for: .milliseconds(10)) + _ = try await client.publish(channel: "testSubscriptions", message: "hello") + _ = try await client.publish(channel: "testSubscriptions", message: "goodbye") + await stream.first { _ in true } + try await Task.sleep(for: .milliseconds(10)) + _ = try await client.publish(channel: "testSubscriptions", message: "hello") + _ = try await client.publish(channel: "testSubscriptions", message: "goodbye") + try await group.waitForAll() + } + } + } + + @Test + @available(valkeySwift 1.0, *) + func testClientMultipleSubscriptions() async throws { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + var logger = Logger(label: "Subscriptions") + logger.logLevel = .trace + let firstNodeHostname = clusterFirstNodeHostname! + let firstNodePort = clusterFirstNodePort ?? 6379 + try await Self.withValkeyCluster([(host: firstNodeHostname, port: firstNodePort)], logger: logger) { client in + try await withThrowingTaskGroup(of: Void.self) { group in + let count = 50 + for i in 0..