Skip to content

Commit a6c986c

Browse files
committed
Add closing client connection and client ping
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
1 parent 66d3385 commit a6c986c

File tree

4 files changed

+94
-24
lines changed

4 files changed

+94
-24
lines changed

Sources/NatsSwift/NatsClient/NatsClient.swift

+9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public enum NatsState {
1717
case Pending
1818
case Connected
1919
case Disconnected
20+
case Closed
2021
}
2122

2223
public struct Auth {
@@ -67,6 +68,14 @@ extension Client {
6768
}
6869
try await connectionHandler.connect()
6970
}
71+
72+
public func close() async throws {
73+
logger.debug("close")
74+
guard let connectionHandler = self.connectionHandler else {
75+
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
76+
}
77+
try await connectionHandler.close()
78+
}
7079

7180
public func publish(_ payload: Data, subject: String, reply: String? = nil, headers: HeaderMap? = nil) throws {
7281
logger.debug("publish")

Sources/NatsSwift/NatsConnection.swift

+79-8
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ import Foundation
77
import NIO
88
import NIOFoundationCompat
99
import Dispatch
10+
import Atomics
1011

1112
class ConnectionHandler: ChannelInboundHandler {
1213
let lang = "Swift"
1314
let version = "0.0.1"
1415

1516
internal let allocator = ByteBufferAllocator()
1617
internal var inputBuffer: ByteBuffer
18+
19+
// Connection options
1720
internal var urls: [URL]
1821
// nanoseconds representation of TimeInterval
1922
internal let reconnectWait: UInt64
@@ -23,10 +26,12 @@ class ConnectionHandler: ChannelInboundHandler {
2326
typealias InboundIn = ByteBuffer
2427
internal var state: NatsState = .Pending
2528
internal var subscriptions: [ UInt64: Subscription ]
26-
internal var subscriptionCounter = SubscriptionCounter()
29+
internal var subscriptionCounter = ManagedAtomic<UInt64>(0)
2730
internal var serverInfo: ServerInfo?
2831
internal var auth: Auth?
2932
private var parseRemainder: Data?
33+
private var pingTask: RepeatedTask?
34+
private var outstandingPings = ManagedAtomic<UInt8>(0)
3035

3136

3237
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
@@ -50,28 +55,29 @@ class ConnectionHandler: ChannelInboundHandler {
5055
}
5156
for op in parseResult.ops {
5257
if let continuation = self.serverInfoContinuation {
58+
self.serverInfoContinuation = nil
5359
logger.debug("server info")
5460
switch op {
5561
case let .Error(err):
5662
continuation.resume(throwing: err)
5763
case let .Info(info):
5864
continuation.resume(returning: info)
5965
default:
60-
continuation.resume(throwing: NSError(domain: "nats_swift", code: 1, userInfo: ["message": "unexpected operation; expected server info: \(op)"]))
66+
// ignore until we get either error or server info
67+
continue
6168
}
62-
self.serverInfoContinuation = nil
6369
continue
6470
}
6571

6672
if let continuation = self.connectionEstablishedContinuation {
73+
self.connectionEstablishedContinuation = nil
6774
logger.debug("conn established")
6875
switch op {
6976
case let .Error(err):
7077
continuation.resume(throwing: err)
7178
default:
7279
continuation.resume()
7380
}
74-
self.connectionEstablishedContinuation = nil
7581
continue
7682
}
7783

@@ -85,6 +91,9 @@ class ConnectionHandler: ChannelInboundHandler {
8591
logger.error("error sending pong: \(error)")
8692
continue
8793
}
94+
case .Pong:
95+
logger.debug("pong")
96+
self.outstandingPings.store(0, ordering: AtomicStoreOrdering.relaxed)
8897
case let .Error(err):
8998
logger.debug("error \(err)")
9099
case let .Message(msg):
@@ -165,9 +174,44 @@ class ConnectionHandler: ChannelInboundHandler {
165174
}
166175
}
167176
self.state = .Connected
177+
guard let channel = self.channel else {
178+
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty channel"])
179+
}
180+
// Schedule the task to send a PING periodically
181+
let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval*1_000_000_000))
182+
self.pingTask = channel.eventLoop.scheduleRepeatedTask(initialDelay: pingInterval, delay: pingInterval) { [weak self] task in
183+
self?.sendPing()
184+
}
168185
logger.debug("connection established")
169186
}
170187

