@@ -7,13 +7,16 @@ import Foundation
7
7
import NIO
8
8
import NIOFoundationCompat
9
9
import Dispatch
10
+ import Atomics
10
11
11
12
class ConnectionHandler: ChannelInboundHandler {
12
13
let lang = " Swift "
13
14
let version = " 0.0.1 "
14
15
15
16
internal let allocator = ByteBufferAllocator ( )
16
17
internal var inputBuffer : ByteBuffer
18
+
19
+ // Connection options
17
20
internal var urls : [ URL ]
18
21
// nanoseconds representation of TimeInterval
19
22
internal let reconnectWait : UInt64
@@ -23,10 +26,12 @@ class ConnectionHandler: ChannelInboundHandler {
23
26
typealias InboundIn = ByteBuffer
24
27
internal var state : NatsState = . Pending
25
28
internal var subscriptions : [ UInt64 : Subscription ]
26
- internal var subscriptionCounter = SubscriptionCounter ( )
29
+ internal var subscriptionCounter = ManagedAtomic < UInt64 > ( 0 )
27
30
internal var serverInfo : ServerInfo ?
28
31
internal var auth : Auth ?
29
32
private var parseRemainder : Data ?
33
+ private var pingTask : RepeatedTask ?
34
+ private var outstandingPings = ManagedAtomic < UInt8 > ( 0 )
30
35
31
36
32
37
func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
@@ -50,28 +55,29 @@ class ConnectionHandler: ChannelInboundHandler {
50
55
}
51
56
for op in parseResult. ops {
52
57
if let continuation = self . serverInfoContinuation {
58
+ self . serverInfoContinuation = nil
53
59
logger. debug ( " server info " )
54
60
switch op {
55
61
case let . Error( err) :
56
62
continuation. resume ( throwing: err)
57
63
case let . Info( info) :
58
64
continuation. resume ( returning: info)
59
65
default :
60
- continuation. resume ( throwing: NSError ( domain: " nats_swift " , code: 1 , userInfo: [ " message " : " unexpected operation; expected server info: \( op) " ] ) )
66
+ // ignore until we get either error or server info
67
+ continue
61
68
}
62
- self . serverInfoContinuation = nil
63
69
continue
64
70
}
65
71
66
72
if let continuation = self . connectionEstablishedContinuation {
73
+ self . connectionEstablishedContinuation = nil
67
74
logger. debug ( " conn established " )
68
75
switch op {
69
76
case let . Error( err) :
70
77
continuation. resume ( throwing: err)
71
78
default :
72
79
continuation. resume ( )
73
80
}
74
- self . connectionEstablishedContinuation = nil
75
81
continue
76
82
}
77
83
@@ -85,6 +91,9 @@ class ConnectionHandler: ChannelInboundHandler {
85
91
logger. error ( " error sending pong: \( error) " )
86
92
continue
87
93
}
94
+ case . Pong:
95
+ logger. debug ( " pong " )
96
+ self . outstandingPings. store ( 0 , ordering: AtomicStoreOrdering . relaxed)
88
97
case let . Error( err) :
89
98
logger. debug ( " error \( err) " )
90
99
case let . Message( msg) :
@@ -165,9 +174,44 @@ class ConnectionHandler: ChannelInboundHandler {
165
174
}
166
175
}
167
176
self. state = . Connected
177
+ guard let channel = self . channel else {
178
+ throw NSError ( domain: " nats_swift " , code: 1 , userInfo: [ " message " : " empty channel " ] )
179
+ }
180
+ // Schedule the task to send a PING periodically
181
+ let pingInterval = TimeAmount . nanoseconds ( Int64 ( self . pingInterval*1_000_000_000) )
182
+ self . pingTask = channel. eventLoop. scheduleRepeatedTask ( initialDelay: pingInterval, delay: pingInterval) { [ weak self] task in
183
+ self ? . sendPing ( )
184
+ }
168
185
logger. debug ( " connection established " )
169
186
}
170
187
188
+ func close( ) async throws {
189
+ self . state = . Closed
190
+ try await disconnect ( )
191
+ try await self . group. shutdownGracefully ( )
192
+ }
193
+
194
+ func disconnect( ) async throws {
195
+ self . pingTask? . cancel ( )
196
+ try await self . channel? . close ( ) . get ( )
197
+ }
198
+
199
+ private func sendPing( ) {
200
+ let pingsOut = self . outstandingPings. wrappingIncrementThenLoad ( ordering: AtomicUpdateOrdering . relaxed)
201
+ if pingsOut > 2 {
202
+ handleDisconnect ( )
203
+ return
204
+ }
205
+ let ping = ClientOp . Ping
206
+ do {
207
+ try self . write ( operation: ping)
208
+ logger. debug ( " sent ping: \( pingsOut) " )
209
+ } catch {
210
+ logger. error ( " Unable to send ping: \( error) " )
211
+ }
212
+
213
+ }
214
+
171
215
func channelActive( context: ChannelHandlerContext) {
172
216
logger. debug ( " TCP channel active " )
173
217
@@ -177,18 +221,43 @@ class ConnectionHandler: ChannelInboundHandler {
177
221
func channelInactive( context: ChannelHandlerContext) {
178
222
logger. debug ( " TCP channel inactive " )
179
223
180
- handleDisconnect ( )
224
+ if self . state == . Connected {
225
+ handleDisconnect ( )
226
+ }
181
227
}
182
228
183
229
func errorCaught( context: ChannelHandlerContext, error: Error) {
184
230
// TODO(pp): implement Close() on the connection and call it here
185
231
logger. debug ( " Encountered error on the channel: \( error) " )
186
- self . state = . Disconnected
187
- handleReconnect ( )
232
+ context. close ( promise: nil )
233
+ if self . state == . Connected {
234
+ handleDisconnect ( )
235
+ } else if self . state == . Disconnected {
236
+ handleReconnect ( )
237
+ }
188
238
}
189
239
190
240
func handleDisconnect( ) {
191
241
self . state = . Disconnected
242
+ if let channel = self . channel {
243
+ let promise = channel. eventLoop. makePromise ( of: Void . self)
244
+ Task {
245
+ do {
246
+ try await self . disconnect ( )
247
+ promise. succeed ( )
248
+ } catch {
249
+ promise. fail ( error)
250
+ }
251
+ }
252
+ promise. futureResult. whenComplete { result in
253
+ do {
254
+ try result. get ( )
255
+ } catch {
256
+ logger. error ( " Error closing connection: \( error) " )
257
+ }
258
+ }
259
+ }
260
+
192
261
handleReconnect ( )
193
262
}
194
263
@@ -200,10 +269,12 @@ class ConnectionHandler: ChannelInboundHandler {
200
269
try await self . connect ( )
201
270
} catch {
202
271
// TODO(pp): add option to set this to exponential backoff (with jitter)
272
+ logger. debug ( " could not reconnect: \( error) " )
203
273
try await Task . sleep ( nanoseconds: self . reconnectWait)
204
274
attempts += 1
205
275
continue
206
276
}
277
+ logger. debug ( " reconnected " )
207
278
break
208
279
}
209
280
for (sid, sub) in self . subscriptions {
@@ -242,7 +313,7 @@ class ConnectionHandler: ChannelInboundHandler {
242
313
}
243
314
244
315
func subscribe( _ subject: String) async throws -> Subscription {
245
- let sid = self . subscriptionCounter. next ( )
316
+ let sid = self . subscriptionCounter. wrappingIncrementThenLoad ( ordering : AtomicUpdateOrdering . relaxed )
246
317
try write ( operation: ClientOp . Subscribe ( ( sid, subject, nil ) ) )
247
318
let sub = Subscription ( subject: subject)
248
319
self . subscriptions [ sid] = sub
0 commit comments