Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add user and password auth #24

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ let package = Package(
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOFoundationCompat", package: "swift-nio")
]),
.testTarget(name: "NatsSwiftTests", dependencies: ["NatsSwift"]),
.testTarget(
name: "NatsSwiftTests",
dependencies: ["NatsSwift"],
resources: [
.process("Integration/Resources")
]
),

.executableTarget(name: "Benchmark", dependencies: ["NatsSwift"]),
.executableTarget(name: "BenchmarkPubSub", dependencies: ["NatsSwift"]),
Expand Down
30 changes: 29 additions & 1 deletion Sources/NatsSwift/NatsClient/NatsClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,41 @@ public enum NatsState {
case Disconnected
}

public struct Auth {
var user: String?
var password: String?
var token: String?

init(user: String, password: String) {
self.user = user
self.password = password
}
init(token: String) {
self.token = token
}
}

public class Client {
var urls: [URL] = []
var pingInteval: TimeInterval = 1.0
var reconnectWait: TimeInterval = 2.0
var maxReconnects: Int? = nil
var auth: Auth? = nil

internal let allocator = ByteBufferAllocator()
internal var buffer: ByteBuffer
internal var connectionHandler: ConnectionHandler?

internal init() {
self.buffer = allocator.buffer(capacity: 1024)
self.connectionHandler = ConnectionHandler(
inputBuffer: buffer,
urls: urls,
reconnectWait: reconnectWait,
maxReconnects: maxReconnects,
pingInterval: pingInteval,
auth: auth
)
}
}

