Skip to content

Commit 7423ebf

Browse files
committed
add unit test cases for websocket client
1 parent 3cdcb3d commit 7423ebf

19 files changed

+1990
-1172
lines changed

AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient.swift

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,6 @@ import Amplify
1111
import Combine
1212
@_spi(AmplifySwift) import AWSPluginsCore
1313

14-
protocol AppSyncRequestInterceptor {
15-
func interceptRequest(event: AppSyncRealTimeRequest, url: URL) async -> AppSyncRealTimeRequest
16-
}
17-
18-
protocol AppSyncWebSocketClientProtocol {
19-
var isConnected: Bool { get async }
20-
var publisher: AnyPublisher<WebSocketEvent, Never> { get async }
21-
22-
func connect(
23-
autoConnectOnNetworkStatusChange: Bool,
24-
autoRetryOnConnectionFailure: Bool
25-
) async
26-
27-
func disconnect() async
28-
29-
func write(message: String) async throws
30-
}
31-
32-
extension WebSocketClient: AppSyncWebSocketClientProtocol { }
33-
3414
actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
3515

3616
static let jsonEncoder = JSONEncoder()
@@ -153,6 +133,8 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
153133
private func startSubscription(id: String, query: String) async throws -> AnyCancellable {
154134
log.debug("[AppSyncRealTimeClient] Starting subscription request \(id), query: \(query)")
155135

136+
// TODO: (5d) it seems the current implementation is no retry on request level
137+
// we just pass down the errors to subscribers to handle
156138
try await RetryWithJitter.execute { [weak self] in
157139
guard let self else { return }
158140
try await Self.sendRequestWithTimeout(
@@ -232,11 +214,13 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
232214
private func filterAppSyncSubscriptionEvent(
233215
with id: String
234216
) -> AnyPublisher<AppSyncSubscriptionEvent, Never> {
235-
subject.filter { $0.id == id }
217+
subject.filter { $0.id == id || $0.type == .connectionError }
236218
.map { response -> AppSyncSubscriptionEvent? in
237219
switch response.type {
238220
case .startAck: return .subscribed
239221
case .stopAck: return .unsubscribed
222+
case .connectionError:
223+
return .error(Self.decodeURLErrors(response.payload))
240224
case .error:
241225
return .error(Self.decodeGraphQLErrors(response.payload))
242226
case .data:
@@ -250,6 +234,20 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
250234
.eraseToAnyPublisher()
251235
}
252236

237+
private static func decodeURLErrors(_ data: JSONValue?) -> [Error] {
238+
guard let errors = data?.errors?.asArray else {
239+
return []
240+
}
241+
242+
return errors.flatMap { error -> [Error] in
243+
guard let code = error.errorCode?.intValue else {
244+
return []
245+
}
246+
let description = error.errorType?.stringValue ?? ""
247+
return [URLError(URLError.Code(rawValue: code), userInfo: ["description": description])]
248+
}
249+
}
250+
253251
private static func decodeGraphQLErrors(_ data: JSONValue?) -> [Error] {
254252
do {
255253
return try GraphQLErrorDecoder.decodeAppSyncErrors(data)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
9+
import Foundation
10+
11+
protocol AppSyncRequestInterceptor {
12+
func interceptRequest(event: AppSyncRealTimeRequest, url: URL) async -> AppSyncRealTimeRequest
13+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
9+
import Foundation
10+
import Combine
11+
@_spi(AmplifySwift) import AWSPluginsCore
12+
13+
protocol AppSyncWebSocketClientProtocol {
14+
var isConnected: Bool { get async }
15+
var publisher: AnyPublisher<WebSocketEvent, Never> { get async }
16+
17+
func connect(
18+
autoConnectOnNetworkStatusChange: Bool,
19+
autoRetryOnConnectionFailure: Bool
20+
) async
21+
22+
func disconnect() async
23+
24+
func write(message: String) async throws
25+
}
26+
27+
extension WebSocketClient: AppSyncWebSocketClientProtocol { }
28+

AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public class AWSGraphQLSubscriptionTaskRunner<R: Decodable>: InternalTaskRunner,
4545
}
4646

4747
public func cancel() {
48+
self.send(GraphQLSubscriptionEvent<R>.connection(.disconnected))
4849
Task { [weak self] in
4950
guard let self else {
5051
return
@@ -54,8 +55,6 @@ public class AWSGraphQLSubscriptionTaskRunner<R: Decodable>: InternalTaskRunner,
5455
}
5556
do {
5657
try await appSyncClient.unsubscribe(id: self.subscriptionId)
57-
let subscriptionEvent = GraphQLSubscriptionEvent<R>.connection(.disconnected)
58-
self.send(subscriptionEvent)
5958
} catch {
6059
print("Failed to unsubscribe \(self.subscriptionId)")
6160
}
@@ -224,7 +223,8 @@ final public class AWSGraphQLSubscriptionOperation<R: Decodable>: GraphQLSubscri
224223
}
225224

226225
override public func cancel() {
227-
let superCancel = super.cancel
226+
super.cancel()
227+
228228
Task { [weak self] in
229229
guard let self else {
230230
return
@@ -236,12 +236,6 @@ final public class AWSGraphQLSubscriptionOperation<R: Decodable>: GraphQLSubscri
236236

237237
do {
238238
try await appSyncRealTimeClient.unsubscribe(id: subscriptionId)
239-
let subscriptionEvent = GraphQLSubscriptionEvent<R>.connection(.disconnected)
240-
dispatchInProcess(data: subscriptionEvent)
241-
242-
243-
dispatch(result: .successfulVoid)
244-
superCancel()
245239
finish()
246240
} catch {
247241
print("Failed to unsubscribe \(error)")
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
////
2-
//// Copyright Amazon.com Inc. or its affiliates.
3-
//// All Rights Reserved.
4-
////
5-
//// SPDX-License-Identifier: Apache-2.0
6-
////
71
//
8-
//import XCTest
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
94
//
10-
//// swiftlint:disable:next type_name
11-
//class AWSAPICategoryPluginURLSessionBehaviorDelegateTests: AWSAPICategoryPluginTestBase {
12-
// func testClassMustNotBeEmptyOrSwiftFormatWillCrash() {
13-
// // TODO implement code
14-
// }
15-
//}
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
import XCTest
9+
10+
// swiftlint:disable:next type_name
11+
class AWSAPICategoryPluginURLSessionBehaviorDelegateTests: AWSAPICategoryPluginTestBase {
12+
func testClassMustNotBeEmptyOrSwiftFormatWillCrash() {
13+
// TODO implement code
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
9+
import XCTest
10+
import Amplify
11+
@testable import AWSAPIPlugin
12+
13+
class APIKeyAuthInterceptorTests: XCTestCase {
14+
15+
func testInterceptConnection_addApiKeySignatureInURLQuery() async {
16+
let apiKey = UUID().uuidString
17+
let interceptor = APIKeyAuthInterceptor(apiKey: apiKey)
18+
let resultUrl = await interceptor.interceptConnection(url: URL(string: "https://example.com")!)
19+
guard let components = URLComponents(url: resultUrl, resolvingAgainstBaseURL: false) else {
20+
XCTFail("Failed to decode decorated URL")
21+
return
22+
}
23+
24+
let header = components.queryItems?.first { $0.name == "header" }
25+
XCTAssertNotNil(header?.value)
26+
let headerData = try! header?.value!.base64DecodedString().data(using: .utf8)
27+
let decodedHeader = try! JSONDecoder().decode(JSONValue.self, from: headerData!)
28+
XCTAssertEqual(decodedHeader["x-api-key"]?.stringValue, apiKey)
29+
}
30+
31+
func testInterceptRequest_appendAuthInfoInPayload() async {
32+
let apiKey = UUID().uuidString
33+
let interceptor = APIKeyAuthInterceptor(apiKey: apiKey)
34+
let decoratedRequest = await interceptor.interceptRequest(
35+
event: AppSyncRealTimeRequest.start(.init(
36+
id: UUID().uuidString,
37+
data: "",
38+
auth: nil
39+
)),
40+
url: URL(string: "https://example.appsync-realtime-api.amazonaws.com")!
41+
)
42+
guard case let .start(request) = decoratedRequest else {
43+
XCTFail("Request should be a start request")
44+
return
45+
}
46+
47+
XCTAssertNotNil(request.auth)
48+
guard case let .apiKey(apiKeyInfo) = request.auth! else {
49+
XCTFail("Auth should be api key")
50+
return
51+
}
52+
53+
XCTAssertEqual(apiKeyInfo.apiKey, apiKey)
54+
XCTAssertEqual(apiKeyInfo.host, "example.appsync-api.amazonaws.com")
55+
}
56+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//
2+
// Copyright Amazon.com Inc. or its affiliates.
3+
// All Rights Reserved.
4+
//
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
8+
9+
import XCTest
10+
import Amplify
11+
@testable import AWSAPIPlugin
12+
@testable @_spi(AmplifySwift) import AWSPluginsCore
13+
14+
class CognitoAuthInterceptorTests: XCTestCase {
15+
16+
func testInterceptConnection() async {
17+
let authTokenProvider = MockAuthTokenProvider()
18+
let interceptor = CognitoAuthInterceptor(authTokenProvider: authTokenProvider)
19+
20+
let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!)
21+
guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else {
22+
XCTFail("Failed to get url components from decorated URL")
23+
return
24+
}
25+
26+
guard let queryHeaderString =
27+
try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString()
28+
else {
29+
XCTFail("Failed to extract header field from query string")
30+
return
31+
}
32+
33+
guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!)
34+
else {
35+
XCTFail("Failed to decode query header to json object")
36+
return
37+
}
38+
XCTAssertEqual(authTokenProvider.authToken, queryHeader.Authorization?.stringValue)
39+
XCTAssertEqual("example.com", queryHeader.host?.stringValue)
40+
}
41+
42+
func testInterceptRequest() async {
43+
let authTokenProvider = MockAuthTokenProvider()
44+
let interceptor = CognitoAuthInterceptor(authTokenProvider: authTokenProvider)
45+
let decoratedRequest = await interceptor.interceptRequest(
46+
event: .start(.init(id: UUID().uuidString, data: UUID().uuidString, auth: nil)),
47+
url: URL(string: "https://example.com")!
48+
)
49+
50+
guard case let .start(decoratedAuth) = decoratedRequest else {
51+
XCTFail("Failed to extract decoratedAuth info")
52+
return
53+
}
54+
55+
guard case let .some(.cognito(authInfo)) = decoratedAuth.auth else {
56+
XCTFail("Failed to extract authInfo from decoratedAuth")
57+
return
58+
}
59+
60+
XCTAssertEqual(authTokenProvider.authToken, authInfo.authToken)
61+
XCTAssertEqual("example.com", authInfo.host)
62+
}
63+
}
64+
65+
fileprivate class MockAuthTokenProvider: AmplifyAuthTokenProvider {
66+
let authToken = UUID().uuidString
67+
func getLatestAuthToken() async throws -> String {
68+
return authToken
69+
}
70+
}

AmplifyPlugins/API/Tests/AWSAPIPluginTests/Mocks/MockSubscription.swift

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct MockSubscriptionConnectionFactory: AppSyncRealTimeClientFactoryProtocol {
2121
AWSAuthServiceBehavior,
2222
AWSAuthorizationType?,
2323
APIAuthProviderFactory
24-
) async throws -> AppSyncRealTimeClient
24+
) async throws -> AppSyncRealTimeClientProtocol
2525

2626
let onGetOrCreateConnection: OnGetOrCreateConnection
2727

@@ -40,35 +40,46 @@ struct MockSubscriptionConnectionFactory: AppSyncRealTimeClientFactoryProtocol {
4040
}
4141
}
4242

43-
struct MockAppSyncRealTimeClient {
44-
// typealias OnSubscribe = (
45-
// String,
46-
// [String: Any?]?,
47-
// @escaping SubscriptionEventHandler
48-
// ) -> SubscriptionItem
49-
//
50-
// typealias OnUnsubscribe = (SubscriptionItem) -> Void
51-
//
52-
// let onSubscribe: OnSubscribe
53-
// let onUnsubscribe: OnUnsubscribe
54-
//
55-
// init(onSubscribe: @escaping OnSubscribe, onUnsubscribe: @escaping OnUnsubscribe) {
56-
// self.onSubscribe = onSubscribe
57-
// self.onUnsubscribe = onUnsubscribe
58-
// }
59-
//
60-
// func subscribe(
61-
// requestString: String,
62-
// variables: [String: Any?]?,
63-
// eventHandler: @escaping SubscriptionEventHandler
64-
// ) -> SubscriptionItem {
65-
// onSubscribe(requestString, variables, eventHandler)
66-
// }
67-
//
68-
// func unsubscribe(item: SubscriptionItem) {
69-
// onUnsubscribe(item)
70-
// }
43+
struct MockAppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
44+
private let subject = PassthroughSubject<AppSyncSubscriptionEvent, Never>()
7145

46+
func subscribe(id: String, query: String) async throws -> AnyPublisher<AppSyncSubscriptionEvent, Never> {
47+
defer {
48+
49+
Task {
50+
try await Task.sleep(seconds: 0.25)
51+
subject.send(.subscribing)
52+
try await Task.sleep(seconds: 0.45)
53+
subject.send(.subscribed)
54+
}
55+
}
56+
return subject.eraseToAnyPublisher()
57+
}
58+
59+
func unsubscribe(id: String) async throws {
60+
try await Task.sleep(seconds: 0.45)
61+
subject.send(.unsubscribed)
62+
}
63+
64+
func connect() async throws { }
65+
66+
func disconnect() async { }
67+
68+
func triggerEvent(_ event: AppSyncSubscriptionEvent) {
69+
subject.send(event)
70+
}
71+
72+
static func waitForSubscirbing() async throws {
73+
try await Task.sleep(seconds: 0.3)
74+
}
75+
76+
static func waitForSubscirbed() async throws {
77+
try await Task.sleep(seconds: 0.5)
78+
}
79+
80+
static func waitForUnsubscirbed() async throws {
81+
try await Task.sleep(seconds: 0.5)
82+
}
7283
}
7384

7485
class MockAppSyncRequestInterceptor: AppSyncRequestInterceptor {

0 commit comments

Comments
 (0)