Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ SELECT isPersistentWebSocketEnabled FROM Accounts WHERE logout_reason IS NULL AN
updatePersistentWebSocketStatus:
UPDATE Accounts SET isPersistentWebSocketEnabled = :isPersistentWebSocketEnabled WHERE id = :userId;

updateAllPersistentWebSocketStatus:
UPDATE Accounts SET isPersistentWebSocketEnabled = :enabled WHERE logout_reason IS NULL;

updateSsoId:
UPDATE Accounts SET scim_external_id = :scimExternalId, subject = :subject, tenant = :tenant WHERE id = :userId;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ interface AccountsDAO {
suspend fun deleteAccount(userIDEntity: UserIDEntity)
suspend fun markAccountAsInvalid(userIDEntity: UserIDEntity, logoutReason: LogoutReason)
suspend fun updatePersistentWebSocketStatus(userIDEntity: UserIDEntity, isPersistentWebSocketEnabled: Boolean)
suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean)
suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean
suspend fun accountInfo(userIDEntity: UserIDEntity): AccountInfoEntity?
fun fullAccountInfo(userIDEntity: UserIDEntity): FullAccountEntity?
Expand Down Expand Up @@ -304,6 +305,12 @@ internal class AccountsDAOImpl internal constructor(
}
}

override suspend fun setAllAccountsPersistentWebSocketEnabled(enabled: Boolean) {
withContext(queriesContext) {
queries.updateAllPersistentWebSocketStatus(enabled)
}
}

override suspend fun persistentWebSocketStatus(userIDEntity: UserIDEntity): Boolean = withContext(queriesContext) {
queries.persistentWebSocketStatus(userIDEntity).executeAsOne()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,9 @@ import com.wire.kalium.persistence.db.GlobalDatabaseBuilder
import com.wire.kalium.persistence.model.LogoutReason
import com.wire.kalium.persistence.model.ServerConfigEntity
import com.wire.kalium.persistence.model.SsoIdEntity
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.TestDispatcher
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.test.setMain
import kotlin.test.AfterTest
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -207,6 +201,79 @@ class AccountsDAOTest : GlobalDBBaseTest() {
assertEquals(null, result)
}

@Test
fun whenUpdatingPersistentWebSocketStatus_thenStatusIsUpdated() = runTest {
val account = VALID_ACCOUNT
globalDatabaseBuilder.accountsDAO.insertOrReplace(account.info.userIDEntity, account.ssoId, account.managedBy, account.serverConfigId, false)

// initial status false
val initial = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity)
assertEquals(false, initial)

// update to true
globalDatabaseBuilder.accountsDAO.updatePersistentWebSocketStatus(account.info.userIDEntity, true)
val updated = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(account.info.userIDEntity)
assertEquals(true, updated)
}

@Test
fun whenSettingAllAccountsPersistentWebSocketEnabled_thenAllStatusesAreUpdated() = runTest {
val a1 = VALID_ACCOUNT
val a2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user2", "domain2"), null))
val a3 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("user3", "domain3"), null))

listOf(a1, a2, a3).forEach {
globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false)
}

globalDatabaseBuilder.accountsDAO.setAllAccountsPersistentWebSocketEnabled(true)

listOf(a1, a2, a3).forEach {
val status = globalDatabaseBuilder.accountsDAO.persistentWebSocketStatus(it.info.userIDEntity)
assertEquals(true, status)
}
}

@Test
fun whenGettingAllValidAccountPersistentWebSocketStatus_thenOnlyValidAccountsIncluded() = runTest {
val valid1 = VALID_ACCOUNT
val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userB", "domainB"), null))
val invalid = INVALID_ACCOUNT

// insert accounts with different initial statuses
globalDatabaseBuilder.accountsDAO.insertOrReplace(valid1.info.userIDEntity, valid1.ssoId, valid1.managedBy, valid1.serverConfigId, true)
globalDatabaseBuilder.accountsDAO.insertOrReplace(valid2.info.userIDEntity, valid2.ssoId, valid2.managedBy, valid2.serverConfigId, false)
globalDatabaseBuilder.accountsDAO.insertOrReplace(invalid.info.userIDEntity, invalid.ssoId, invalid.managedBy, invalid.serverConfigId, true)
globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!)

val list = globalDatabaseBuilder.accountsDAO.getAllValidAccountPersistentWebSocketStatus().first()
// Should contain only the two valid accounts in any order
val ids = list.map { it.userIDEntity }.toSet()
assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), ids)
val map = list.associateBy({ it.userIDEntity }, { it.isPersistentWebSocketEnabled })
assertEquals(true, map[valid1.info.userIDEntity])
assertEquals(false, map[valid2.info.userIDEntity])
}

