diff --git a/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq b/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq index 40f52b18632..0f75d4c9c3e 100644 --- a/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq +++ b/data/persistence/src/commonMain/db_global/com/wire/kalium/persistence/Accounts.sq @@ -57,6 +57,9 @@ SELECT isPersistentWebSocketEnabled FROM Accounts WHERE logout_reason IS NULL AN updatePersistentWebSocketStatus: UPDATE Accounts SET isPersistentWebSocketEnabled = :isPersistentWebSocketEnabled WHERE id = :userId; +updateAllPersistentWebSocketStatus: +UPDATE Accounts SET isPersistentWebSocketEnabled = :enabled WHERE logout_reason IS NULL; + updateSsoId: UPDATE Accounts SET scim_external_id = :scimExternalId, subject = :subject, tenant = :tenant WHERE id = :userId; diff --git a/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt b/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt index f02a2a2058d..44371de022b 100644 --- a/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt +++ b/data/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/daokaliumdb/AccountsDAO.kt @@ -174,6 +174,7 @@ interface AccountsDAO { suspend fun deleteAccount(userIDEntity: UserIDEntity) suspend fun markAccountAsInvalid(userIDEntity: UserIDEntity, logoutReason: LogoutReason) suspend fun updatePersistentWebSocketStatus(userIDEntity: UserIDEntity, isPersistentWebSocketEnabled: Boolean) + suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean suspend fun accountInfo(userIDEntity: UserIDEntity): AccountInfoEntity? fun fullAccountInfo(userIDEntity: UserIDEntity): FullAccountEntity? @@ -304,6 +305,12 @@ internal class AccountsDAOImpl internal constructor( } } + override suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) { + withContext(queriesContext) { + queries.updateAllPersistentWebSocketStatus(enabled) + } + } + override suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean = withContext(queriesContext) { queries.persistentWebSocketStatus(userIDEntity).executeAsOne() } diff --git a/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt b/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt index f3af6f1e3cc..bc191ebe9f4 100644 --- a/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt +++ b/data/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/globalDB/AccountsDAOTest.kt @@ -29,15 +29,9 @@ import com.wire.kalium.persistence.db.GlobalDatabaseBuilder import com.wire.kalium.persistence.model.LogoutReason import com.wire.kalium.persistence.model.ServerConfigEntity import com.wire.kalium.persistence.model.SsoIdEntity -import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.first import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.test.StandardTestDispatcher -import kotlinx.coroutines.test.TestCoroutineScheduler -import kotlinx.coroutines.test.TestDispatcher -import kotlinx.coroutines.test.resetMain import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.test.setMain -import kotlin.test.AfterTest import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals @@ -207,6 +201,79 @@ class AccountsDAOTest : GlobalDBBaseTest() { assertEquals(null, result) } + @Test + fun whenUpdatingPersistentWebSocketStatus_thenStatusIsUpdated() = runTest { + val account = VALID_ACCOUNT + globalDatabaseBuilder.accountsDAO.insertOrReplace(account.info.userIDEntity, account.ssoId, account.managedBy, account.serverConfigId, false) + + // initial status false + val initial = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity) + assertEquals(false, initial) + + // update to true + globalDatabaseBuilder.accountsDAO.updatePersistentWebSocketStatus(account.info.userIDEntity, true) + val updated = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity) + assertEquals(true, updated) + } + + @Test + fun whenSettingAllAccountsPersistentWebSocketEnabled_thenAllStatusesAreUpdated() = runTest { + val a1 = VALID_ACCOUNT + val a2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user2", "domain2"), null)) + val a3 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user3", "domain3"), null)) + + listOf(a1, a2, a3).forEach { + globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false) + } + + globalDatabaseBuilder.accountsDAO.setAllAccountsPersistentWebSocketEnabled(true) + + listOf(a1, a2, a3).forEach { + val status = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(it.info.userIDEntity) + assertEquals(true, status) + } + } + + @Test + fun whenGettingAllValidAccountPersistentWebSocketStatus_thenOnlyValidAccountsIncluded() = runTest { + val valid1 = VALID_ACCOUNT + val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userB", "domainB"), null)) + val invalid = INVALID_ACCOUNT + + // insert accounts with different initial statuses + globalDatabaseBuilder.accountsDAO.insertOrReplace(valid1.info.userIDEntity, valid1.ssoId, valid1.managedBy, valid1.serverConfigId, true) + globalDatabaseBuilder.accountsDAO.insertOrReplace(valid2.info.userIDEntity, valid2.ssoId, valid2.managedBy, valid2.serverConfigId, false) + globalDatabaseBuilder.accountsDAO.insertOrReplace(invalid.info.userIDEntity, invalid.ssoId, invalid.managedBy, invalid.serverConfigId, true) + globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!) + + val list = globalDatabaseBuilder.accountsDAO.getAllValidAccountPersistentWebSocketStatus().first() + // Should contain only the two valid accounts in any order + val ids = list.map { it.userIDEntity }.toSet() + assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), ids) + val map = list.associateBy({ it.userIDEntity }, { it.isPersistentWebSocketEnabled }) + assertEquals(true, map[valid1.info.userIDEntity]) + assertEquals(false, map[valid2.info.userIDEntity]) + } + + @Test + fun whenRequestingValidAccountWithServerConfigId_thenReturnMapForValidAccounts() = runTest { + val valid1 = VALID_ACCOUNT + val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userC", "domainC"), null)) + val invalid = INVALID_ACCOUNT + + listOf(valid1, valid2, invalid).forEach { + globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false) + } + globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!) + + val map = globalDatabaseBuilder.accountsDAO.validAccountWithServerConfigId() + // only valid1 and valid2 should be present + assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), map.keys) + map.values.forEach { serverConfig -> + assertEquals(SERVER_CONFIG, serverConfig) + } + } + private companion object { val VALID_ACCOUNT = FullAccountEntity( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt index 16d741992db..48371f9d405 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/GlobalKaliumScope.kt @@ -63,6 +63,8 @@ import com.wire.kalium.logic.feature.user.ObserveValidAccountsUseCase import com.wire.kalium.logic.feature.user.ObserveValidAccountsUseCaseImpl import com.wire.kalium.logic.feature.user.webSocketStatus.ObservePersistentWebSocketConnectionStatusUseCase import com.wire.kalium.logic.feature.user.webSocketStatus.ObservePersistentWebSocketConnectionStatusUseCaseImpl +import com.wire.kalium.logic.feature.user.webSocketStatus.SetPersistentWebSocketForAllUsersUseCase +import com.wire.kalium.logic.feature.user.webSocketStatus.SetPersistentWebSocketForAllUsersUseCaseImpl import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.sync.GlobalWorkScheduler import com.wire.kalium.logic.sync.WorkSchedulerProvider @@ -121,6 +123,9 @@ public class GlobalKaliumScope internal constructor( public val observePersistentWebSocketConnectionStatus: ObservePersistentWebSocketConnectionStatusUseCase get() = ObservePersistentWebSocketConnectionStatusUseCaseImpl(sessionRepository) + public val setAllPersistentWebSocketEnabled: SetPersistentWebSocketForAllUsersUseCase + get() = SetPersistentWebSocketForAllUsersUseCaseImpl(sessionRepository) + private val notificationTokenRepository: NotificationTokenRepository get() = NotificationTokenDataSource(globalPreferences.tokenStorage) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt index 167578a4dea..b76fa27676d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/session/SessionRepository.kt @@ -77,6 +77,7 @@ internal interface SessionRepository { suspend fun deleteSession(userId: UserId): Either suspend fun ssoId(userId: UserId): Either suspend fun updatePersistentWebSocketStatus(userId: UserId, isPersistentWebSocketEnabled: Boolean): Either + suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either suspend fun updateSsoIdAndScimInfo(userId: UserId, ssoId: SsoId?, managedBy: ManagedByDTO?): Either suspend fun isFederated(userId: UserId): Either suspend fun getAllValidAccountPersistentWebSocketStatus(): Either>> @@ -198,6 +199,9 @@ internal class SessionDataSource internal constructor( accountsDAO.updatePersistentWebSocketStatus(userId.toDao(), isPersistentWebSocketEnabled) } + override suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either = + wrapStorageRequest { accountsDAO.setAllAccountsPersistentWebSocketEnabled(enabled) } + override suspend fun updateSsoIdAndScimInfo( userId: UserId, ssoId: SsoId?, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/webSocketStatus/SetPersistentWebSocketForAllUsersUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/webSocketStatus/SetPersistentWebSocketForAllUsersUseCase.kt new file mode 100644 index 00000000000..17e4516f34b --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/webSocketStatus/SetPersistentWebSocketForAllUsersUseCase.kt @@ -0,0 +1,49 @@ +/* + * Wire + * Copyright (C) 2025 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.kalium.logic.feature.user.webSocketStatus + +import com.wire.kalium.common.error.CoreFailure +import com.wire.kalium.common.functional.fold +import com.wire.kalium.logic.data.session.SessionRepository + +/** + * This use case is responsible for setting the persistent web socket connection status for all users. + */ +public interface SetPersistentWebSocketForAllUsersUseCase { + /** + * @param enabled true if the persistent web socket connection should be enabled for all users, false otherwise + */ + public suspend operator fun invoke(enabled: Boolean): SetAllPersistentWebSocketEnabledResult +} + +public sealed class SetAllPersistentWebSocketEnabledResult { + public data object Success : SetAllPersistentWebSocketEnabledResult() + public data class Failure(val failure: CoreFailure) : SetAllPersistentWebSocketEnabledResult() +} + +internal class SetPersistentWebSocketForAllUsersUseCaseImpl( + private val sessionRepository: SessionRepository +) : SetPersistentWebSocketForAllUsersUseCase { + override suspend operator fun invoke(enabled: Boolean): SetAllPersistentWebSocketEnabledResult = + sessionRepository.setAllPersistentWebSocketEnabled(enabled).fold({ + SetAllPersistentWebSocketEnabledResult.Failure(it) + }, { + SetAllPersistentWebSocketEnabledResult.Success + }) +}