Skip to content

Commit ff371f3

Browse files
committed
Add queue subscribe
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
1 parent c841763 commit ff371f3

File tree

5 files changed

+99
-14
lines changed

5 files changed

+99
-14
lines changed

Sources/Nats/NatsClient/NatsClient.swift

+9-6
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ extension NatsClient {
225225
operation: ClientOp.publish((subject, inbox, payload, headers)))
226226

227227
return try await withThrowingTaskGroup(
228-
of: NatsMessage?.self,
229-
body: { group in
228+
of: NatsMessage?.self) { group in
230229
group.addTask {
231230
do {
232231
return try await sub.makeAsyncIterator().next()
@@ -258,7 +257,7 @@ extension NatsClient {
258257

259258
// this should not be reachable
260259
throw NatsError.ClientError.internalError("error waiting for response")
261-
})
260+
}
262261
}
263262

264263
/// Flushes the internal buffer ensuring that all messages are sent.
@@ -277,22 +276,26 @@ extension NatsClient {
277276

278277
/// Subscribes to a subject to receive messages.
279278
///
280-
/// - Parameter subject is a subject the client want's to subscribe to.
279+
/// - Parameters:
280+
/// - subject:a subject the client want's to subscribe to.
281+
/// - queue: optional queue group name.
281282
///
282283
/// - Returns a ``NatsSubscription`` allowing iteration over incoming messages.
283284
///
284285
/// > **Throws:**
285286
/// > - ``NatsError/ClientError/connectionClosed`` if the conneciton is closed.
286287
/// > - ``NatsError/ClientError/io(_:)`` if there is an error sending the SUB request to the server.
287-
public func subscribe(subject: String) async throws -> NatsSubscription {
288+
/// > - ``NatsError/SubscriptionError/invalidSubject`` if the provided subject is invalid.
289+
/// > - ``NatsError/SubscriptionError/invalidQueue`` if the provided queue group is invalid.
290+
public func subscribe(subject: String, queue: String? = nil) async throws -> NatsSubscription {
288291
logger.info("subscribe to subject \(subject)")
289292
guard let connectionHandler = self.connectionHandler else {
290293
throw NatsError.ClientError.internalError("empty connection handler")
291294
}
292295
if case .closed = connectionHandler.state {
293296
throw NatsError.ClientError.connectionClosed
294297
}
295-
return try await connectionHandler.subscribe(subject)
298+
return try await connectionHandler.subscribe(subject, queue: queue)
296299
}
297300

298301
/// Sends a PING to the server, returning the time it took for the server to respond.

Sources/Nats/NatsConnection.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -819,11 +819,11 @@ class ConnectionHandler: ChannelInboundHandler {
819819
}
820820
}
821821

822-
internal func subscribe(_ subject: String) async throws -> NatsSubscription {
822+
internal func subscribe(_ subject: String, queue: String? = nil) async throws -> NatsSubscription {
823823
let sid = self.subscriptionCounter.wrappingIncrementThenLoad(
824824
ordering: AtomicUpdateOrdering.relaxed)
825-
try await write(operation: ClientOp.subscribe((sid, subject, nil)))
826-
let sub = NatsSubscription(sid: sid, subject: subject, conn: self)
825+
let sub = try NatsSubscription(sid: sid, subject: subject, queue: queue, conn: self)
826+
try await write(operation: ClientOp.subscribe((sid, subject, queue)))
827827
self.subscriptions[sid] = sub
828828
return sub
829829
}

Sources/Nats/NatsError.swift

+6
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,17 @@ public enum NatsError {
205205
}
206206

207207
public enum SubscriptionError: NatsErrorProtocol, Equatable {
208+
case invalidSubject
209+
case invalidQueue
208210
case permissionDenied
209211
case subscriptionClosed
210212

211213
public var description: String {
212214
switch self {
215+
case .invalidSubject:
216+
return "nats: invalid subject name"
217+
case .invalidQueue:
218+
return "nats: invalid queue group name"
213219
case .permissionDenied:
214220
return "nats: permission denied"
215221
case .subscriptionClosed:

Sources/Nats/NatsSubscription.swift

+34-5
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import Foundation
1515

1616
// TODO(pp): Implement slow consumer
17-
// TODO(pp): Add queue subscribe
1817
public class NatsSubscription: AsyncSequence {
1918
public typealias Element = NatsMessage
2019
public typealias AsyncIterator = SubscriptionIterator
2120

2221
public let subject: String
22+
public let queue: String?
2323
internal var max: UInt64?
2424
internal var delivered: UInt64 = 0
2525
internal let sid: UInt64
@@ -34,14 +34,21 @@ public class NatsSubscription: AsyncSequence {
3434

3535
private static let defaultSubCapacity: UInt64 = 512 * 1024
3636

37-
convenience init(sid: UInt64, subject: String, conn: ConnectionHandler) {
38-
self.init(
39-
sid: sid, subject: subject, capacity: NatsSubscription.defaultSubCapacity, conn: conn)
37+
convenience init(sid: UInt64, subject: String, queue: String?, conn: ConnectionHandler) throws {
38+
try self.init(
39+
sid: sid, subject: subject, queue: queue, capacity: NatsSubscription.defaultSubCapacity, conn: conn)
4040
}
4141

42-
init(sid: UInt64, subject: String, capacity: UInt64, conn: ConnectionHandler) {
42+
init(sid: UInt64, subject: String, queue: String?, capacity: UInt64, conn: ConnectionHandler) throws {
43+
if !NatsSubscription.validSubject(subject) {
44+
throw NatsError.SubscriptionError.invalidSubject
45+
}
46+
if let queue, !NatsSubscription.validQueue(queue) {
47+
throw NatsError.SubscriptionError.invalidQueue
48+
}
4349
self.sid = sid
4450
self.subject = subject
51+
self.queue = queue
4552
self.capacity = capacity
4653
self.buffer = []
4754
self.conn = conn
@@ -150,4 +157,26 @@ public class NatsSubscription: AsyncSequence {
150157
}
151158
return try await self.conn.unsubscribe(sub: self, max: after)
152159
}
160+
161+
// validateSubject will do a basic subject validation.
162+
// Spaces are not allowed and all tokens should be > 0 in length.
163+
private static func validSubject(_ subj: String) -> Bool {
164+
let whitespaceCharacterSet = CharacterSet.whitespacesAndNewlines
165+
if subj.rangeOfCharacter(from: whitespaceCharacterSet) != nil {
166+
return false
167+
}
168+
let tokens = subj.split(separator: ".")
169+
for token in tokens {
170+
if token.isEmpty {
171+
return false
172+
}
173+
}
174+
return true
175+
}
176+
177+
// validQueue will check a queue name for whitespaces.
178+
private static func validQueue(_ queue: String) -> Bool {
179+
let whitespaceCharacterSet = CharacterSet.whitespacesAndNewlines
180+
return queue.rangeOfCharacter(from: whitespaceCharacterSet) == nil
181+
}
153182
}

Tests/NatsTests/Integration/ConnectionTests.swift

+47
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,53 @@ class CoreNatsTests: XCTestCase {
314314
XCTAssertEqual(message?.payload, "msg".data(using: .utf8)!)
315315
}
316316

317+
func testQueueGroupSubscribe() async throws {
318+
natsServer.start()
319+
logger.logLevel = .debug
320+
let client = NatsClientOptions().url(URL(string: natsServer.clientURL)!).build()
321+
try await client.connect()
322+
323+
let sub1 = try await client.subscribe(subject: "test", queue: "queueGroup")
324+
let sub2 = try await client.subscribe(subject: "test", queue: "queueGroup")
325+
326+
try await client.publish("msg".data(using: .utf8)!, subject: "test")
327+
328+
try await withThrowingTaskGroup(of: NatsMessage?.self) { group in
329+
group.addTask {try await sub1.makeAsyncIterator().next()}
330+
group.addTask {try await sub2.makeAsyncIterator().next()}
331+
group.addTask {
332+
try await Task.sleep(nanoseconds: UInt64(1_000_000_000))
333+
return nil
334+
}
335+
336+
var msgReceived = false
337+
var timeoutReceived = false
338+
for try await result in group {
339+
if let _ = result {
340+
if msgReceived == true {
341+
XCTFail("received 2 messages")
342+
return
343+
}
344+
msgReceived = true
345+
} else {
346+
if !msgReceived {
347+
XCTFail("timeout received before getting any messages")
348+
return
349+
}
350+
timeoutReceived = true
351+
}
352+
if msgReceived && timeoutReceived {
353+
break
354+
}
355+
}
356+
group.cancelAll()
357+
try await sub1.unsubscribe()
358+
try await sub2.unsubscribe()
359+
return
360+
}
361+
}
362+
363+
317364
func testUnsubscribe() async throws {
318365
natsServer.start()
319366
logger.logLevel = .debug

0 commit comments

Comments
 (0)