@Test
fun whenRequestingValidAccountWithServerConfigId_thenReturnMapForValidAccounts() = runTest {
val valid1 = VALID_ACCOUNT
val valid2 = VALID_ACCOUNT.copy(info = AccountInfoEntity(UserIDEntity("userC", "domainC"), null))
val invalid = INVALID_ACCOUNT

listOf(valid1, valid2, invalid).forEach {
globalDatabaseBuilder.accountsDAO.insertOrReplace(it.info.userIDEntity, it.ssoId, it.managedBy, it.serverConfigId, false)
}
globalDatabaseBuilder.accountsDAO.markAccountAsInvalid(invalid.info.userIDEntity, invalid.info.logoutReason!!)

val map = globalDatabaseBuilder.accountsDAO.validAccountWithServerConfigId()
// only valid1 and valid2 should be present
assertEquals(setOf(valid1.info.userIDEntity, valid2.info.userIDEntity), map.keys)
map.values.forEach { serverConfig ->
assertEquals(SERVER_CONFIG, serverConfig)
}
}

private companion object {

val VALID_ACCOUNT = FullAccountEntity(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ import com.wire.kalium.logic.feature.user.ObserveValidAccountsUseCase
import com.wire.kalium.logic.feature.user.ObserveValidAccountsUseCaseImpl
import com.wire.kalium.logic.feature.user.webSocketStatus.ObservePersistentWebSocketConnectionStatusUseCase
import com.wire.kalium.logic.feature.user.webSocketStatus.ObservePersistentWebSocketConnectionStatusUseCaseImpl
import com.wire.kalium.logic.feature.user.webSocketStatus.SetPersistentWebSocketForAllUsersUseCase
import com.wire.kalium.logic.feature.user.webSocketStatus.SetPersistentWebSocketForAllUsersUseCaseImpl
import com.wire.kalium.logic.featureFlags.KaliumConfigs
import com.wire.kalium.logic.sync.GlobalWorkScheduler
import com.wire.kalium.logic.sync.WorkSchedulerProvider
Expand Down Expand Up @@ -121,6 +123,9 @@ public class GlobalKaliumScope internal constructor(
public val observePersistentWebSocketConnectionStatus: ObservePersistentWebSocketConnectionStatusUseCase
get() = ObservePersistentWebSocketConnectionStatusUseCaseImpl(sessionRepository)

public val setAllPersistentWebSocketEnabled: SetPersistentWebSocketForAllUsersUseCase
get() = SetPersistentWebSocketForAllUsersUseCaseImpl(sessionRepository)

private val notificationTokenRepository: NotificationTokenRepository
get() = NotificationTokenDataSource(globalPreferences.tokenStorage)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ internal interface SessionRepository {
suspend fun deleteSession(userId: UserId): Either<StorageFailure, Unit>
suspend fun ssoId(userId: UserId): Either<StorageFailure, SsoIdEntity?>
suspend fun updatePersistentWebSocketStatus(userId: UserId, isPersistentWebSocketEnabled: Boolean): Either<StorageFailure, Unit>
suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit>
suspend fun updateSsoIdAndScimInfo(userId: UserId, ssoId: SsoId?, managedBy: ManagedByDTO?): Either<StorageFailure, Unit>
suspend fun isFederated(userId: UserId): Either<StorageFailure, Boolean>
suspend fun getAllValidAccountPersistentWebSocketStatus(): Either<StorageFailure, Flow<List<PersistentWebSocketStatus>>>
Expand Down Expand Up @@ -198,6 +199,9 @@ internal class SessionDataSource internal constructor(
accountsDAO.updatePersistentWebSocketStatus(userId.toDao(), isPersistentWebSocketEnabled)
}

override suspend fun setAllPersistentWebSocketEnabled(enabled: Boolean): Either<StorageFailure, Unit> =
wrapStorageRequest { accountsDAO.setAllAccountsPersistentWebSocketEnabled(enabled) }

override suspend fun updateSsoIdAndScimInfo(
userId: UserId,
ssoId: SsoId?,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.logic.feature.user.webSocketStatus

import com.wire.kalium.common.error.CoreFailure
import com.wire.kalium.common.functional.fold
import com.wire.kalium.logic.data.session.SessionRepository

/**
* This use case is responsible for setting the persistent web socket connection status for all users.
*/
public interface SetPersistentWebSocketForAllUsersUseCase {
/**
* @param enabled true if the persistent web socket connection should be enabled for all users, false otherwise
*/
public suspend operator fun invoke(enabled: Boolean): SetAllPersistentWebSocketEnabledResult
}

public sealed class SetAllPersistentWebSocketEnabledResult {
public data object Success : SetAllPersistentWebSocketEnabledResult()
public data class Failure(val failure: CoreFailure) : SetAllPersistentWebSocketEnabledResult()
}

internal class SetPersistentWebSocketForAllUsersUseCaseImpl(
private val sessionRepository: SessionRepository
) : SetPersistentWebSocketForAllUsersUseCase {
override suspend operator fun invoke(enabled: Boolean): SetAllPersistentWebSocketEnabledResult =
sessionRepository.setAllPersistentWebSocketEnabled(enabled).fold({
SetAllPersistentWebSocketEnabledResult.Failure(it)
}, {
SetAllPersistentWebSocketEnabledResult.Success
})
}
Loading