From 7740b5104e4c9a42f66af96d73d9529680d698c1 Mon Sep 17 00:00:00 2001 From: Di Wu Date: Wed, 18 Sep 2024 11:45:55 -0700 Subject: [PATCH] fix(api): append auth info as head fields for appSync realtime handshake request --- .../AppSyncRealTimeRequestAuth.swift | 34 +++--------- .../APIKeyAuthInterceptor.swift | 9 +-- .../AuthTokenInterceptor.swift | 11 ++-- .../IAMAuthInterceptor.swift | 11 ++-- .../AppSyncRealTimeClientFactory.swift | 12 ++-- .../Utils/URLRequest+AppSyncAuth.swift | 17 ++++++ .../AppSyncRealTimeRequestAuthTests.swift | 55 ------------------- .../APIKeyAuthInterceptorTests.swift | 17 ++---- .../CognitoAuthInterceptorTests.swift | 48 +++------------- .../AppSyncRealTimeClientFactoryTests.swift | 30 +++++++++- .../WebSocket/WebSocketClient.swift | 14 +---- .../WebSocket/WebSocketInterceptor.swift | 14 +++++ 12 files changed, 104 insertions(+), 168 deletions(-) create mode 100644 AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift index 87e01b1842..f649c3c380 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift @@ -9,6 +9,9 @@ import Foundation public enum AppSyncRealTimeRequestAuth { + private static let jsonEncoder = JSONEncoder() + private static let jsonDecoder = JSONDecoder() + case authToken(AuthToken) case apiKey(ApiKey) case iam(IAM) @@ -31,33 +34,10 @@ public enum AppSyncRealTimeRequestAuth { let amzDate: String } - public struct URLQuery { - let header: AppSyncRealTimeRequestAuth - let payload: String - - init(header: AppSyncRealTimeRequestAuth, payload: String = "{}") { - self.header = header - self.payload = payload - } - - func withBaseURL(_ url: URL, encoder: JSONEncoder? = nil) -> URL { - let jsonEncoder: JSONEncoder = encoder ?? JSONEncoder() - guard let headerJsonData = try? jsonEncoder.encode(header) else { - return url - } - - guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) - else { - return url - } - - urlComponents.queryItems = [ - URLQueryItem(name: "header", value: headerJsonData.base64EncodedString()), - URLQueryItem(name: "payload", value: try? payload.base64EncodedString()) - ] - - return urlComponents.url ?? url - } + var authHeaders: [String: String] { + (try? Self.jsonEncoder.encode(self)).flatMap { + try? Self.jsonDecoder.decode([String: String].self, from: $0) + } ?? [:] } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift index f52ded490e..5906f7567e 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift @@ -21,11 +21,12 @@ class APIKeyAuthInterceptor { } extension APIKeyAuthInterceptor: WebSocketInterceptor { - func interceptConnection(url: URL) async -> URL { + + func interceptConnection(request: URLRequest) async -> URLRequest { + guard let url = request.url else { return request } + let authHeader = getAuthHeader(apiKey, AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!) - return AppSyncRealTimeRequestAuth.URLQuery( - header: .apiKey(authHeader) - ).withBaseURL(url) + return request.injectAppSyncAuthToRequestHeader(auth: .apiKey(authHeader)) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift index b0f19ffd78..95f96698b1 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift @@ -57,15 +57,16 @@ extension AuthTokenInterceptor: AppSyncRequestInterceptor { } extension AuthTokenInterceptor: WebSocketInterceptor { - func interceptConnection(url: URL) async -> URL { + func interceptConnection(request: URLRequest) async -> URLRequest { + guard let url = request.url else { return request } let authToken = await getAuthToken() - return AppSyncRealTimeRequestAuth.URLQuery( - header: .authToken(.init( + return request.injectAppSyncAuthToRequestHeader( + auth: .authToken(.init( host: AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!, authToken: authToken - )) - ).withBaseURL(url) + ) + )) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift index cd023676c7..511f7925f6 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift @@ -88,15 +88,14 @@ class IAMAuthInterceptor { } extension IAMAuthInterceptor: WebSocketInterceptor { - func interceptConnection(url: URL) async -> URL { + + func interceptConnection(request: URLRequest) async -> URLRequest { + guard let url = request.url else { return request } let connectUrl = AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).appendingPathComponent("connect") guard let authHeader = await getAuthHeader(connectUrl, with: "{}") else { - return connectUrl + return request } - - return AppSyncRealTimeRequestAuth.URLQuery( - header: .iam(authHeader) - ).withBaseURL(url) + return request.injectAppSyncAuthToRequestHeader(auth: .iam(authHeader)) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift index 57a3708e1e..802bb142eb 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift @@ -127,9 +127,9 @@ actor AppSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol { extension AppSyncRealTimeClientFactory { /** - Converting appsync api url to realtime api url - 1. api.example.com/graphql -> api.example.com/graphql/realtime - 2. abc.appsync-api.us-east-1.amazonaws.com/graphql -> abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql + Converting appsync api url to realtime api url, realtime endpoint has scheme 'wss' + 1. api.example.com/graphql -> wss://api.example.com/graphql/realtime + 2. abc.appsync-api.us-east-1.amazonaws.com/graphql -> wss://abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql */ static func appSyncRealTimeEndpoint(_ url: URL) -> URL { guard let host = url.host else { @@ -145,6 +145,7 @@ extension AppSyncRealTimeClientFactory { } urlComponents.host = host.replacingOccurrences(of: "appsync-api", with: "appsync-realtime-api") + urlComponents.scheme = "wss" guard let realTimeUrl = urlComponents.url else { return url } @@ -153,9 +154,9 @@ extension AppSyncRealTimeClientFactory { } /** - Converting appsync realtime api url to api url + Converting appsync realtime api url to api url, api endpoint has scheme 'https' 1. api.example.com/graphql/realtime -> api.example.com/graphql - 2. abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql -> abc.appsync-api.us-east-1.amazonaws.com/graphql + 2. abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql -> https://abc.appsync-api.us-east-1.amazonaws.com/graphql */ static func appSyncApiEndpoint(_ url: URL) -> URL { guard let host = url.host else { @@ -174,6 +175,7 @@ extension AppSyncRealTimeClientFactory { } urlComponents.host = host.replacingOccurrences(of: "appsync-realtime-api", with: "appsync-api") + urlComponents.scheme = "https" guard let apiUrl = urlComponents.url else { return url } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift new file mode 100644 index 0000000000..c186934f8d --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift @@ -0,0 +1,17 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +extension URLRequest { + func injectAppSyncAuthToRequestHeader(auth: AppSyncRealTimeRequestAuth) -> URLRequest { + var requstCopy = self + auth.authHeaders.forEach { requstCopy.setValue($0.value, forHTTPHeaderField: $0.key) } + return requstCopy + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift index 6ab7af0692..e9f4061431 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift @@ -147,61 +147,6 @@ class AppSyncRealTimeRequestAuthTests: XCTestCase { """.shrink()) } - func testAppSyncRealTimeRequestAuth_URLQueryWithCognitoAuthHeader() { - let expectedURL = """ - https://example.com?\ - header=eyJBdXRob3JpemF0aW9uIjoiNDk4NTljN2MtNzQwNS00ZDU4LWFmZjctNTJiZ\ - TRiNDczNTU3IiwiaG9zdCI6ImV4YW1wbGUuY29tIn0%3D\ - &payload=e30%3D - """ - let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( - header: .authToken(.init( - host: "example.com", - authToken: "49859c7c-7405-4d58-aff7-52be4b473557" - )) - ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) - XCTAssertEqual(encodedURL.absoluteString, expectedURL) - } - - func testAppSyncRealTimeRequestAuth_URLQueryWithApiKeyAuthHeader() { - let expectedURL = """ - https://example.com?\ - header=eyJob3N0IjoiZXhhbXBsZS5jb20iLCJ4LWFtei1kYXRlIjoiOWUwZTJkZjktMmVlNy00NjU5L\ - TgzNjItMWM4ODFlMTE4YzlmIiwieC1hcGkta2V5IjoiNjVlMmZhY2EtOGUxZS00ZDM3LThkYzctNjQ0N\ - 2Q5Njk4MjQ3In0%3D\ - &payload=e30%3D - """ - let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( - header: .apiKey(.init( - host: "example.com", - apiKey: "65e2faca-8e1e-4d37-8dc7-6447d9698247", - amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f" - )) - ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) - XCTAssertEqual(encodedURL.absoluteString, expectedURL) - } - - func testAppSyncRealTimeRequestAuth_URLQueryWithIAMAuthHeader() { - - let expectedURL = """ - https://example.com?\ - header=eyJhY2NlcHQiOiJhcHBsaWNhdGlvblwvanNvbiwgdGV4dFwvamF2YXNjcmlwdCIsIkF1dGhvcml6YXR\ - pb24iOiJjOWRhZDg5Ny05MGQxLTRhNGMtYTVjOS0yYjM2YTI0NzczNWYiLCJjb250ZW50LWVuY29kaW5nIjoiY\ - W16LTEuMCIsImNvbnRlbnQtdHlwZSI6ImFwcGxpY2F0aW9uXC9qc29uOyBjaGFyc2V0PVVURi04IiwiaG9zdCI\ - 6ImV4YW1wbGUuY29tIiwieC1hbXotZGF0ZSI6IjllMGUyZGY5LTJlZTctNDY1OS04MzYyLTFjODgxZTExOGM5Z\ - iIsIlgtQW16LVNlY3VyaXR5LVRva2VuIjoiZTdlNjI2OWUtZmRhMS00ZGUwLThiZGItYmFhN2I2ZGQwYTBkIn0%3D\ - &payload=e30%3D - """ - let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( - header: .iam(.init( - host: "example.com", - authToken: "c9dad897-90d1-4a4c-a5c9-2b36a247735f", - securityToken: "e7e6269e-fda1-4de0-8bdb-baa7b6dd0a0d", - amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f")) - ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) - XCTAssertEqual(encodedURL.absoluteString, expectedURL) - } - private func toJson(_ value: Encodable) -> String? { return try? String(data: jsonEncoder.encode(value), encoding: .utf8) } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift index 8c89c0a53a..7c8ebff620 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift @@ -12,20 +12,13 @@ import Amplify class APIKeyAuthInterceptorTests: XCTestCase { - func testInterceptConnection_addApiKeySignatureInURLQuery() async { + func testInterceptConnection_addApiKeyInRequestHeader() async { let apiKey = UUID().uuidString let interceptor = APIKeyAuthInterceptor(apiKey: apiKey) - let resultUrl = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) - guard let components = URLComponents(url: resultUrl, resolvingAgainstBaseURL: false) else { - XCTFail("Failed to decode decorated URL") - return - } - - let header = components.queryItems?.first { $0.name == "header" } - XCTAssertNotNil(header?.value) - let headerData = try! header?.value!.base64DecodedString().data(using: .utf8) - let decodedHeader = try! JSONDecoder().decode(JSONValue.self, from: headerData!) - XCTAssertEqual(decodedHeader["x-api-key"]?.stringValue, apiKey) + let resultUrlRequest = await interceptor.interceptConnection(request: URLRequest(url: URL(string: "https://example.com")!)) + + let header = resultUrlRequest.value(forHTTPHeaderField: "x-api-key") + XCTAssertEqual(header, apiKey) } func testInterceptRequest_appendAuthInfoInPayload() async { diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift index 4127f018fd..d0383bff21 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift @@ -13,56 +13,24 @@ import Amplify class CognitoAuthInterceptorTests: XCTestCase { - func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeaderToQuery() async { + func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeader() async { let authTokenProvider = MockAuthTokenProvider() let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) - let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) - guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else { - XCTFail("Failed to get url components from decorated URL") - return - } + let decoratedURLRequest = await interceptor.interceptConnection(request: URLRequest(url:URL(string: "https://example.com")!)) - guard let queryHeaderString = - try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString() - else { - XCTFail("Failed to extract header field from query string") - return - } - - guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!) - else { - XCTFail("Failed to decode query header to json object") - return - } - XCTAssertEqual(authTokenProvider.authToken, queryHeader.Authorization?.stringValue) - XCTAssertEqual("example.com", queryHeader.host?.stringValue) + XCTAssertEqual(authTokenProvider.authToken, decoratedURLRequest.value(forHTTPHeaderField: "Authorization")) + XCTAssertEqual("example.com", decoratedURLRequest.value(forHTTPHeaderField: "host")) } - func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeaderToQuery() async { + func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeader() async { let authTokenProvider = MockAuthTokenProviderFailed() let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) - let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) - guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else { - XCTFail("Failed to get url components from decorated URL") - return - } + let decoratedURLRequest = await interceptor.interceptConnection(request: URLRequest(url:URL(string: "https://example.com")!)) - guard let queryHeaderString = - try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString() - else { - XCTFail("Failed to extract header field from query string") - return - } - - guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!) - else { - XCTFail("Failed to decode query header to json object") - return - } - XCTAssertEqual("", queryHeader.Authorization?.stringValue) - XCTAssertEqual("example.com", queryHeader.host?.stringValue) + XCTAssertEqual("", decoratedURLRequest.value(forHTTPHeaderField: "Authorization")) + XCTAssertEqual("example.com", decoratedURLRequest.value(forHTTPHeaderField: "host")) } func testInterceptRequest_withAuthTokenProvider_appendCorrectAuthInfoToPayload() async { diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift index 7156ac7678..15ca8c7858 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift @@ -15,15 +15,15 @@ class AppSyncRealTimeClientFactoryTests: XCTestCase { let appSyncEndpoint = URL(string: "https://abc.appsync-api.amazonaws.com/graphql")! XCTAssertEqual( AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint), - URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql") + URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql") ) } func testAppSyncRealTimeEndpoint_withAWSAppSyncRealTimeDomain_returnTheSameDomain() { - let appSyncEndpoint = URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql")! + let appSyncEndpoint = URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql")! XCTAssertEqual( AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint), - URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql") + URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql") ) } @@ -34,4 +34,28 @@ class AppSyncRealTimeClientFactoryTests: XCTestCase { URL(string: "https://test.example.com/graphql/realtime") ) } + + func testAppSyncApiEndpoint_withAWSAppSyncRealTimeDomain_returnCorrectApiDomain() { + let appSyncEndpoint = URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql")! + XCTAssertEqual( + AppSyncRealTimeClientFactory.appSyncApiEndpoint(appSyncEndpoint), + URL(string: "https://abc.appsync-api.amazonaws.com/graphql") + ) + } + + func testAppSyncApiEndpoint_withAWSAppSyncApiDomain_returnTheSameDomain() { + let appSyncEndpoint = URL(string: "https://abc.appsync-api.amazonaws.com/graphql")! + XCTAssertEqual( + AppSyncRealTimeClientFactory.appSyncApiEndpoint(appSyncEndpoint), + URL(string: "https://abc.appsync-api.amazonaws.com/graphql") + ) + } + + func testAppSyncApiEndpoint_withCustomDomain_returnCorrectRealtimePath() { + let appSyncEndpoint = URL(string: "https://test.example.com/graphql")! + XCTAssertEqual( + AppSyncRealTimeClientFactory.appSyncApiEndpoint(appSyncEndpoint), + URL(string: "https://test.example.com/graphql") + ) + } } diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift index cc1149ac27..e2e8c85503 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift @@ -72,7 +72,7 @@ public final actor WebSocketClient: NSObject { interceptor: WebSocketInterceptor? = nil, networkMonitor: WebSocketNetworkMonitorProtocol = AmplifyNetworkMonitor() ) { - self.url = Self.useWebSocketProtocolScheme(url: url) + self.url = url self.handshakeHttpHeaders = handshakeHttpHeaders self.interceptor = interceptor self.autoConnectOnNetworkStatusChange = false @@ -160,6 +160,8 @@ public final actor WebSocketClient: NSObject { var urlRequest = URLRequest(url: decoratedURL) self.handshakeHttpHeaders.forEach { urlRequest.setValue($0.value, forHTTPHeaderField: $0.key) } + urlRequest = await self.interceptor?.interceptConnection(request: urlRequest) ?? urlRequest + let urlSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) return urlSession.webSocketTask(with: urlRequest) } @@ -345,16 +347,6 @@ extension WebSocketClient { } } -extension WebSocketClient { - static func useWebSocketProtocolScheme(url: URL) -> URL { - guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) else { - return url - } - urlComponents.scheme = urlComponents.scheme == "http" ? "ws" : "wss" - return urlComponents.url ?? url - } -} - extension WebSocketClient: DefaultLogger { public static var log: Logger { Amplify.Logging.logger(forNamespace: String(describing: self)) diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift index a53ec3b950..351119ff03 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift @@ -11,4 +11,18 @@ import Foundation @_spi(WebSocket) public protocol WebSocketInterceptor { func interceptConnection(url: URL) async -> URL + + func interceptConnection(request: URLRequest) async -> URLRequest +} + +public extension WebSocketInterceptor { + + func interceptConnection(url: URL) async -> URL { + return url + } + + func interceptConnection(request: URLRequest) async -> URLRequest { + return request + } + }