Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,23 @@ interface MLSClient {
*/
suspend fun getPublicKey(): Pair<ByteArray, MLSCiphersuite>

/**
* Conversation E2EI verification status.
*
* Read-only operation that does not require an explicit transaction context.
*/
suspend fun isGroupVerified(groupId: MLSGroupId): E2EIConversationState

/**
* Get user identities in a conversation.
*
* Read-only operation that does not require an explicit transaction context.
*/
suspend fun getUserIdentities(
groupId: MLSGroupId,
users: List<CryptoQualifiedID>
): Map<String, List<WireIdentity>>

/**
* Runs a block of code inside a CoreCrypto transaction.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ class MLSClientImpl(
}
}

override suspend fun isGroupVerified(groupId: MLSGroupId): E2EIConversationState {
return coreCrypto.e2eiConversationState(groupId.toCrypto()).toCryptography()
}

override suspend fun getUserIdentities(
groupId: MLSGroupId,
users: List<CryptoQualifiedID>
): Map<String, List<WireIdentity>> {
return coreCrypto.getUserIdentities(groupId.toCrypto(), users.map { it.value })
.mapValues { (_, identities) ->
identities.mapNotNull { identity -> identity.toCryptography() }
}
}

override suspend fun <R> transaction(name: String, block: suspend (context: MlsCoreCryptoContext) -> R): R {
return coreCrypto.transaction(name) { context ->
block(mlsCoreCryptoContext(context))
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pbandk = "0.15.0"
turbine = "1.1.0"
avs = "10.1.41"
jna = "5.17.0"
core-crypto = "9.1.3"
core-crypto = "9.2.0"
core-crypto-kmp = "9.1.3.5-kmp"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
core-crypto-kmp = "9.1.3.5-kmp"
core-crypto-kmp = "9.2.0.1-kmp"

desugar-jdk = "2.1.3"
kermit = "2.0.3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.wire.kalium.common.functional.onSuccess
import com.wire.kalium.common.logger.kaliumLogger
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.MlsCoreCryptoContext
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.data.client.toDao
Expand Down Expand Up @@ -222,7 +223,7 @@ internal interface MLSConversationRepository : MLSMemberAdder {
): Either<CoreFailure, List<WireIdentity>>

suspend fun getMembersIdentities(
mlsContext: MlsCoreCryptoContext,
mlsClient: MLSClient,
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>>
Expand Down Expand Up @@ -744,7 +745,7 @@ internal class MLSConversationDataSource(
}

override suspend fun getMembersIdentities(
mlsContext: MlsCoreCryptoContext,
mlsClient: MLSClient,
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>> =
Expand All @@ -754,7 +755,7 @@ internal class MLSConversationDataSource(
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()

mlsContext.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2508,8 +2508,7 @@ public class UserSessionScope internal constructor(
public val fetchConversationMLSVerificationStatus: FetchConversationMLSVerificationStatusUseCase
get() = FetchConversationMLSVerificationStatusUseCaseImpl(
conversationRepository,
fetchMLSVerificationStatusUseCase,
cryptoTransactionProvider
fetchMLSVerificationStatusUseCase
)

public val kaliumFileSystem: KaliumFileSystem by lazy {
Expand Down Expand Up @@ -2551,6 +2550,7 @@ public class UserSessionScope internal constructor(

private val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase by lazy {
FetchMLSVerificationStatusUseCaseImpl(
mlsClientProvider,
conversationRepository,
persistMessage,
mlsConversationRepository,
Expand All @@ -2564,8 +2564,7 @@ public class UserSessionScope internal constructor(
ObserveE2EIConversationsVerificationStatusesUseCaseImpl(
fetchMLSVerificationStatus = fetchMLSVerificationStatusUseCase,
epochChangesObserver = epochChangesObserver,
kaliumLogger = userScopedLogger,
transactionProvider = cryptoTransactionProvider
kaliumLogger = userScopedLogger
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.common.functional.onSuccess
import com.wire.kalium.common.functional.right
import com.wire.kalium.logic.data.client.CryptoTransactionProvider

/**
* Trigger the checking and updating MLS Conversations Verification status.
Expand All @@ -33,18 +31,14 @@ public interface FetchConversationMLSVerificationStatusUseCase {

internal class FetchConversationMLSVerificationStatusUseCaseImpl(
private val conversationRepository: ConversationRepository,
private val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase,
private val transactionProvider: CryptoTransactionProvider
private val fetchMLSVerificationStatusUseCase: FetchMLSVerificationStatusUseCase
) : FetchConversationMLSVerificationStatusUseCase {

override suspend fun invoke(conversationId: ConversationId) {
conversationRepository.getConversationById(conversationId).onSuccess {
val protocol = it.protocol
if (protocol is Conversation.ProtocolInfo.MLSCapable)
transactionProvider.mlsTransaction("FetchConversationMLSVerificationStatus") { mlsContext ->
fetchMLSVerificationStatusUseCase(mlsContext, protocol.groupId)
Unit.right()
}
fetchMLSVerificationStatusUseCase(protocol.groupId)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ import com.wire.kalium.common.functional.flatMap
import com.wire.kalium.common.functional.getOrElse
import com.wire.kalium.common.functional.left
import com.wire.kalium.common.functional.map
import com.wire.kalium.common.functional.onFailure
import com.wire.kalium.common.functional.onSuccess
import com.wire.kalium.common.functional.right
import com.wire.kalium.cryptography.CredentialType
import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.MlsCoreCryptoContext
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.Conversation.VerificationStatus
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.mls.EpochChangesData
import com.wire.kalium.logic.data.conversation.toModel
import com.wire.kalium.logic.data.id.ConversationId
Expand All @@ -55,7 +57,7 @@ internal typealias UserToWireIdentity = Map<UserId, List<WireIdentity>>
*/
@Mockable
internal interface FetchMLSVerificationStatusUseCase {
suspend operator fun invoke(mlsContext: MlsCoreCryptoContext, groupId: GroupID)
suspend operator fun invoke(groupId: GroupID)
}

internal data class VerificationStatusData(
Expand All @@ -66,6 +68,7 @@ internal data class VerificationStatusData(

@Suppress("LongParameterList")
internal class FetchMLSVerificationStatusUseCaseImpl(
private val mlsClientProvider: MLSClientProvider,
private val conversationRepository: ConversationRepository,
private val persistMessage: PersistMessageUseCase,
private val mlsConversationRepository: MLSConversationRepository,
Expand All @@ -76,19 +79,24 @@ internal class FetchMLSVerificationStatusUseCaseImpl(

private val logger = kaliumLogger.withTextTag("FetchMLSVerificationStatusUseCaseImpl")

override suspend fun invoke(mlsContext: MlsCoreCryptoContext, groupId: GroupID) {
wrapMLSRequest { mlsContext.isGroupVerified(groupId.value) }
override suspend fun invoke(groupId: GroupID) {
mlsClientProvider.getMLSClient()
.onFailure { logger.w("Could not fetch MLS client for verification refresh: $it") }
.onSuccess { mlsClient -> refreshVerificationStatus(mlsClient, groupId) }
}

private suspend fun refreshVerificationStatus(mlsClient: MLSClient, groupId: GroupID) {
wrapMLSRequest { mlsClient.isGroupVerified(groupId.value) }
.map { it.toModel() }
.flatMap { ccGroupStatus ->
if (ccGroupStatus == VerificationStatus.VERIFIED) {
verifyUsersStatus(mlsContext, groupId)
verifyUsersStatus(mlsClient, groupId)
} else {
conversationRepository.getConversationByMLSGroupId(groupId).map {
VerificationStatusData(
conversationId = it.id,
currentPersistedStatus = it.mlsVerificationStatus,
newStatus =
ccGroupStatus
newStatus = ccGroupStatus
)
}
}
Expand All @@ -97,13 +105,13 @@ internal class FetchMLSVerificationStatusUseCaseImpl(
}
}

private suspend fun verifyUsersStatus(mlsContext: MlsCoreCryptoContext, groupId: GroupID): Either<CoreFailure, VerificationStatusData> =
private suspend fun verifyUsersStatus(mlsClient: MLSClient, groupId: GroupID): Either<CoreFailure, VerificationStatusData> =
conversationRepository.getGroupStatusMembersNamesAndHandles(groupId)
.flatMap { epochChangesData ->
mlsConversationRepository.getMembersIdentities(
mlsContext,
epochChangesData.conversationId,
epochChangesData.members.keys.toList()
mlsClient = mlsClient,
conversationId = epochChangesData.conversationId,
userIds = epochChangesData.members.keys.toList()
)
.flatMap { ccIdentities ->
updateKnownUsersIfNeeded(epochChangesData, ccIdentities, groupId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.conversation.mls.NameAndHandle
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.common.functional.flatMap
import com.wire.kalium.common.functional.getOrElse
import com.wire.kalium.common.functional.map
import com.wire.kalium.logic.data.client.CryptoTransactionProvider

/**
* This use case is used to get the e2ei certificates of all the users in Conversation.
Expand All @@ -39,19 +40,18 @@ public interface GetMembersE2EICertificateStatusesUseCase {
}

internal class GetMembersE2EICertificateStatusesUseCaseImpl internal constructor(
private val mlsClientProvider: MLSClientProvider,
private val mlsConversationRepository: MLSConversationRepository,
private val conversationRepository: ConversationRepository,
private val transactionProvider: CryptoTransactionProvider
private val conversationRepository: ConversationRepository
) : GetMembersE2EICertificateStatusesUseCase {
override suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, Boolean> =
transactionProvider
.mlsTransaction("E2EIMembersCertificateStatuses") { mlsContext ->
mlsConversationRepository.getMembersIdentities(
mlsContext,
conversationId,
userIds
)
}
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
mlsConversationRepository.getMembersIdentities(
mlsClient,
conversationId,
userIds
)
}
.map { identities ->
val usersNameAndHandle = conversationRepository.selectMembersNameAndHandle(conversationId).getOrElse(mapOf())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.common.functional.right
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.data.client.CryptoTransactionProvider
import com.wire.kalium.logic.data.conversation.EpochChangesObserver

/**
Expand All @@ -33,7 +31,6 @@ internal interface ObserveE2EIConversationsVerificationStatusesUseCase {
internal class ObserveE2EIConversationsVerificationStatusesUseCaseImpl(
private val fetchMLSVerificationStatus: FetchMLSVerificationStatusUseCase,
private val epochChangesObserver: EpochChangesObserver,
private val transactionProvider: CryptoTransactionProvider,
kaliumLogger: KaliumLogger
) : ObserveE2EIConversationsVerificationStatusesUseCase {

Expand All @@ -44,10 +41,7 @@ internal class ObserveE2EIConversationsVerificationStatusesUseCaseImpl(
epochChangesObserver.observe()
.collect { groupWithEpoch ->
logger.d("Epoch for group ${groupWithEpoch.groupId.toLogString()} changed ${groupWithEpoch.epoch}")
transactionProvider.mlsTransaction("ObserveE2EIConversationsVerificationStatuses") { mlsContext ->
fetchMLSVerificationStatus(mlsContext, groupWithEpoch.groupId)
Unit.right()
}
fetchMLSVerificationStatus(groupWithEpoch.groupId)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ public class UserScope internal constructor(
)
public val getMembersE2EICertificateStatuses: GetMembersE2EICertificateStatusesUseCase
get() = GetMembersE2EICertificateStatusesUseCaseImpl(
mlsClientProvider = transactionProvider.mlsClientProvider,
mlsConversationRepository = mlsConversationRepository,
conversationRepository = conversationRepository,
transactionProvider = transactionProvider
conversationRepository = conversationRepository
)
public val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository)
public val setUserHandle: SetUserHandleUseCase get() = SetUserHandleUseCase(accountRepository, validateUserHandleUseCase, syncManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement
import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementImpl
import com.wire.kalium.logic.util.arrangement.provider.DummyMLSClient
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.authenticated.client.DeviceTypeDTO
Expand Down Expand Up @@ -1306,7 +1307,7 @@ class MLSConversationRepositoryTest {
member2 to listOf(WIRE_IDENTITY.copy(clientId = CRYPTO_CLIENT_ID.copy("member_2_client_id")))
)
),
mlsConversationRepository.getMembersIdentities(arrangement.mlsContext, TestConversation.ID, listOf(member1, member2, member3))
mlsConversationRepository.getMembersIdentities(arrangement.mlsClient, TestConversation.ID, listOf(member1, member2, member3))
)

coVerify {
Expand Down Expand Up @@ -1396,7 +1397,7 @@ class MLSConversationRepositoryTest {
.withGetUserIdentitiesReturn(mapOf(groupId to listOf(wireIdentity)))
.arrange()
// when
val result = mlsConversationRepository.getMembersIdentities(arrangement.mlsContext, TestConversation.ID, listOf(TestUser.USER_ID))
val result = mlsConversationRepository.getMembersIdentities(arrangement.mlsClient, TestConversation.ID, listOf(TestUser.USER_ID))
// then
result.shouldSucceed() {
it.values.forEach {
Expand Down Expand Up @@ -1480,7 +1481,7 @@ class MLSConversationRepositoryTest {
val checkRevocationList = mock(RevocationListChecker::class)
val certificateRevocationListRepository = mock(CertificateRevocationListRepository::class)
val epochChangesObserver = mock(EpochChangesObserver::class)

val mlsClient = DummyMLSClient(mlsContext)
val epochsFlow = MutableSharedFlow<GroupID>()

val proposalTimersFlow = MutableSharedFlow<ProposalTimer>()
Expand Down
Loading
Loading