From 2c400694acc41fff7d62c2dfa30049afe9ca4a2c Mon Sep 17 00:00:00 2001 From: Jakub Zerko Date: Tue, 10 Feb 2026 10:06:29 +0100 Subject: [PATCH] chore: remove extra MLS verification transactions on epoch changes --- .../com/wire/kalium/cryptography/MLSClient.kt | 17 ++++++++ .../wire/kalium/cryptography/MLSClientImpl.kt | 14 ++++++ gradle/libs.versions.toml | 2 +- .../conversation/MLSConversationRepository.kt | 7 +-- .../kalium/logic/feature/UserSessionScope.kt | 7 ++- ...onversationMLSVerificationStatusUseCase.kt | 10 +---- .../FetchMLSVerificationStatusUseCase.kt | 32 ++++++++------ ...etMembersE2EICertificateStatusesUseCase.kt | 22 +++++----- ...onversationsVerificationStatusesUseCase.kt | 8 +--- .../kalium/logic/feature/user/UserScope.kt | 4 +- .../MLSConversationRepositoryTest.kt | 7 +-- ...mbersE2EICertificateStatusesUseCaseTest.kt | 30 +++++++------ ...rsationMLSVerificationStatusUseCaseTest.kt | 15 +++---- .../FetchMLSVerificationStatusUseCaseTest.kt | 43 ++++++++++++------- ...rsationsVerificationStatusesUseCaseTest.kt | 7 +-- .../provider/E2EIClientProviderArrangement.kt | 15 ++++++- .../FetchMLSVerificationStatusArrangement.kt | 2 +- 17 files changed, 145 insertions(+), 97 deletions(-) diff --git a/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt b/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt index 092253632bd..561235ca893 100644 --- a/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt +++ b/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt @@ -165,6 +165,23 @@ interface MLSClient { */ suspend fun getPublicKey(): Pair + /** + * Conversation E2EI verification status. + * + * Read-only operation that does not require an explicit transaction context. + */ + suspend fun isGroupVerified(groupId: MLSGroupId): E2EIConversationState + + /** + * Get user identities in a conversation. + * + * Read-only operation that does not require an explicit transaction context. + */ + suspend fun getUserIdentities( + groupId: MLSGroupId, + users: List + ): Map> + /** * Runs a block of code inside a CoreCrypto transaction. * diff --git a/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt b/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt index 72640fdadd2..3b89644bb60 100644 --- a/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt +++ b/core/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt @@ -60,6 +60,20 @@ class MLSClientImpl( } } + override suspend fun isGroupVerified(groupId: MLSGroupId): E2EIConversationState { + return coreCrypto.e2eiConversationState(groupId.toCrypto()).toCryptography() + } + + override suspend fun getUserIdentities( + groupId: MLSGroupId, + users: List + ): Map> { + return coreCrypto.getUserIdentities(groupId.toCrypto(), users.map { it.value }) + .mapValues { (_, identities) -> + identities.mapNotNull { identity -> identity.toCryptography() } + } + } + override suspend fun transaction(name: String, block: suspend (context: MlsCoreCryptoContext) -> R): R { return coreCrypto.transaction(name) { context -> block(mlsCoreCryptoContext(context)) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ac0e91f984b..6999fcc3747 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -39,7 +39,7 @@ pbandk = "0.15.0" turbine = "1.1.0" avs = "10.1.41" jna = "5.17.0" -core-crypto = "9.1.3" +core-crypto = "9.2.0" core-crypto-kmp = "9.1.3.5-kmp" desugar-jdk = "2.1.3" kermit = "2.0.3" diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index cd51e1060b5..e5bd8dc2735 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -40,6 +40,7 @@ import com.wire.kalium.common.functional.onSuccess import com.wire.kalium.common.logger.kaliumLogger import com.wire.kalium.cryptography.CryptoQualifiedClientId import com.wire.kalium.cryptography.E2EIClient +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.MlsCoreCryptoContext import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logic.data.client.toDao @@ -222,7 +223,7 @@ internal interface MLSConversationRepository : MLSMemberAdder { ): Either> suspend fun getMembersIdentities( - mlsContext: MlsCoreCryptoContext, + mlsClient: MLSClient, conversationId: ConversationId, userIds: List ): Either>> @@ -744,7 +745,7 @@ internal class MLSConversationDataSource( } override suspend fun getMembersIdentities( - mlsContext: MlsCoreCryptoContext, + mlsClient: MLSClient, conversationId: ConversationId, userIds: List ): Either>> = @@ -754,7 +755,7 @@ internal class MLSConversationDataSource( wrapMLSRequest { val userIdsAndIdentity = mutableMapOf>() - mlsContext.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() }) + mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() }) .forEach { (userIdValue, identities) -> userIds.firstOrNull { it.value == userIdValue }?.also { userIdsAndIdentity[it] = identities diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index a171f5f788b..2223a236b54 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -2508,8 +2508,7 @@ public class UserSessionScope internal constructor( public val fetchConversationMLSVerificationStatus: FetchConversationMLSVerificationStatusUseCase get() = FetchConversationMLSVerificationStatusUseCaseImpl( conversationRepository, - fetchMLSVerificationStatusUseCase, - cryptoTransactionProvider + fetchMLSVerificationStatusUseCase ) public val kaliumFileSystem: KaliumFileSystem by lazy { @@ -2551,6 +2550,7 @@ public class UserSessionScope internal constructor( private val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase by lazy { FetchMLSVerificationStatusUseCaseImpl( + mlsClientProvider, conversationRepository, persistMessage, mlsConversationRepository, @@ -2564,8 +2564,7 @@ public class UserSessionScope internal constructor( ObserveE2EIConversationsVerificationStatusesUseCaseImpl( fetchMLSVerificationStatus = fetchMLSVerificationStatusUseCase, epochChangesObserver = epochChangesObserver, - kaliumLogger = userScopedLogger, - transactionProvider = cryptoTransactionProvider + kaliumLogger = userScopedLogger ) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCase.kt index c0222cc5a96..e2cce52bd84 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCase.kt @@ -21,8 +21,6 @@ import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.common.functional.onSuccess -import com.wire.kalium.common.functional.right -import com.wire.kalium.logic.data.client.CryptoTransactionProvider /** * Trigger the checking and updating MLS Conversations Verification status. @@ -33,18 +31,14 @@ public interface FetchConversationMLSVerificationStatusUseCase { internal class FetchConversationMLSVerificationStatusUseCaseImpl( private val conversationRepository: ConversationRepository, - private val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase, - private val transactionProvider: CryptoTransactionProvider + private val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase ) : FetchConversationMLSVerificationStatusUseCase { override suspend fun invoke(conversationId: ConversationId) { conversationRepository.getConversationById(conversationId).onSuccess { val protocol = it.protocol if (protocol is Conversation.ProtocolInfo.MLSCapable) - transactionProvider.mlsTransaction("FetchConversationMLSVerificationStatus") { mlsContext -> - fetchMLSVerificationStatusUseCase(mlsContext, protocol.groupId) - Unit.right() - } + fetchMLSVerificationStatusUseCase(protocol.groupId) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCase.kt index 1080e755bf5..a1d98652280 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCase.kt @@ -25,16 +25,18 @@ import com.wire.kalium.common.functional.flatMap import com.wire.kalium.common.functional.getOrElse import com.wire.kalium.common.functional.left import com.wire.kalium.common.functional.map +import com.wire.kalium.common.functional.onFailure import com.wire.kalium.common.functional.onSuccess import com.wire.kalium.common.functional.right import com.wire.kalium.cryptography.CredentialType import com.wire.kalium.cryptography.CryptoCertificateStatus -import com.wire.kalium.cryptography.MlsCoreCryptoContext +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.Conversation.VerificationStatus -import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.mls.EpochChangesData import com.wire.kalium.logic.data.conversation.toModel import com.wire.kalium.logic.data.id.ConversationId @@ -55,7 +57,7 @@ internal typealias UserToWireIdentity = Map> */ @Mockable internal interface FetchMLSVerificationStatusUseCase { - suspend operator fun invoke(mlsContext: MlsCoreCryptoContext, groupId: GroupID) + suspend operator fun invoke(groupId: GroupID) } internal data class VerificationStatusData( @@ -66,6 +68,7 @@ internal data class VerificationStatusData( @Suppress("LongParameterList") internal class FetchMLSVerificationStatusUseCaseImpl( + private val mlsClientProvider: MLSClientProvider, private val conversationRepository: ConversationRepository, private val persistMessage: PersistMessageUseCase, private val mlsConversationRepository: MLSConversationRepository, @@ -76,19 +79,24 @@ internal class FetchMLSVerificationStatusUseCaseImpl( private val logger = kaliumLogger.withTextTag("FetchMLSVerificationStatusUseCaseImpl") - override suspend fun invoke(mlsContext: MlsCoreCryptoContext, groupId: GroupID) { - wrapMLSRequest { mlsContext.isGroupVerified(groupId.value) } + override suspend fun invoke(groupId: GroupID) { + mlsClientProvider.getMLSClient() + .onFailure { logger.w("Could not fetch MLS client for verification refresh: $it") } + .onSuccess { mlsClient -> refreshVerificationStatus(mlsClient, groupId) } + } + + private suspend fun refreshVerificationStatus(mlsClient: MLSClient, groupId: GroupID) { + wrapMLSRequest { mlsClient.isGroupVerified(groupId.value) } .map { it.toModel() } .flatMap { ccGroupStatus -> if (ccGroupStatus == VerificationStatus.VERIFIED) { - verifyUsersStatus(mlsContext, groupId) + verifyUsersStatus(mlsClient, groupId) } else { conversationRepository.getConversationByMLSGroupId(groupId).map { VerificationStatusData( conversationId = it.id, currentPersistedStatus = it.mlsVerificationStatus, - newStatus = - ccGroupStatus + newStatus = ccGroupStatus ) } } @@ -97,13 +105,13 @@ internal class FetchMLSVerificationStatusUseCaseImpl( } } - private suspend fun verifyUsersStatus(mlsContext: MlsCoreCryptoContext, groupId: GroupID): Either = + private suspend fun verifyUsersStatus(mlsClient: MLSClient, groupId: GroupID): Either = conversationRepository.getGroupStatusMembersNamesAndHandles(groupId) .flatMap { epochChangesData -> mlsConversationRepository.getMembersIdentities( - mlsContext, - epochChangesData.conversationId, - epochChangesData.members.keys.toList() + mlsClient = mlsClient, + conversationId = epochChangesData.conversationId, + userIds = epochChangesData.members.keys.toList() ) .flatMap { ccIdentities -> updateKnownUsersIfNeeded(epochChangesData, ccIdentities, groupId) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt index 487c6d4a620..438ff7360d4 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt @@ -23,12 +23,13 @@ import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.conversation.mls.NameAndHandle +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.feature.e2ei.CertificateStatus +import com.wire.kalium.common.functional.flatMap import com.wire.kalium.common.functional.getOrElse import com.wire.kalium.common.functional.map -import com.wire.kalium.logic.data.client.CryptoTransactionProvider /** * This use case is used to get the e2ei certificates of all the users in Conversation. @@ -39,19 +40,18 @@ public interface GetMembersE2EICertificateStatusesUseCase { } internal class GetMembersE2EICertificateStatusesUseCaseImpl internal constructor( + private val mlsClientProvider: MLSClientProvider, private val mlsConversationRepository: MLSConversationRepository, - private val conversationRepository: ConversationRepository, - private val transactionProvider: CryptoTransactionProvider + private val conversationRepository: ConversationRepository ) : GetMembersE2EICertificateStatusesUseCase { override suspend operator fun invoke(conversationId: ConversationId, userIds: List): Map = - transactionProvider - .mlsTransaction("E2EIMembersCertificateStatuses") { mlsContext -> - mlsConversationRepository.getMembersIdentities( - mlsContext, - conversationId, - userIds - ) - } + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + mlsConversationRepository.getMembersIdentities( + mlsClient, + conversationId, + userIds + ) + } .map { identities -> val usersNameAndHandle = conversationRepository.selectMembersNameAndHandle(conversationId).getOrElse(mapOf()) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCase.kt index 8e3d5638bf5..0ad3cfd6f37 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCase.kt @@ -17,9 +17,7 @@ */ package com.wire.kalium.logic.feature.e2ei.usecase -import com.wire.kalium.common.functional.right import com.wire.kalium.logger.KaliumLogger -import com.wire.kalium.logic.data.client.CryptoTransactionProvider import com.wire.kalium.logic.data.conversation.EpochChangesObserver /** @@ -33,7 +31,6 @@ internal interface ObserveE2EIConversationsVerificationStatusesUseCase { internal class ObserveE2EIConversationsVerificationStatusesUseCaseImpl( private val fetchMLSVerificationStatus: FetchMLSVerificationStatusUseCase, private val epochChangesObserver: EpochChangesObserver, - private val transactionProvider: CryptoTransactionProvider, kaliumLogger: KaliumLogger ) : ObserveE2EIConversationsVerificationStatusesUseCase { @@ -44,10 +41,7 @@ internal class ObserveE2EIConversationsVerificationStatusesUseCaseImpl( epochChangesObserver.observe() .collect { groupWithEpoch -> logger.d("Epoch for group ${groupWithEpoch.groupId.toLogString()} changed ${groupWithEpoch.epoch}") - transactionProvider.mlsTransaction("ObserveE2EIConversationsVerificationStatuses") { mlsContext -> - fetchMLSVerificationStatus(mlsContext, groupWithEpoch.groupId) - Unit.right() - } + fetchMLSVerificationStatus(groupWithEpoch.groupId) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt index 4be3929cd99..3ace93886c2 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt @@ -190,9 +190,9 @@ public class UserScope internal constructor( ) public val getMembersE2EICertificateStatuses: GetMembersE2EICertificateStatusesUseCase get() = GetMembersE2EICertificateStatusesUseCaseImpl( + mlsClientProvider = transactionProvider.mlsClientProvider, mlsConversationRepository = mlsConversationRepository, - conversationRepository = conversationRepository, - transactionProvider = transactionProvider + conversationRepository = conversationRepository ) public val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository) public val setUserHandle: SetUserHandleUseCase get() = SetUserHandleUseCase(accountRepository, validateUserHandleUseCase, syncManager) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index 341c3a24b45..f7a7d317844 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -70,6 +70,7 @@ import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.test_util.TestKaliumDispatcher import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementImpl +import com.wire.kalium.logic.util.arrangement.provider.DummyMLSClient import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.authenticated.client.DeviceTypeDTO @@ -1306,7 +1307,7 @@ class MLSConversationRepositoryTest { member2 to listOf(WIRE_IDENTITY.copy(clientId = CRYPTO_CLIENT_ID.copy("member_2_client_id"))) ) ), - mlsConversationRepository.getMembersIdentities(arrangement.mlsContext, TestConversation.ID, listOf(member1, member2, member3)) + mlsConversationRepository.getMembersIdentities(arrangement.mlsClient, TestConversation.ID, listOf(member1, member2, member3)) ) coVerify { @@ -1396,7 +1397,7 @@ class MLSConversationRepositoryTest { .withGetUserIdentitiesReturn(mapOf(groupId to listOf(wireIdentity))) .arrange() // when - val result = mlsConversationRepository.getMembersIdentities(arrangement.mlsContext, TestConversation.ID, listOf(TestUser.USER_ID)) + val result = mlsConversationRepository.getMembersIdentities(arrangement.mlsClient, TestConversation.ID, listOf(TestUser.USER_ID)) // then result.shouldSucceed() { it.values.forEach { @@ -1480,7 +1481,7 @@ class MLSConversationRepositoryTest { val checkRevocationList = mock(RevocationListChecker::class) val certificateRevocationListRepository = mock(CertificateRevocationListRepository::class) val epochChangesObserver = mock(EpochChangesObserver::class) - + val mlsClient = DummyMLSClient(mlsContext) val epochsFlow = MutableSharedFlow() val proposalTimersFlow = MutableSharedFlow() diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt index ef7c3368bec..4898060ae84 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt @@ -20,8 +20,10 @@ package com.wire.kalium.logic.feature.e2ei import com.wire.kalium.cryptography.CredentialType import com.wire.kalium.cryptography.CryptoCertificateStatus import com.wire.kalium.cryptography.CryptoQualifiedClientId +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.common.error.MLSFailure +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.mls.NameAndHandle import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.toCrypto @@ -30,11 +32,11 @@ import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatu import com.wire.kalium.common.functional.Either import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl -import kotlinx.coroutines.runBlocking +import io.mockative.any +import io.mockative.coEvery +import io.mockative.mock import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant import kotlin.test.Test @@ -116,24 +118,28 @@ class GetMembersE2EICertificateStatusesUseCaseTest { private class Arrangement(private val block: suspend Arrangement.() -> Unit) : MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(), - CryptoTransactionProviderArrangement by CryptoTransactionProviderArrangementImpl(), ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl() { - fun arrange() = run { - runBlocking { - withMLSTransactionReturning(Either.Right(Unit)) - block() - } + val mlsClientProvider = mock(MLSClientProvider::class) + val mlsClient = mock(MLSClient::class) + + suspend fun withMLSClientSuccess() { + coEvery { mlsClientProvider.getMLSClient(any()) }.returns(Either.Right(mlsClient)) + } + + suspend fun arrange() = run { + withMLSClientSuccess() + block() this@Arrangement to GetMembersE2EICertificateStatusesUseCaseImpl( + mlsClientProvider = mlsClientProvider, mlsConversationRepository = mlsConversationRepository, - conversationRepository = conversationRepository, - transactionProvider = cryptoTransactionProvider + conversationRepository = conversationRepository ) } } private companion object { - fun arrange(configuration: suspend Arrangement.() -> Unit) = Arrangement(configuration).arrange() + suspend fun arrange(configuration: suspend Arrangement.() -> Unit) = Arrangement(configuration).arrange() private val USER_ID = UserId("value", "domain") private val CRYPTO_QUALIFIED_CLIENT_ID = diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCaseTest.kt index 2243844569d..38d2a684c59 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchConversationMLSVerificationStatusUseCaseTest.kt @@ -18,10 +18,8 @@ package com.wire.kalium.logic.feature.e2ei.usecase import com.wire.kalium.common.error.StorageFailure -import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.common.functional.Either -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementImpl +import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.usecase.FetchMLSVerificationStatusArrangement @@ -46,7 +44,7 @@ class FetchConversationMLSVerificationStatusUseCaseTest { useCase(TestConversation.ID) advanceUntilIdle() - coVerify { arrangement.fetchMLSVerificationStatusUseCase(any(), any()) } + coVerify { arrangement.fetchMLSVerificationStatusUseCase(any()) } .wasNotInvoked() } @@ -59,7 +57,7 @@ class FetchConversationMLSVerificationStatusUseCaseTest { useCase(TestConversation.ID) advanceUntilIdle() - coVerify { arrangement.fetchMLSVerificationStatusUseCase(any(), any()) } + coVerify { arrangement.fetchMLSVerificationStatusUseCase(any()) } .wasNotInvoked() } @@ -73,7 +71,7 @@ class FetchConversationMLSVerificationStatusUseCaseTest { useCase(TestConversation.ID) advanceUntilIdle() - coVerify { arrangement.fetchMLSVerificationStatusUseCase(any(), eq(protocolInfo.groupId)) } + coVerify { arrangement.fetchMLSVerificationStatusUseCase(eq(protocolInfo.groupId)) } .wasInvoked() } @@ -82,17 +80,14 @@ class FetchConversationMLSVerificationStatusUseCaseTest { private class Arrangement( private val block: suspend Arrangement.() -> Unit ) : FetchMLSVerificationStatusArrangement by FetchMLSVerificationStatusArrangementImpl(), - CryptoTransactionProviderArrangement by CryptoTransactionProviderArrangementImpl(), ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl() { suspend fun arrange() = let { block() - withMLSTransactionReturning(Either.Right(Unit)) mockFetchMLSVerificationStatus() this to FetchConversationMLSVerificationStatusUseCaseImpl( conversationRepository = conversationRepository, - fetchMLSVerificationStatusUseCase = fetchMLSVerificationStatusUseCase, - transactionProvider = cryptoTransactionProvider + fetchMLSVerificationStatusUseCase = fetchMLSVerificationStatusUseCase ) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCaseTest.kt index a2e9b0f8791..f93c6dee87e 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/FetchMLSVerificationStatusUseCaseTest.kt @@ -24,7 +24,9 @@ import com.wire.kalium.cryptography.CredentialType import com.wire.kalium.cryptography.CryptoCertificateStatus import com.wire.kalium.cryptography.CryptoQualifiedClientId import com.wire.kalium.cryptography.E2EIConversationState +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.WireIdentity +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.conversation.mls.EpochChangesData @@ -36,8 +38,6 @@ import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.feature.e2ei.usecase.FetchMLSVerificationStatusUseCaseTest.Arrangement.Companion.getMockedIdentity import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestUser -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangement @@ -63,11 +63,12 @@ class FetchMLSVerificationStatusUseCaseTest { fun givenNotVerifiedConversation_whenNotVerifiedStatusComes_thenNothingChanged() = runTest { val conversationDetails = TestConversation.MLS_CONVERSATION val (arrangement, handler) = arrange { + withGetMLSClientSuccess() withIsGroupVerified(E2EIConversationState.NOT_VERIFIED) withConversationByMLSGroupId(Either.Right(conversationDetails)) } - handler(arrangement.mlsContext, TestConversation.GROUP_ID) + handler(TestConversation.GROUP_ID) advanceUntilIdle() coVerify { @@ -89,11 +90,12 @@ class FetchMLSVerificationStatusUseCaseTest { mlsVerificationStatus = Conversation.VerificationStatus.VERIFIED ) val (arrangement, handler) = arrange { + withGetMLSClientSuccess() withIsGroupVerified(E2EIConversationState.NOT_VERIFIED) withConversationByMLSGroupId(Either.Right(conversationDetails)) } - handler(arrangement.mlsContext, TestConversation.GROUP_ID) + handler(TestConversation.GROUP_ID) advanceUntilIdle() coVerify { @@ -119,11 +121,12 @@ class FetchMLSVerificationStatusUseCaseTest { ) val (arrangement, handler) = arrange { + withGetMLSClientSuccess() withIsGroupVerified(E2EIConversationState.NOT_VERIFIED) withConversationByMLSGroupId(Either.Right(conversationDetails)) } - handler(arrangement.mlsContext, TestConversation.GROUP_ID) + handler(TestConversation.GROUP_ID) advanceUntilIdle() coVerify { @@ -160,12 +163,13 @@ class FetchMLSVerificationStatusUseCaseTest { ) ) val (arrangement, handler) = arrange { + withGetMLSClientSuccess() withIsGroupVerified(E2EIConversationState.VERIFIED) withSelectGroupStatusMembersNamesAndHandles(Either.Right(epochChangedData)) - withGetMembersIdentities(Either.Right(ccMembersIdentity)) + withGetMembersIdentities(ccMembersIdentity) } - handler(arrangement.mlsContext, TestConversation.GROUP_ID) + handler(TestConversation.GROUP_ID) advanceUntilIdle() coVerify { @@ -205,12 +209,13 @@ class FetchMLSVerificationStatusUseCaseTest { ) ) val (arrangement, handler) = arrange { + withGetMLSClientSuccess() withIsGroupVerified(E2EIConversationState.VERIFIED) withSelectGroupStatusMembersNamesAndHandles(Either.Right(epochChangedData)) - withGetMembersIdentities(Either.Right(ccMembersIdentity)) + withGetMembersIdentities(ccMembersIdentity) } - handler(arrangement.mlsContext, TestConversation.GROUP_ID) + handler(TestConversation.GROUP_ID) advanceUntilIdle() coVerify { @@ -235,21 +240,26 @@ class FetchMLSVerificationStatusUseCaseTest { private val block: suspend Arrangement.() -> Unit ) : ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), PersistMessageUseCaseArrangement by PersistMessageUseCaseArrangementImpl(), - CryptoTransactionProviderArrangement by CryptoTransactionProviderArrangementImpl(), UserRepositoryArrangement by UserRepositoryArrangementImpl() { + val mlsClientProvider = mock(MLSClientProvider::class) + val mlsClient = mock(MLSClient::class) val mlsConversationRepository = mock(MLSConversationRepository::class) + suspend fun withGetMLSClientSuccess() { + coEvery { mlsClientProvider.getMLSClient(any()) }.returns(Either.Right(mlsClient)) + } + suspend fun withIsGroupVerified(result: E2EIConversationState) { coEvery { - mlsContext.isGroupVerified(any()) + mlsClient.isGroupVerified(any()) }.returns(result) } - suspend fun withGetMembersIdentities(result: Either>>) { + suspend fun withGetMembersIdentities(result: Map>) { coEvery { - mlsConversationRepository.getMembersIdentities(any(), any(), any()) - }.returns(result) + mlsConversationRepository.getMembersIdentities(eq(mlsClient), any(), any()) + }.returns(Either.Right(result)) } suspend inline fun arrange() = let { @@ -258,12 +268,13 @@ class FetchMLSVerificationStatusUseCaseTest { withSetDegradedConversationNotifiedFlag(Either.Right(Unit)) block() this to FetchMLSVerificationStatusUseCaseImpl( + mlsClientProvider = mlsClientProvider, + mlsConversationRepository = mlsConversationRepository, conversationRepository = conversationRepository, persistMessage = persistMessageUseCase, selfUserId = TestUser.USER_ID, kaliumLogger = kaliumLogger, - userRepository = userRepository, - mlsConversationRepository = mlsConversationRepository + userRepository = userRepository ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCaseTest.kt index 76179d17058..739c5078670 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/ObserveE2EIConversationsVerificationStatusesUseCaseTest.kt @@ -20,8 +20,6 @@ package com.wire.kalium.logic.feature.e2ei.usecase import com.wire.kalium.common.functional.Either import com.wire.kalium.common.logger.kaliumLogger import com.wire.kalium.logic.framework.TestConversation -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement -import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.MLSConversationRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.MLSConversationRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.usecase.FetchMLSVerificationStatusArrangement @@ -47,7 +45,7 @@ class ObserveE2EIConversationsVerificationStatusesUseCaseTest { handler() advanceUntilIdle() - coVerify { arrangement.fetchMLSVerificationStatusUseCase(any(), eq(TestConversation.GROUP_ID)) } + coVerify { arrangement.fetchMLSVerificationStatusUseCase(eq(TestConversation.GROUP_ID)) } .wasInvoked() } @@ -56,17 +54,14 @@ class ObserveE2EIConversationsVerificationStatusesUseCaseTest { private class Arrangement( private val block: Arrangement.() -> Unit ) : FetchMLSVerificationStatusArrangement by FetchMLSVerificationStatusArrangementImpl(), - CryptoTransactionProviderArrangement by CryptoTransactionProviderArrangementImpl(), MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl() { suspend fun arrange() = let { block() mockFetchMLSVerificationStatus() - withMLSTransactionReturning(Either.Right(Unit)) this to ObserveE2EIConversationsVerificationStatusesUseCaseImpl( epochChangesObserver = epochChangesObserver, fetchMLSVerificationStatus = fetchMLSVerificationStatusUseCase, - transactionProvider = cryptoTransactionProvider, kaliumLogger = kaliumLogger, ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt index 5c9a2fc61c2..f997984720b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt @@ -28,8 +28,11 @@ import com.wire.kalium.logic.data.user.SelfUser import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.common.functional.Either import com.wire.kalium.common.functional.right +import com.wire.kalium.cryptography.CryptoQualifiedID +import com.wire.kalium.cryptography.E2EIConversationState import com.wire.kalium.cryptography.MLSCiphersuite import com.wire.kalium.cryptography.MlsCoreCryptoContext +import com.wire.kalium.cryptography.WireIdentity import dev.mokkery.answering.returns import dev.mokkery.everySuspend import dev.mokkery.matcher.any as mokkeryAny @@ -196,9 +199,19 @@ class DummyMLSClient( TODO("Not yet implemented") } + override suspend fun isGroupVerified(groupId: String): E2EIConversationState { + return context.isGroupVerified(groupId) + } + + override suspend fun getUserIdentities( + groupId: String, + users: List + ): Map> { + return context.getUserIdentities(groupId, users) + } + override suspend fun transaction(name: String, block: suspend (context: MlsCoreCryptoContext) -> R): R { return block(context) } } - diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/FetchMLSVerificationStatusArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/FetchMLSVerificationStatusArrangement.kt index 36e79072603..0f98abbc9c3 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/FetchMLSVerificationStatusArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/FetchMLSVerificationStatusArrangement.kt @@ -33,6 +33,6 @@ internal class FetchMLSVerificationStatusArrangementImpl : FetchMLSVerificationS override val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase = mock(FetchMLSVerificationStatusUseCase::class) override suspend fun mockFetchMLSVerificationStatus() { - coEvery { fetchMLSVerificationStatusUseCase(any(), any()) }.returns(Unit) + coEvery { fetchMLSVerificationStatusUseCase(any()) }.returns(Unit) } }