From d4b957d5bf268096117b2aece892ffb4fd88b0f0 Mon Sep 17 00:00:00 2001 From: Abhash Kumar Singh Date: Fri, 3 Nov 2023 13:34:49 -0700 Subject: [PATCH] fix(datastore): multi auth rule for read subscription (#3316) * fix(datastore): multi auth rule for read subscription * Address review comments --- .../Model/Internal/Schema/AuthRule.swift | 4 +++ .../Auth/AWSAuthModeStrategy.swift | 33 ++++++++++++++++--- .../Auth/AuthModeStrategyTests.swift | 31 ++++++++++++++--- ...omingAsyncSubscriptionEventPublisher.swift | 6 ++-- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/Amplify/Categories/DataStore/Model/Internal/Schema/AuthRule.swift b/Amplify/Categories/DataStore/Model/Internal/Schema/AuthRule.swift index 5c5027aff2..8a5c6b4aeb 100644 --- a/Amplify/Categories/DataStore/Model/Internal/Schema/AuthRule.swift +++ b/Amplify/Categories/DataStore/Model/Internal/Schema/AuthRule.swift @@ -68,3 +68,7 @@ public struct AuthRule { self.operations = operations } } + +extension AuthRule: Hashable { + +} diff --git a/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthModeStrategy.swift b/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthModeStrategy.swift index a1355ced5a..1f3edd8786 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthModeStrategy.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/Auth/AWSAuthModeStrategy.swift @@ -42,6 +42,8 @@ public protocol AuthModeStrategy: AnyObject { init() func authTypesFor(schema: ModelSchema, operation: ModelOperation) async -> AWSAuthorizationTypeIterator + + func authTypesFor(schema: ModelSchema, operations: [ModelOperation]) async -> AWSAuthorizationTypeIterator } /// AuthorizationType iterator with an extra `count` property used @@ -93,6 +95,11 @@ public class AWSDefaultAuthModeStrategy: AuthModeStrategy { operation: ModelOperation) -> AWSAuthorizationTypeIterator { return AWSAuthorizationTypeIterator(withValues: []) } + + public func authTypesFor(schema: ModelSchema, + operations: [ModelOperation]) -> AWSAuthorizationTypeIterator { + return AWSAuthorizationTypeIterator(withValues: []) + } } // MARK: - AWSMultiAuthModeStrategy @@ -188,19 +195,35 @@ public class AWSMultiAuthModeStrategy: AuthModeStrategy { /// - Returns: an iterator for the applicable auth rules public func authTypesFor(schema: ModelSchema, operation: ModelOperation) async -> AWSAuthorizationTypeIterator { - var applicableAuthRules = schema.authRules - .filter(modelOperation: operation) + return await authTypesFor(schema: schema, operations: [operation]) + } + + /// Returns the union of authorization types for the provided schema for the given list of operations + /// - Parameters: + /// - schema: model schema + /// - operations: model operations + /// - Returns: an iterator for the applicable auth rules + public func authTypesFor(schema: ModelSchema, + operations: [ModelOperation]) async -> AWSAuthorizationTypeIterator { + var sortedRules = operations + .flatMap { schema.authRules.filter(modelOperation: $0) } + .reduce(into: [AuthRule](), { array, rule in + if !array.contains(rule) { + array.append(rule) + } + }) .sorted(by: AWSMultiAuthModeStrategy.comparator) - + // if there isn't a user signed in, returns only public or custom rules if let authDelegate = authDelegate, await !authDelegate.isUserLoggedIn() { - applicableAuthRules = applicableAuthRules.filter { rule in + sortedRules = sortedRules.filter { rule in return rule.allow == .public || rule.allow == .custom } } - let applicableAuthTypes = applicableAuthRules.map { + let applicableAuthTypes = sortedRules.map { AWSMultiAuthModeStrategy.authTypeFor(authRule: $0) } return AWSAuthorizationTypeIterator(withValues: applicableAuthTypes) } + } diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Auth/AuthModeStrategyTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Auth/AuthModeStrategyTests.swift index 12ecb27a6d..304aac6add 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Auth/AuthModeStrategyTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Auth/AuthModeStrategyTests.swift @@ -88,9 +88,9 @@ class AuthModeStrategyTests: XCTestCase { let authMode = AWSMultiAuthModeStrategy() let delegate = UnauthenticatedUserDelegate() authMode.authDelegate = delegate - + var authTypesIterator = await authMode.authTypesFor(schema: ModelWithOwnerAndPublicAuth.schema, - operation: .create) + operation: .create) XCTAssertEqual(authTypesIterator.count, 1) XCTAssertEqual(authTypesIterator.next(), .apiKey) } @@ -101,7 +101,7 @@ class AuthModeStrategyTests: XCTestCase { func testMultiAuthPriorityWithCustomStrategy() async { let authMode = AWSMultiAuthModeStrategy() var authTypesIterator = await authMode.authTypesFor(schema: ModelWithCustomStrategy.schema, - operation: .create) + operation: .create) XCTAssertEqual(authTypesIterator.count, 3) XCTAssertEqual(authTypesIterator.next(), .function) XCTAssertEqual(authTypesIterator.next(), .amazonCognitoUserPools) @@ -117,12 +117,35 @@ class AuthModeStrategyTests: XCTestCase { authMode.authDelegate = delegate var authTypesIterator = await authMode.authTypesFor(schema: ModelWithCustomStrategy.schema, - operation: .create) + operation: .create) XCTAssertEqual(authTypesIterator.count, 2) XCTAssertEqual(authTypesIterator.next(), .function) XCTAssertEqual(authTypesIterator.next(), .awsIAM) } + // Given: multi-auth strategy and a model schema without auth provider + // When: auth types are requested with multiple operation + // Then: default values based on the auth strategy should be returned + func testMultiAuthShouldReturnDefaultAuthTypesForMultipleOperation() async { + let authMode = AWSMultiAuthModeStrategy() + var authTypesIterator = await authMode.authTypesFor(schema: ModelNoProvider.schema, operations: [.read, .create]) + XCTAssertEqual(authTypesIterator.count, 2) + XCTAssertEqual(authTypesIterator.next(), .amazonCognitoUserPools) + XCTAssertEqual(authTypesIterator.next(), .apiKey) + } + + // Given: multi-auth strategy and a model schema with auth provider + // When: auth types are requested with multiple operation + // Then: auth rule for public access should be returned + func testMultiAuthReturnDefaultAuthTypesForMultipleOperationWithProvider() async { + let authMode = AWSMultiAuthModeStrategy() + let delegate = UnauthenticatedUserDelegate() + authMode.authDelegate = delegate + var authTypesIterator = await authMode.authTypesFor(schema: ModelNoProvider.schema, operations: [.read, .create]) + XCTAssertEqual(authTypesIterator.count, 1) + XCTAssertEqual(authTypesIterator.next(), .apiKey) + } + } // MARK: - Test models diff --git a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift index 301e5534b8..c45ba79650 100644 --- a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift +++ b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisher.swift @@ -73,7 +73,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { // onCreate operation let onCreateValueListener = onCreateValueListenerHandler(event:) let onCreateAuthTypeProvider = await authModeStrategy.authTypesFor(schema: modelSchema, - operation: .create) + operations: [.create, .read]) self.onCreateValueListener = onCreateValueListener self.onCreateOperation = RetryableGraphQLSubscriptionOperation( requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor( @@ -94,7 +94,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { // onUpdate operation let onUpdateValueListener = onUpdateValueListenerHandler(event:) let onUpdateAuthTypeProvider = await authModeStrategy.authTypesFor(schema: modelSchema, - operation: .update) + operations: [.update, .read]) self.onUpdateValueListener = onUpdateValueListener self.onUpdateOperation = RetryableGraphQLSubscriptionOperation( requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor( @@ -115,7 +115,7 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable { // onDelete operation let onDeleteValueListener = onDeleteValueListenerHandler(event:) let onDeleteAuthTypeProvider = await authModeStrategy.authTypesFor(schema: modelSchema, - operation: .delete) + operations: [.delete, .read]) self.onDeleteValueListener = onDeleteValueListener self.onDeleteOperation = RetryableGraphQLSubscriptionOperation( requestFactory: IncomingAsyncSubscriptionEventPublisher.apiRequestFactoryFor(