188+
func close() async throws {
189+
self.state = .Closed
190+
try await disconnect()
191+
try await self.group.shutdownGracefully()
192+
}
193+
194+
func disconnect() async throws {
195+
self.pingTask?.cancel()
196+
try await self.channel?.close().get()
197+
}
198+
199+
private func sendPing() {
200+
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(ordering: AtomicUpdateOrdering.relaxed)
201+
if pingsOut > 2 {
202+
handleDisconnect()
203+
return
204+
}
205+
let ping = ClientOp.Ping
206+
do {
207+
try self.write(operation: ping)
208+
logger.debug("sent ping: \(pingsOut)")
209+
} catch {
210+
logger.error("Unable to send ping: \(error)")
211+
}
212+
213+
}
214+
171215
func channelActive(context: ChannelHandlerContext) {
172216
logger.debug("TCP channel active")
173217

@@ -177,18 +221,43 @@ class ConnectionHandler: ChannelInboundHandler {
177221
func channelInactive(context: ChannelHandlerContext) {
178222
logger.debug("TCP channel inactive")
179223

180-
handleDisconnect()
224+
if self.state == .Connected {
225+
handleDisconnect()
226+
}
181227
}
182228

183229
func errorCaught(context: ChannelHandlerContext, error: Error) {
184230
// TODO(pp): implement Close() on the connection and call it here
185231
logger.debug("Encountered error on the channel: \(error)")
186-
self.state = .Disconnected
187-
handleReconnect()
232+
context.close(promise: nil)
233+
if self.state == .Connected {
234+
handleDisconnect()
235+
} else if self.state == .Disconnected {
236+
handleReconnect()
237+
}
188238
}
189239

190240
func handleDisconnect() {
191241
self.state = .Disconnected
242+
if let channel = self.channel {
243+
let promise = channel.eventLoop.makePromise(of: Void.self)
244+
Task {
245+
do {
246+
try await self.disconnect()
247+
promise.succeed()
248+
} catch {
249+
promise.fail(error)
250+
}
251+
}
252+
promise.futureResult.whenComplete { result in
253+
do {
254+
try result.get()
255+
} catch {
256+
logger.error("Error closing connection: \(error)")
257+
}
258+
}
259+
}
260+
192261
handleReconnect()
193262
}
194263

@@ -200,10 +269,12 @@ class ConnectionHandler: ChannelInboundHandler {
200269
try await self.connect()
201270
} catch {
202271
// TODO(pp): add option to set this to exponential backoff (with jitter)
272+
logger.debug("could not reconnect: \(error)")
203273
try await Task.sleep(nanoseconds: self.reconnectWait)
204274
attempts += 1
205275
continue
206276
}
277+
logger.debug("reconnected")
207278
break
208279
}
209280
for (sid, sub) in self.subscriptions {
@@ -242,7 +313,7 @@ class ConnectionHandler: ChannelInboundHandler {
242313
}
243314

244315
func subscribe(_ subject: String) async throws -> Subscription {
245-
let sid = self.subscriptionCounter.next()
316+
let sid = self.subscriptionCounter.wrappingIncrementThenLoad(ordering: AtomicUpdateOrdering.relaxed)
246317
try write(operation: ClientOp.Subscribe((sid, subject, nil)))
247318
let sub = Subscription(subject: subject)
248319
self.subscriptions[sid] = sub

Sources/NatsSwift/NatsSubscription.swift

-12
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,3 @@ public class Subscription: AsyncSequence {
8484
}
8585
}
8686
}
87-
88-
internal class SubscriptionCounter {
89-
private var counter: UInt64 = 0
90-
private let queue = DispatchQueue(label: "io.nats.swift.subscriptionCounter")
91-
92-
func next() -> UInt64 {
93-
queue.sync {
94-
counter+=1
95-
return counter
96-
}
97-
}
98-
}

Tests/NatsSwiftTests/Integration/ConnectionTests.swift

+6-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class CoreNatsTests: XCTestCase {
121121
}
122122

123123
// make sure sub receives messages
124-
for await msg in sub {
124+
for await _ in sub {
125125
messagesReceived += 1
126126
if messagesReceived == 10 {
127127
break
@@ -140,16 +140,17 @@ class CoreNatsTests: XCTestCase {
140140
try client.publish(payload, subject: "foo")
141141
}
142142
}
143-
144-
for await msg in sub {
143+
144+
for await _ in sub {
145145
messagesReceived += 1
146-
if messagesReceived == 10 {
146+
if messagesReceived == 20 {
147147
break
148148
}
149149
}
150150

151151
// Check if the total number of messages received matches the number sent
152152
XCTAssertEqual(20, messagesReceived, "Mismatch in the number of messages sent and received")
153+
try await client.close()
153154
}
154155

155156
func testUsernameAndPassword() async throws {
@@ -226,3 +227,4 @@ class CoreNatsTests: XCTestCase {
226227

227228
}
228229
}
230+

0 commit comments

Comments
 (0)