Expand Down Expand Up @@ -62,6 +90,6 @@ extension Client {
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
}
return try await connectionHandler.subscribe(subject)

}
}
27 changes: 24 additions & 3 deletions Sources/NatsSwift/NatsClient/NatsClientOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ public class ClientOptions {
private var pingInterval: TimeInterval = 60.0
private var reconnectWait: TimeInterval = 2.0
private var maxReconnects: Int = 60
private var auth: Auth? = nil

public init() {}

public func urls(_ urls: [URL]) -> ClientOptions {
self.urls = urls
return self
}

public func url(_ url: URL) -> ClientOptions {
self.urls = [url]
return self
Expand All @@ -41,16 +42,36 @@ public class ClientOptions {
return self
}

public func username_and_password(_ username: String, _ password: String) -> ClientOptions {
if self.auth == nil {
self.auth = Auth(user: username, password: password)
} else {
self.auth?.user = username
self.auth?.password = password
}
return self
}

public func token(_ token: String) -> ClientOptions {
if self.auth == nil {
self.auth = Auth(token: token)
} else {
self.auth?.token = token
}
return self
}

public func build() -> Client {
let client = Client()
client.connectionHandler = ConnectionHandler(
inputBuffer: client.buffer,
urls: urls,
reconnectWait: reconnectWait,
maxReconnects: maxReconnects,
pingInterval: pingInterval
pingInterval: pingInterval,
auth: auth
)

return client
}
}
20 changes: 12 additions & 8 deletions Sources/NatsSwift/NatsConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class ConnectionHandler: ChannelInboundHandler {
internal let reconnectWait: UInt64
internal let maxReconnects: Int?
internal let pingInterval: TimeInterval

typealias InboundIn = ByteBuffer
internal var state: NatsState = .Pending
internal var subscriptions: [ UInt64: Subscription ]
internal var subscriptionCounter = SubscriptionCounter()
internal var serverInfo: ServerInfo?
internal var auth: Auth?
private var parseRemainder: Data?


Expand All @@ -37,7 +38,7 @@ class ConnectionHandler: ChannelInboundHandler {
// TODO(pp): errors in parser should trigger context.fireErrorCaught() which invokes errorCaught() and invokes reconnect
func channelReadComplete(context: ChannelHandlerContext) {
var inputChunk = Data(buffer: inputBuffer)

if let remainder = self.parseRemainder {
inputChunk.prepend(remainder)
}
Expand Down Expand Up @@ -99,16 +100,18 @@ class ConnectionHandler: ChannelInboundHandler {
}
inputBuffer.clear()
}
init(inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?, pingInterval: TimeInterval) {
init(inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?, pingInterval: TimeInterval, auth: Auth?) {
self.inputBuffer = self.allocator.buffer(capacity: 1024)
self.urls = urls
self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
self.inputBuffer = allocator.buffer(capacity: 1024)
self.subscriptions = [UInt64:Subscription]()
self.reconnectWait = UInt64(reconnectWait * 1_000_000_000)
self.maxReconnects = maxReconnects
self.auth = auth
self.pingInterval = pingInterval
}

internal var group: MultiThreadedEventLoopGroup
internal var channel: Channel?

Expand Down Expand Up @@ -146,7 +149,8 @@ class ConnectionHandler: ChannelInboundHandler {
// Wait for the first message after sending the connect request
}
self.serverInfo = info
let connect = ConnectInfo(verbose: false, pedantic: false, userJwt: nil, nkey: "", signature: nil, name: "", echo: true, lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false, user: "", pass: "", authToken: "", headers: true, noResponders: true)
// TODO(jrm): Add rest of auth here.
let connect = ConnectInfo(verbose: false, pedantic: false, userJwt: nil, nkey: "", signature: nil, name: "", echo: true, lang: self.lang, version: self.version, natsProtocol: .dynamic, tlsRequired: false, user: self.auth?.user ?? "", pass: self.auth?.password ?? "", authToken: self.auth?.token ?? "", headers: true, noResponders: true)

try await withCheckedThrowingContinuation { continuation in
self.connectionEstablishedContinuation = continuation
Expand All @@ -163,7 +167,7 @@ class ConnectionHandler: ChannelInboundHandler {
self.state = .Connected
logger.debug("connection established")
}

func channelActive(context: ChannelHandlerContext) {
logger.debug("TCP channel active")

Expand All @@ -175,19 +179,19 @@ class ConnectionHandler: ChannelInboundHandler {

handleDisconnect()
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
// TODO(pp): implement Close() on the connection and call it here
logger.debug("Encountered error on the channel: \(error)")
self.state = .Disconnected
handleReconnect()
}

func handleDisconnect() {
self.state = .Disconnected
handleReconnect()
}

func handleReconnect() {
Task {
var attempts = 0
Expand Down
81 changes: 79 additions & 2 deletions Tests/NatsSwiftTests/Integration/ConnectionTests.swift
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ class CoreNatsTests: XCTestCase {
("testPublish", testPublish),
("testPublishWithReply", testPublishWithReply),
("testSubscribe", testSubscribe),
("testConnect", testConnect)
("testConnect", testConnect),
("testReconnect", testReconnect),
("testUsernameAndPassword", testUsernameAndPassword),
("testTokenAuth", testTokenAuth)
]
var natsServer = NatsServer()

Expand Down Expand Up @@ -137,7 +140,7 @@ class CoreNatsTests: XCTestCase {
try client.publish(payload, subject: "foo")
}
}

for await msg in sub {
messagesReceived += 1
if messagesReceived == 10 {
Expand All @@ -148,4 +151,78 @@ class CoreNatsTests: XCTestCase {
// Check if the total number of messages received matches the number sent
XCTAssertEqual(20, messagesReceived, "Mismatch in the number of messages sent and received")
}

func testUsernameAndPassword() async throws {
logger.logLevel = .debug
let currentFile = URL(fileURLWithPath: #file)
// Navigate up to the Tests directory
let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent()
// Construct the path to the resource
let resourceURL = testsDir
.appendingPathComponent("Integration/Resources/creds.conf", isDirectory: false)
natsServer.start(cfg: resourceURL.path)
let client = ClientOptions()
.url(URL(string:natsServer.clientURL)!)
.username_and_password("derek", "s3cr3t")
.maxReconnects(5)
.build()
try await client.connect()
try client.publish("msg".data(using: .utf8)!, subject: "test")
try await client.flush()
try await client.subscribe(to: "test")
XCTAssertNotNil(client, "Client should not be nil")


// Test if client with bad credentials throws an error
let bad_creds_client = ClientOptions()
.url(URL(string:natsServer.clientURL)!)
.username_and_password("derek", "badpassword")
.maxReconnects(5)
.build()

do {
try await bad_creds_client.connect()
XCTFail("Should have thrown an error")
} catch {
XCTAssertNotNil(error, "Error should not be nil")
}

}

func testTokenAuth() async throws {
logger.logLevel = .debug
let currentFile = URL(fileURLWithPath: #file)
// Navigate up to the Tests directory
let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent()
// Construct the path to the resource
let resourceURL = testsDir
.appendingPathComponent("Integration/Resources/token.conf", isDirectory: false)
natsServer.start(cfg: resourceURL.path)
let client = ClientOptions()
.url(URL(string:natsServer.clientURL)!)
.token("s3cr3t")
.maxReconnects(5)
.build()
try await client.connect()
try client.publish("msg".data(using: .utf8)!, subject: "test")
try await client.flush()
try await client.subscribe(to: "test")
XCTAssertNotNil(client, "Client should not be nil")


// Test if client with bad credentials throws an error
let bad_creds_client = ClientOptions()
.url(URL(string:natsServer.clientURL)!)
.token("badtoken")
.maxReconnects(5)
.build()

do {
try await bad_creds_client.connect()
XCTFail("Should have thrown an error")
} catch {
XCTAssertNotNil(error, "Error should not be nil")
}

}
}
5 changes: 5 additions & 0 deletions Tests/NatsSwiftTests/Integration/Resources/creds.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
authorization {
user: derek
password: s3cr3t
}

4 changes: 4 additions & 0 deletions Tests/NatsSwiftTests/Integration/Resources/token.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
authorization {
token: s3cr3t
}