@@ -47,7 +47,9 @@ class ConnectionHandler: ChannelInboundHandler {
47
47
private var clientKey : URL ?
48
48
49
49
typealias InboundIn = ByteBuffer
50
- private var state : NatsState = . pending
50
+ private let stateLock = NSLock ( )
51
+ internal var state : NatsState = . pending
52
+
51
53
private var subscriptions : [ UInt64 : Subscription ]
52
54
private var subscriptionCounter = ManagedAtomic < UInt64 > ( 0 )
53
55
private var serverInfo : ServerInfo ?
@@ -56,6 +58,7 @@ class ConnectionHandler: ChannelInboundHandler {
56
58
private var pingTask : RepeatedTask ?
57
59
private var outstandingPings = ManagedAtomic < UInt8 > ( 0 )
58
60
private var reconnectAttempts = 0
61
+ private var reconnectTask : Task < ( ) , Never > ? = nil
59
62
60
63
private var group : MultiThreadedEventLoopGroup
61
64
@@ -219,7 +222,6 @@ class ConnectionHandler: ChannelInboundHandler {
219
222
// if there are more reconnect attempts than the number of servers,
220
223
// we are after the initial connect, so sleep between servers
221
224
let shouldSleep = self . reconnectAttempts >= self . urls. count
222
- logger. debug ( " reconnect attempts: \( self . reconnectAttempts) " )
223
225
for s in servers {
224
226
if let maxReconnects {
225
227
if reconnectAttempts >= maxReconnects {
@@ -249,8 +251,6 @@ class ConnectionHandler: ChannelInboundHandler {
249
251
throw lastErr
250
252
}
251
253
self . reconnectAttempts = 0
252
- self . state = . connected
253
- self . fire ( . connected)
254
254
guard let channel = self . channel else {
255
255
throw NatsClientError ( " internal error: empty channel " )
256
256
}
@@ -530,16 +530,71 @@ class ConnectionHandler: ChannelInboundHandler {
530
530
}
531
531
532
532
func close( ) async throws {
533
- self . state = . closed
534
- try await disconnect ( )
533
+ self . reconnectTask? . cancel ( )
534
+ await self . reconnectTask? . value
535
+
536
+ guard let eventLoop = self . channel? . eventLoop else {
537
+ throw NatsClientError ( " internal error: channel should not be nil " )
538
+ }
539
+ let promise = eventLoop. makePromise ( of: Void . self)
540
+
541
+ eventLoop. execute { // This ensures the code block runs on the event loop
542
+ self . state = . closed
543
+ self . pingTask? . cancel ( )
544
+ self . channel? . close ( mode: . all, promise: promise)
545
+ }
546
+
547
+ try await promise. futureResult. get ( )
535
548
self . fire ( . closed)
536
549
}
537
550
538
- func disconnect( ) async throws {
551
+ private func disconnect( ) async throws {
539
552
self . pingTask? . cancel ( )
540
553
try await self . channel? . close ( ) . get ( )
541
554
}
542
555
556
+ func suspend( ) async throws {
557
+ self . reconnectTask? . cancel ( )
558
+ _ = await self . reconnectTask? . value
559
+
560
+ guard let eventLoop = self . channel? . eventLoop else {
561
+ throw NatsClientError ( " internal error: channel should not be nil " )
562
+ }
563
+ let promise = eventLoop. makePromise ( of: Void . self)
564
+
565
+ eventLoop. execute { // This ensures the code block runs on the event loop
566
+ if self . state == . connected {
567
+ self . state = . suspended
568
+ self . pingTask? . cancel ( )
569
+ self . channel? . close ( mode: . all, promise: promise)
570
+ } else {
571
+ self . state = . suspended
572
+ promise. succeed ( )
573
+ }
574
+ }
575
+
576
+ try await promise. futureResult. get ( )
577
+ self . fire ( . suspended)
578
+ }
579
+
580
+ func resume( ) async throws {
581
+ guard let eventLoop = self . channel? . eventLoop else {
582
+ throw NatsClientError ( " internal error: channel should not be nil " )
583
+ }
584
+ try await eventLoop. submit {
585
+ guard self . state == . suspended else {
586
+ throw NatsClientError (
587
+ " unable to resume connection - connection is not in suspended state " )
588
+ }
589
+ self . handleReconnect ( )
590
+ } . get ( )
591
+ }
592
+
593
+ func reconnect( ) async throws {
594
+ try await suspend ( )
595
+ try await resume ( )
596
+ }
597
+
543
598
internal func sendPing( _ rttCommand: RttCommand? = nil ) async {
544
599
let pingsOut = self . outstandingPings. wrappingIncrementThenLoad (
545
600
ordering: AtomicUpdateOrdering . relaxed)
@@ -627,19 +682,30 @@ class ConnectionHandler: ChannelInboundHandler {
627
682
}
628
683
629
684
func handleReconnect( ) {
630
- Task {
631
- while maxReconnects == nil || self . reconnectAttempts < maxReconnects! {
685
+ reconnectTask = Task {
686
+ var reconnected = false
687
+ while !Task. isCancelled
688
+ && ( maxReconnects == nil || self . reconnectAttempts < maxReconnects!)
689
+ {
632
690
do {
633
691
try await self . connect ( )
692
+ } catch _ as CancellationError {
693
+ // task cancelled
694
+ return
634
695
} catch {
635
696
// TODO(pp): add option to set this to exponential backoff (with jitter)
636
697
logger. debug ( " could not reconnect: \( error) " )
637
698
continue
638
699
}
639
700
logger. debug ( " reconnected " )
701
+ reconnected = true
640
702
break
641
703
}
642
- if self . state != . connected {
704
+ // if task was cancelled when establishing connection, do not attempt to recreate subscriptions
705
+ if Task . isCancelled {
706
+ return
707
+ }
708
+ if !reconnected && !Task. isCancelled {
643
709
logger. error ( " could not reconnect; maxReconnects exceeded " )
644
710
logger. debug ( " closing connection " )
645
711
do {
@@ -651,7 +717,15 @@ class ConnectionHandler: ChannelInboundHandler {
651
717
return
652
718
}
653
719
for (sid, sub) in self . subscriptions {
654
- try await write ( operation: ClientOp . subscribe ( ( sid, sub. subject, nil ) ) )
720
+ do {
721
+ try await write ( operation: ClientOp . subscribe ( ( sid, sub. subject, nil ) ) )
722
+ } catch {
723
+ logger. error ( " error recreating subscription \( sid) : \( error) " )
724
+ }
725
+ }
726
+ self . channel? . eventLoop. execute {
727
+ self . state = . connected
728
+ self . fire ( . connected)
655
729
}
656
730
}
657
731
}
@@ -741,6 +815,7 @@ public enum NatsEventKind: String {
741
815
case connected = " connected "
742
816
case disconnected = " disconnected "
743
817
case closed = " closed "
818
+ case suspended = " suspended "
744
819
case lameDuckMode = " lameDuckMode "
745
820
case error = " error "
746
821
static let all = [ connected, disconnected, closed, lameDuckMode, error]
@@ -749,6 +824,7 @@ public enum NatsEventKind: String {
749
824
public enum NatsEvent {
750
825
case connected
751
826
case disconnected
827
+ case suspended
752
828
case closed
753
829
case lameDuckMode
754
830
case error ( NatsError)
@@ -759,6 +835,8 @@ public enum NatsEvent {
759
835
return . connected
760
836
case . disconnected:
761
837
return . disconnected
838
+ case . suspended:
839
+ return . suspended
762
840
case . closed:
763
841
return . closed
764
842
case . lameDuckMode:
0 commit comments