Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into 5d/api-behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
5d committed Apr 24, 2024
2 parents ccbfd66 + adf5a2e commit 3079ec7
Show file tree
Hide file tree
Showing 15 changed files with 367 additions and 29 deletions.
51 changes: 40 additions & 11 deletions Amplify/Core/Support/TaskQueue.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,25 @@
import Foundation

/// A helper for executing asynchronous work serially.
public actor TaskQueue<Success> {
private var previousTask: Task<Success, Error>?
public class TaskQueue<Success> {
typealias Block = @Sendable () async -> Void
private var streamContinuation: AsyncStream<Block>.Continuation!

public init() {}
public init() {
let stream = AsyncStream<Block>.init { continuation in
streamContinuation = continuation
}

Task {
for await block in stream {
_ = await block()
}
}
}

deinit {
streamContinuation.finish()
}

/// Serializes asynchronous requests made from an async context
///
Expand All @@ -25,17 +40,31 @@ public actor TaskQueue<Success> {
/// TaskQueue serializes this work so that `doAsync1` is performed before `doAsync2`,
/// which is performed before `doAsync3`.
public func sync(block: @Sendable @escaping () async throws -> Success) async throws -> Success {
let currentTask: Task<Success, Error> = Task { [previousTask] in
_ = await previousTask?.result
return try await block()
try await withCheckedThrowingContinuation { continuation in
streamContinuation.yield {
do {
let value = try await block()
continuation.resume(returning: value)
} catch {
continuation.resume(throwing: error)
}
}
}
previousTask = currentTask
return try await currentTask.value
}

public nonisolated func async(block: @Sendable @escaping () async throws -> Success) rethrows {
Task {
try await sync(block: block)
public func async(block: @Sendable @escaping () async throws -> Success) {
streamContinuation.yield {
do {
_ = try await block()
} catch {
Self.log.warn("Failed to handle async task in TaskQueue<\(Success.self)> with error: \(error)")
}
}
}
}

extension TaskQueue {
public static var log: Logger {
Amplify.Logging.logger(forNamespace: String(describing: self))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol {
self.state.value == .connected
}

internal var numberOfSubscriptions: Int {
self.subscriptions.count
}

/**
Creates a new AppSyncRealTimeClient with endpoint, requestInterceptor and webSocketClient.
- Parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ public class AWSGraphQLSubscriptionTaskRunner<R: Decodable>: InternalTaskRunner,
self.apiAuthProviderFactory = apiAuthProviderFactory
}

/// When the top-level AmplifyThrowingSequence is canceled, this cancel method is invoked.
/// In this situation, we need to send the disconnected event because
/// the top-level AmplifyThrowingSequence is terminated immediately upon cancellation.
public func cancel() {
self.send(GraphQLSubscriptionEvent<R>.connection(.disconnected))
Task { [weak self] in
guard let self else {
return
}
Task {
guard let appSyncClient = self.appSyncClient else {
return
}
Expand Down Expand Up @@ -213,12 +213,7 @@ final public class AWSGraphQLSubscriptionOperation<R: Decodable>: GraphQLSubscri

override public func cancel() {
super.cancel()

Task { [weak self] in
guard let self else {
return
}

Task {
guard let appSyncRealTimeClient = self.appSyncRealTimeClient else {
return
}
Expand Down Expand Up @@ -378,6 +373,31 @@ fileprivate func toAPIError<R: Decodable>(_ errors: [Error], type: R.Type) -> AP
(hasAuthorizationError ? ": \(APIError.UnauthorizedMessageString)" : "")
}

#if swift(<5.8)
if let errors = errors.cast(to: AppSyncRealTimeRequest.Error.self) {
let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized})
return APIError.operationError(
errorDescription(hasAuthorizationError),
"",
errors.first
)
} else if let errors = errors.cast(to: GraphQLError.self) {
let hasAuthorizationError = errors.map(\.extensions)
.compactMap { $0.flatMap { $0["errorType"]?.stringValue } }
.contains(where: { AppSyncErrorType($0) == .unauthorized })
return APIError.operationError(
errorDescription(hasAuthorizationError),
"",
GraphQLResponseError<R>.error(errors)
)
} else {
return APIError.operationError(
errorDescription(),
"",
errors.first
)
}
#else
switch errors {
case let errors as [AppSyncRealTimeRequest.Error]:
let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized})
Expand All @@ -402,5 +422,5 @@ fileprivate func toAPIError<R: Decodable>(_ errors: [Error], type: R.Type) -> AP
errors.first
)
}

#endif
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//


import Foundation

@_spi(AmplifyAPI)
extension Array where Element == Error {
func cast<T>(to type: T.Type) -> [T]? {
self.reduce([]) { partialResult, ele in
if let partialResult, let ele = ele as? T {
return partialResult + [ele]
}
return nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,65 @@ class GraphQLModelBasedTests: XCTestCase {
await fulfillment(of: [progressInvoked], timeout: TestCommonConstants.networkTimeout)
}


/// Given: Several subscriptions with Amplify API plugin
/// When: Cancel subscriptions
/// Then: AppSync real time client automatically unsubscribe and remove the subscription
func testCancelledSubscription_automaticallyUnsubscribeAndRemoved() async throws {
let numberOfSubscription = 5
let allSubscribedExpectation = expectation(description: "All subscriptions are subscribed")
allSubscribedExpectation.expectedFulfillmentCount = numberOfSubscription

let subscriptions = (0..<5).map { _ in
Amplify.API.subscribe(request: .subscription(of: Comment.self, type: .onCreate))
}
subscriptions.forEach { subscription in
Task {
do {
for try await subscriptionEvent in subscription {
switch subscriptionEvent {
case .connection(let state):
switch state {
case .connecting:
break
case .connected:
allSubscribedExpectation.fulfill()
case .disconnected:
break
}
case .data(let result):
switch result {
case .success: break
case .failure(let error):
XCTFail("\(error)")
}
}
}
} catch {
XCTFail("Unexpected subscription failure")
}
}
}

await fulfillment(of: [allSubscribedExpectation], timeout: 3)
if let appSyncRealTimeClientFactory =
getUnderlyingAPIPlugin()?.appSyncRealTimeClientFactory as? AppSyncRealTimeClientFactory,
let appSyncRealTimeClient =
await appSyncRealTimeClientFactory.apiToClientCache.values.first as? AppSyncRealTimeClient
{
var appSyncSubscriptions = await appSyncRealTimeClient.numberOfSubscriptions
XCTAssertEqual(appSyncSubscriptions, numberOfSubscription)

subscriptions.forEach { $0.cancel() }
try await Task.sleep(seconds: 2)
appSyncSubscriptions = await appSyncRealTimeClient.numberOfSubscriptions
XCTAssertEqual(appSyncSubscriptions, 0)

} else {
XCTFail("There should be at least one AppSyncRealTimeClient instance")
}
}

// MARK: Helpers

func createPost(id: String, title: String) async throws -> Post? {
Expand Down Expand Up @@ -499,4 +558,8 @@ class GraphQLModelBasedTests: XCTestCase {
throw error
}
}

func getUnderlyingAPIPlugin() -> AWSAPIPlugin? {
return Amplify.API.plugins["awsAPIPlugin"] as? AWSAPIPlugin
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//


import XCTest
@testable @_spi(AmplifyAPI) import AWSAPIPlugin

class ArrayWithErrorElementExtensionTests: XCTestCase {

/**
Given: errors with generic protocol type
When: cast to the correct underlying concrete type
Then: successfully casted to underlying concrete type
*/
func testCast_toCorrectErrorType_returnCastedErrorType() {
let errors: [Error] = [
Error1(), Error1(), Error1()
]

let error1s = errors.cast(to: Error1.self)
XCTAssertNotNil(error1s)
XCTAssertTrue(!error1s!.isEmpty)
XCTAssertEqual(errors.count, error1s!.count)
}

/**
Given: errors with generic protocol type
When: cast to the wong underlying concrete type
Then: return nil
*/
func testCast_toWrongErrorType_returnNil() {
let errors: [Error] = [
Error1(), Error1(), Error1()
]

let error2s = errors.cast(to: Error2.self)
XCTAssertNil(error2s)
}

/**
Given: errors with generic protocol type
When: some of the elements failed to cast to the underlying concrete type
Then: return nil
*/

func testCast_partiallyToWrongErrorType_returnNil() {
let errors: [Error] = [
Error2(), Error2(), Error1()
]

let error2s = errors.cast(to: Error2.self)
XCTAssertNil(error2s)
}

struct Error1: Error { }

struct Error2: Error { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,21 @@ class FetchAuthSessionOperationHelper: DefaultLogger {
}

case .service(let error):
if let authError = (error as? AuthErrorConvertible)?.authError {
let session = AWSAuthCognitoSession(isSignedIn: isSignedIn,
identityIdResult: .failure(authError),
awsCredentialsResult: .failure(authError),
cognitoTokensResult: .failure(authError))
return session
var authError: AuthError
if let convertedAuthError = (error as? AuthErrorConvertible)?.authError {
authError = convertedAuthError
} else {
authError = AuthError.service(
"Unknown service error occurred",
"See the attached error for more details",
error)
}
let session = AWSAuthCognitoSession(
isSignedIn: isSignedIn,
identityIdResult: .failure(authError),
awsCredentialsResult: .failure(authError),
cognitoTokensResult: .failure(authError))
return session
default: break

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,61 @@ class AWSAuthFetchSignInSessionOperationTests: BaseAuthorizationTests {
let identityId = try? (session as? AuthCognitoIdentityProvider)?.getIdentityId().get()
XCTAssertNotNil(identityId)
}

/// Test signedIn session with invalid response for aws credentials
///
/// - Given: Given an auth plugin with signedIn state
/// - When:
/// - I invoke fetchAuthSession and service throws NSError
/// - Then:
/// - I should get an a valid session with the following details:
/// - isSignedIn = true
/// - aws credentails = service error
/// - identity id = service error
/// - cognito tokens = service error
///
func testSignInSessionWithNSError() async throws {
let initialState = AuthState.configured(
AuthenticationState.signedIn(.testData),
AuthorizationState.sessionEstablished(
AmplifyCredentials.testDataWithExpiredTokens))

let initAuth: MockIdentityProvider.MockInitiateAuthResponse = { _ in
return InitiateAuthOutput(authenticationResult: .init(accessToken: "accessToken",
expiresIn: 1000,
idToken: "idToken",
refreshToken: "refreshToke"))
}

let awsCredentials: MockIdentity.MockGetCredentialsResponse = { _ in
throw NSError(domain: NSURLErrorDomain, code: 1, userInfo: nil)
}
let plugin = configurePluginWith(
userPool: { MockIdentityProvider(mockInitiateAuthResponse: initAuth) },
identityPool: { MockIdentity(mockGetCredentialsResponse: awsCredentials) },
initialState: initialState)

let session = try await plugin.fetchAuthSession(options: AuthFetchSessionRequest.Options())

XCTAssertTrue(session.isSignedIn)
let credentialsResult = (session as? AuthAWSCredentialsProvider)?.getAWSCredentials()
guard case .failure(let error) = credentialsResult, case .service = error else {
XCTFail("Should return service error")
return
}

let identityIdResult = (session as? AuthCognitoIdentityProvider)?.getIdentityId()
guard case .failure(let identityIdError) = identityIdResult,
case .service = identityIdError else {
XCTFail("Should return service error")
return
}

let tokensResult = (session as? AuthCognitoTokensProvider)?.getCognitoTokens()
guard case .failure(let tokenError) = tokensResult,
case .service = tokenError else {
XCTFail("Should return service error")
return
}
}
}
Loading

0 comments on commit 3079ec7

Please sign in to comment.