diff --git a/Package.swift b/Package.swift index 10aada2..b1ca68e 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,13 @@ let package = Package( .product(name: "Logging", package: "swift-log"), .product(name: "NIOFoundationCompat", package: "swift-nio") ]), - .testTarget(name: "NatsSwiftTests", dependencies: ["NatsSwift"]), + .testTarget( + name: "NatsSwiftTests", + dependencies: ["NatsSwift"], + resources: [ + .process("Integration/Resources") + ] + ), .executableTarget(name: "Benchmark", dependencies: ["NatsSwift"]), .executableTarget(name: "BenchmarkPubSub", dependencies: ["NatsSwift"]), diff --git a/Sources/NatsSwift/NatsClient/NatsClient.swift b/Sources/NatsSwift/NatsClient/NatsClient.swift index 11e288d..b91e6bd 100755 --- a/Sources/NatsSwift/NatsClient/NatsClient.swift +++ b/Sources/NatsSwift/NatsClient/NatsClient.swift @@ -19,13 +19,41 @@ public enum NatsState { case Disconnected } +public struct Auth { + var user: String? + var password: String? + var token: String? + + init(user: String, password: String) { + self.user = user + self.password = password + } + init(token: String) { + self.token = token + } +} + public class Client { + var urls: [URL] = [] + var pingInteval: TimeInterval = 1.0 + var reconnectWait: TimeInterval = 2.0 + var maxReconnects: Int? = nil + var auth: Auth? = nil + internal let allocator = ByteBufferAllocator() internal var buffer: ByteBuffer internal var connectionHandler: ConnectionHandler? internal init() { self.buffer = allocator.buffer(capacity: 1024) + self.connectionHandler = ConnectionHandler( + inputBuffer: buffer, + urls: urls, + reconnectWait: reconnectWait, + maxReconnects: maxReconnects, + pingInterval: pingInteval, + auth: auth + ) } } @@ -62,6 +90,6 @@ extension Client { throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"]) } return try await connectionHandler.subscribe(subject) - + } } diff --git a/Sources/NatsSwift/NatsClient/NatsClientOptions.swift b/Sources/NatsSwift/NatsClient/NatsClientOptions.swift index 532d93e..5ccf284 100644 --- a/Sources/NatsSwift/NatsClient/NatsClientOptions.swift +++ b/Sources/NatsSwift/NatsClient/NatsClientOptions.swift @@ -13,6 +13,7 @@ public class ClientOptions { private var pingInterval: TimeInterval = 60.0 private var reconnectWait: TimeInterval = 2.0 private var maxReconnects: Int = 60 + private var auth: Auth? = nil public init() {} @@ -20,7 +21,7 @@ public class ClientOptions { self.urls = urls return self } - + public func url(_ url: URL) -> ClientOptions { self.urls = [url] return self @@ -41,6 +42,25 @@ public class ClientOptions { return self } + public func username_and_password(_ username: String, _ password: String) -> ClientOptions { + if self.auth == nil { + self.auth = Auth(user: username, password: password) + } else { + self.auth?.user = username + self.auth?.password = password + } + return self + } + + public func token(_ token: String) -> ClientOptions { + if self.auth == nil { + self.auth = Auth(token: token) + } else { + self.auth?.token = token + } + return self + } + public func build() -> Client { let client = Client() client.connectionHandler = ConnectionHandler( @@ -48,9 +68,10 @@ public class ClientOptions { urls: urls, reconnectWait: reconnectWait, maxReconnects: maxReconnects, - pingInterval: pingInterval + pingInterval: pingInterval, + auth: auth ) - + return client } } diff --git a/Sources/NatsSwift/NatsConnection.swift b/Sources/NatsSwift/NatsConnection.swift index 5fefbcf..dbb82c8 100644 --- a/Sources/NatsSwift/NatsConnection.swift +++ b/Sources/NatsSwift/NatsConnection.swift @@ -19,12 +19,13 @@ class ConnectionHandler: ChannelInboundHandler { internal let reconnectWait: UInt64 internal let maxReconnects: Int? internal let pingInterval: TimeInterval - + typealias InboundIn = ByteBuffer internal var state: NatsState = .Pending internal var subscriptions: [ UInt64: Subscription ] internal var subscriptionCounter = SubscriptionCounter() internal var serverInfo: ServerInfo? + internal var auth: Auth? private var parseRemainder: Data? @@ -37,7 +38,7 @@ class ConnectionHandler: ChannelInboundHandler { // TODO(pp): errors in parser should trigger context.fireErrorCaught() which invokes errorCaught() and invokes reconnect func channelReadComplete(context: ChannelHandlerContext) { var inputChunk = Data(buffer: inputBuffer) - + if let remainder = self.parseRemainder { inputChunk.prepend(remainder) } @@ -99,7 +100,7 @@ class ConnectionHandler: ChannelInboundHandler { } inputBuffer.clear() } - init(inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?, pingInterval: TimeInterval) { + init(inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?, pingInterval: TimeInterval, auth: Auth?) { self.inputBuffer = self.allocator.buffer(capacity: 1024) self.urls = urls self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) @@ -107,8 +108,10 @@ class ConnectionHandler: ChannelInboundHandler { self.subscriptions = [UInt64:Subscription]() self.reconnectWait = UInt64(reconnectWait * 1_000_000_000) self.maxReconnects = maxReconnects + self.auth = auth self.pingInterval = pingInterval } + internal var group: MultiThreadedEventLoopGroup internal var channel: Channel? @@ -146,7 +149,8 @@ class ConnectionHandler: ChannelInboundHandler { // Wait for the first message after sending the connect request } self.serverInfo = info - let connect = ConnectInfo(verbose: false, pedantic: false, userJwt: nil, nkey: "", signature: nil, name: "", echo: true, lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false, user: "", pass: "", authToken: "", headers: true, noResponders: true) + // TODO(jrm): Add rest of auth here. + let connect = ConnectInfo(verbose: false, pedantic: false, userJwt: nil, nkey: "", signature: nil, name: "", echo: true, lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false, user: self.auth?.user ?? "", pass: self.auth?.password ?? "", authToken: self.auth?.token ?? "", headers: true, noResponders: true) try await withCheckedThrowingContinuation { continuation in self.connectionEstablishedContinuation = continuation @@ -163,7 +167,7 @@ class ConnectionHandler: ChannelInboundHandler { self.state = .Connected logger.debug("connection established") } - + func channelActive(context: ChannelHandlerContext) { logger.debug("TCP channel active") @@ -175,19 +179,19 @@ class ConnectionHandler: ChannelInboundHandler { handleDisconnect() } - + func errorCaught(context: ChannelHandlerContext, error: Error) { // TODO(pp): implement Close() on the connection and call it here logger.debug("Encountered error on the channel: \(error)") self.state = .Disconnected handleReconnect() } - + func handleDisconnect() { self.state = .Disconnected handleReconnect() } - + func handleReconnect() { Task { var attempts = 0 diff --git a/Tests/NatsSwiftTests/Integration/ConnectionTests.swift b/Tests/NatsSwiftTests/Integration/ConnectionTests.swift old mode 100644 new mode 100755 index 150bf3a..0ce39a9 --- a/Tests/NatsSwiftTests/Integration/ConnectionTests.swift +++ b/Tests/NatsSwiftTests/Integration/ConnectionTests.swift @@ -14,7 +14,10 @@ class CoreNatsTests: XCTestCase { ("testPublish", testPublish), ("testPublishWithReply", testPublishWithReply), ("testSubscribe", testSubscribe), - ("testConnect", testConnect) + ("testConnect", testConnect), + ("testReconnect", testReconnect), + ("testUsernameAndPassword", testUsernameAndPassword), + ("testTokenAuth", testTokenAuth) ] var natsServer = NatsServer() @@ -137,7 +140,7 @@ class CoreNatsTests: XCTestCase { try client.publish(payload, subject: "foo") } } - + for await msg in sub { messagesReceived += 1 if messagesReceived == 10 { @@ -148,4 +151,78 @@ class CoreNatsTests: XCTestCase { // Check if the total number of messages received matches the number sent XCTAssertEqual(20, messagesReceived, "Mismatch in the number of messages sent and received") } + + func testUsernameAndPassword() async throws { + logger.logLevel = .debug + let currentFile = URL(fileURLWithPath: #file) + // Navigate up to the Tests directory + let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent() + // Construct the path to the resource + let resourceURL = testsDir + .appendingPathComponent("Integration/Resources/creds.conf", isDirectory: false) + natsServer.start(cfg: resourceURL.path) + let client = ClientOptions() + .url(URL(string:natsServer.clientURL)!) + .username_and_password("derek", "s3cr3t") + .maxReconnects(5) + .build() + try await client.connect() + try client.publish("msg".data(using: .utf8)!, subject: "test") + try await client.flush() + try await client.subscribe(to: "test") + XCTAssertNotNil(client, "Client should not be nil") + + + // Test if client with bad credentials throws an error + let bad_creds_client = ClientOptions() + .url(URL(string:natsServer.clientURL)!) + .username_and_password("derek", "badpassword") + .maxReconnects(5) + .build() + + do { + try await bad_creds_client.connect() + XCTFail("Should have thrown an error") + } catch { + XCTAssertNotNil(error, "Error should not be nil") + } + + } + + func testTokenAuth() async throws { + logger.logLevel = .debug + let currentFile = URL(fileURLWithPath: #file) + // Navigate up to the Tests directory + let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent() + // Construct the path to the resource + let resourceURL = testsDir + .appendingPathComponent("Integration/Resources/token.conf", isDirectory: false) + natsServer.start(cfg: resourceURL.path) + let client = ClientOptions() + .url(URL(string:natsServer.clientURL)!) + .token("s3cr3t") + .maxReconnects(5) + .build() + try await client.connect() + try client.publish("msg".data(using: .utf8)!, subject: "test") + try await client.flush() + try await client.subscribe(to: "test") + XCTAssertNotNil(client, "Client should not be nil") + + + // Test if client with bad credentials throws an error + let bad_creds_client = ClientOptions() + .url(URL(string:natsServer.clientURL)!) + .token("badtoken") + .maxReconnects(5) + .build() + + do { + try await bad_creds_client.connect() + XCTFail("Should have thrown an error") + } catch { + XCTAssertNotNil(error, "Error should not be nil") + } + + } } diff --git a/Tests/NatsSwiftTests/Integration/Resources/creds.conf b/Tests/NatsSwiftTests/Integration/Resources/creds.conf new file mode 100644 index 0000000..db158eb --- /dev/null +++ b/Tests/NatsSwiftTests/Integration/Resources/creds.conf @@ -0,0 +1,5 @@ +authorization { + user: derek + password: s3cr3t +} + diff --git a/Tests/NatsSwiftTests/Integration/Resources/token.conf b/Tests/NatsSwiftTests/Integration/Resources/token.conf new file mode 100644 index 0000000..8257cb9 --- /dev/null +++ b/Tests/NatsSwiftTests/Integration/Resources/token.conf @@ -0,0 +1,4 @@ +authorization { + token: s3cr3t +} +