Skip to content

Commit 86326af

Browse files
authored
Add suspend, resume and reconnect methods (#64)
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
1 parent 7ebee0d commit 86326af

8 files changed

+234
-21
lines changed

README.md

+20-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Currently, the client supports **Core NATS** with auth, TLS, lame duck mode and
2020

2121
JetStream, KV, Object Store, Service API are on the roadmap.
2222

23-
2423
## Support
2524

2625
Join the [#swift](https://natsio.slack.com/channels/swift) channel on nats.io Slack.
@@ -127,7 +126,6 @@ specific subjects, facilitating asynchronous communication patterns. This exampl
127126
will guide you through creating a subscription to a subject, allowing your application to process
128127
incoming messages as they are received.
129128

130-
131129
```swift
132130
let subscription = try await nats.subscribe(subject: "foo.>")
133131

@@ -175,6 +173,26 @@ nats.on(.connected) { event in
175173
}
176174
```
177175

176+
### AppDelegate or SceneDelegate Integration
177+
178+
In order to make sure the connection is managed properly in your
179+
AppDelegate.swift or SceneDelegate.swift, integrate the NatsClient connection
180+
management as follows:
181+
182+
```swift
183+
func sceneDidBecomeActive(_ scene: UIScene) {
184+
Task {
185+
try await self.natsClient.resume()
186+
}
187+
}
188+
189+
func sceneWillResignActive(_ scene: UIScene) {
190+
Task {
191+
try await self.natsClient.suspend()
192+
}
193+
}
194+
```
195+
178196
## Attribution
179197

180198
This library is based on excellent work in https://github.com/aus-der-Technik/SwiftyNats

Sources/Nats/NatsClient/NatsClient.swift

+27
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public enum NatsState {
2626
case connected
2727
case disconnected
2828
case closed
29+
case suspended
2930
}
3031

3132
public struct Auth {
@@ -85,6 +86,8 @@ extension NatsClient {
8586
}
8687
if !connectionHandler.retryOnFailedConnect {
8788
try await connectionHandler.connect()
89+
connectionHandler.state = .connected
90+
connectionHandler.fire(.connected)
8891
} else {
8992
connectionHandler.handleReconnect()
9093
}
@@ -98,6 +101,30 @@ extension NatsClient {
98101
try await connectionHandler.close()
99102
}
100103

104+
public func suspend() async throws {
105+
logger.debug("suspend")
106+
guard let connectionHandler = self.connectionHandler else {
107+
throw NatsClientError("internal error: empty connection handler")
108+
}
109+
try await connectionHandler.suspend()
110+
}
111+
112+
public func resume() async throws {
113+
logger.debug("resume")
114+
guard let connectionHandler = self.connectionHandler else {
115+
throw NatsClientError("internal error: empty connection handler")
116+
}
117+
try await connectionHandler.resume()
118+
}
119+
120+
public func reconnect() async throws {
121+
logger.debug("resume")
122+
guard let connectionHandler = self.connectionHandler else {
123+
throw NatsClientError("internal error: empty connection handler")
124+
}
125+
try await connectionHandler.reconnect()
126+
}
127+
101128
public func publish(
102129
_ payload: Data, subject: String, reply: String? = nil, headers: NatsHeaderMap? = nil
103130
) async throws {

Sources/Nats/NatsClient/NatsClientOptions.swift

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import Dispatch
1515
import Foundation
16+
import Logging
1617
import NIO
1718
import NIOFoundationCompat
1819

Sources/Nats/NatsConnection.swift

+89-11
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class ConnectionHandler: ChannelInboundHandler {
4747
private var clientKey: URL?
4848

4949
typealias InboundIn = ByteBuffer
50-
private var state: NatsState = .pending
50+
private let stateLock = NSLock()
51+
internal var state: NatsState = .pending
52+
5153
private var subscriptions: [UInt64: Subscription]
5254
private var subscriptionCounter = ManagedAtomic<UInt64>(0)
5355
private var serverInfo: ServerInfo?
@@ -56,6 +58,7 @@ class ConnectionHandler: ChannelInboundHandler {
5658
private var pingTask: RepeatedTask?
5759
private var outstandingPings = ManagedAtomic<UInt8>(0)
5860
private var reconnectAttempts = 0
61+
private var reconnectTask: Task<(), Never>? = nil
5962

6063
private var group: MultiThreadedEventLoopGroup
6164

@@ -219,7 +222,6 @@ class ConnectionHandler: ChannelInboundHandler {
219222
// if there are more reconnect attempts than the number of servers,
220223
// we are after the initial connect, so sleep between servers
221224
let shouldSleep = self.reconnectAttempts >= self.urls.count
222-
logger.debug("reconnect attempts: \(self.reconnectAttempts)")
223225
for s in servers {
224226
if let maxReconnects {
225227
if reconnectAttempts >= maxReconnects {
@@ -249,8 +251,6 @@ class ConnectionHandler: ChannelInboundHandler {
249251
throw lastErr
250252
}
251253
self.reconnectAttempts = 0
252-
self.state = .connected
253-
self.fire(.connected)
254254
guard let channel = self.channel else {
255255
throw NatsClientError("internal error: empty channel")
256256
}
@@ -530,16 +530,71 @@ class ConnectionHandler: ChannelInboundHandler {
530530
}
531531

532532
func close() async throws {
533-
self.state = .closed
534-
try await disconnect()
533+
self.reconnectTask?.cancel()
534+
await self.reconnectTask?.value
535+
536+
guard let eventLoop = self.channel?.eventLoop else {
537+
throw NatsClientError("internal error: channel should not be nil")
538+
}
539+
let promise = eventLoop.makePromise(of: Void.self)
540+
541+
eventLoop.execute { // This ensures the code block runs on the event loop
542+
self.state = .closed
543+
self.pingTask?.cancel()
544+
self.channel?.close(mode: .all, promise: promise)
545+
}
546+
547+
try await promise.futureResult.get()
535548
self.fire(.closed)
536549
}
537550

538-
func disconnect() async throws {
551+
private func disconnect() async throws {
539552
self.pingTask?.cancel()
540553
try await self.channel?.close().get()
541554
}
542555

556+
func suspend() async throws {
557+
self.reconnectTask?.cancel()
558+
_ = await self.reconnectTask?.value
559+
560+
guard let eventLoop = self.channel?.eventLoop else {
561+
throw NatsClientError("internal error: channel should not be nil")
562+
}
563+
let promise = eventLoop.makePromise(of: Void.self)
564+
565+
eventLoop.execute { // This ensures the code block runs on the event loop
566+
if self.state == .connected {
567+
self.state = .suspended
568+
self.pingTask?.cancel()
569+
self.channel?.close(mode: .all, promise: promise)
570+
} else {
571+
self.state = .suspended
572+
promise.succeed()
573+
}
574+
}
575+
576+
try await promise.futureResult.get()
577+
self.fire(.suspended)
578+
}
579+
580+
func resume() async throws {
581+
guard let eventLoop = self.channel?.eventLoop else {
582+
throw NatsClientError("internal error: channel should not be nil")
583+
}
584+
try await eventLoop.submit {
585+
guard self.state == .suspended else {
586+
throw NatsClientError(
587+
"unable to resume connection - connection is not in suspended state")
588+
}
589+
self.handleReconnect()
590+
}.get()
591+
}
592+
593+
func reconnect() async throws {
594+
try await suspend()
595+
try await resume()
596+
}
597+
543598
internal func sendPing(_ rttCommand: RttCommand? = nil) async {
544599
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
545600
ordering: AtomicUpdateOrdering.relaxed)
@@ -627,19 +682,30 @@ class ConnectionHandler: ChannelInboundHandler {
627682
}
628683

629684
func handleReconnect() {
630-
Task {
631-
while maxReconnects == nil || self.reconnectAttempts < maxReconnects! {
685+
reconnectTask = Task {
686+
var reconnected = false
687+
while !Task.isCancelled
688+
&& (maxReconnects == nil || self.reconnectAttempts < maxReconnects!)
689+
{
632690
do {
633691
try await self.connect()
692+
} catch _ as CancellationError {
693+
// task cancelled
694+
return
634695
} catch {
635696
// TODO(pp): add option to set this to exponential backoff (with jitter)
636697
logger.debug("could not reconnect: \(error)")
637698
continue
638699
}
639700
logger.debug("reconnected")
701+
reconnected = true
640702
break
641703
}
642-
if self.state != .connected {
704+
// if task was cancelled when establishing connection, do not attempt to recreate subscriptions
705+
if Task.isCancelled {
706+
return
707+
}
708+
if !reconnected && !Task.isCancelled {
643709
logger.error("could not reconnect; maxReconnects exceeded")
644710
logger.debug("closing connection")
645711
do {
@@ -651,7 +717,15 @@ class ConnectionHandler: ChannelInboundHandler {
651717
return
652718
}
653719
for (sid, sub) in self.subscriptions {
654-
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
720+
do {
721+
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
722+
} catch {
723+
logger.error("error recreating subscription \(sid): \(error)")
724+
}
725+
}
726+
self.channel?.eventLoop.execute {
727+
self.state = .connected
728+
self.fire(.connected)
655729
}
656730
}
657731
}
@@ -741,6 +815,7 @@ public enum NatsEventKind: String {
741815
case connected = "connected"
742816
case disconnected = "disconnected"
743817
case closed = "closed"
818+
case suspended = "suspended"
744819
case lameDuckMode = "lameDuckMode"
745820
case error = "error"
746821
static let all = [connected, disconnected, closed, lameDuckMode, error]
@@ -749,6 +824,7 @@ public enum NatsEventKind: String {
749824
public enum NatsEvent {
750825
case connected
751826
case disconnected
827+
case suspended
752828
case closed
753829
case lameDuckMode
754830
case error(NatsError)
@@ -759,6 +835,8 @@ public enum NatsEvent {
759835
return .connected
760836
case .disconnected:
761837
return .disconnected
838+
case .suspended:
839+
return .suspended
762840
case .closed:
763841
return .closed
764842
case .lameDuckMode:

Sources/Nats/NatsSubscription.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public class Subscription: AsyncSequence {
6262
}
6363
}
6464

65-
func complete() {
65+
internal func complete() {
6666
lock.withLock {
6767
closed = true
6868
if let continuation {

0 commit comments

Comments
 (0)