Skip to content

Commit

Permalink
WIP: SnapshotManager
Browse files Browse the repository at this point in the history
  • Loading branch information
grote committed Sep 13, 2024
1 parent c8dde29 commit fa1579e
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ import com.stevesoltys.seedvault.header.VERSION
import com.stevesoltys.seedvault.proto.Snapshot
import com.stevesoltys.seedvault.transport.restore.Loader
import io.github.oshai.kotlinlogging.KotlinLogging
import okio.Buffer
import okio.buffer
import okio.sink
import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.toHexString
import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException

/**
* Manages interactions with snapshots, such as loading, saving and removing them.
* Also keeps a reference to the [latestSnapshot] that holds important re-usable data.
*/
internal class SnapshotManager(
private val snapshotFolder: File,
private val crypto: Crypto,
private val loader: Loader,
private val backendManager: BackendManager,
Expand All @@ -32,37 +38,73 @@ internal class SnapshotManager(
var latestSnapshot: Snapshot? = null
private set

/**
* Call this before starting a backup run with the [handles] of snapshots
* currently available on the backend.
*/
suspend fun onSnapshotsLoaded(handles: List<AppBackupFileType.Snapshot>): List<Snapshot> {
return handles.map { snapshotHandle ->
// TODO set up local snapshot cache, so we don't need to download those all the time
// TODO is it a fatal error when one snapshot is corrupted or couldn't get loaded?
val snapshot = loader.loadFile(snapshotHandle).use { inputStream ->
Snapshot.parseFrom(inputStream)
}
val snapshot = loadSnapshot(snapshotHandle)
// update latest snapshot if this one is more recent
if (snapshot.token > (latestSnapshot?.token ?: 0)) latestSnapshot = snapshot
snapshot
}
}

/**
* Saves the given [snapshot] to the backend and local cache.
*
* @throws IOException or others if saving fails.
*/
@Throws(IOException::class)
suspend fun saveSnapshot(snapshot: Snapshot) {
val buffer = Buffer()
val bufferStream = buffer.outputStream()
bufferStream.write(VERSION.toInt())
crypto.newEncryptingStream(bufferStream, crypto.getAdForVersion()).use { cryptoStream ->
val byteStream = ByteArrayOutputStream()
byteStream.write(VERSION.toInt())
crypto.newEncryptingStream(byteStream, crypto.getAdForVersion()).use { cryptoStream ->
ZstdOutputStream(cryptoStream).use { zstdOutputStream ->
snapshot.writeTo(zstdOutputStream)
}
}
val sha256ByteString = buffer.sha256()
val handle = AppBackupFileType.Snapshot(crypto.repoId, sha256ByteString.hex())
// TODO exception handling
backendManager.backend.save(handle).use { outputStream ->
outputStream.sink().buffer().apply {
writeAll(buffer)
flush() // needs flushing
val bytes = byteStream.toByteArray()
val sha256 = crypto.sha256(bytes).toHexString()
val snapshotHandle = AppBackupFileType.Snapshot(crypto.repoId, sha256)
backendManager.backend.save(snapshotHandle).use { outputStream ->
outputStream.write(bytes)
}
// save to local cache while at it
try {
File(snapshotFolder, snapshotHandle.name).outputStream().use { outputStream ->
outputStream.write(bytes)
}
} catch (e: Exception) { // we'll let this one pass
log.error(e) { "Error saving snapshot ${snapshotHandle.hash} to cache" }
}
}

/**
* Removes the snapshot referenced by the given [snapshotHandle] from the backend
* and local cache.
*/
suspend fun removeSnapshot(snapshotHandle: AppBackupFileType.Snapshot) {
backendManager.backend.remove(snapshotHandle)
// remove from cache as well
File(snapshotFolder, snapshotHandle.name).delete()
}

/**
* Loads and parses the snapshot referenced by the given [snapshotHandle].
* If a locally cached version exists, the backend will not be hit.
*/
private suspend fun loadSnapshot(snapshotHandle: AppBackupFileType.Snapshot): Snapshot {
val file = File(snapshotFolder, snapshotHandle.name)
val inputStream = if (file.isFile) {
loader.loadFile(file, snapshotHandle.hash)
} else {
loader.loadFile(snapshotHandle, file)
}
return inputStream.use { Snapshot.parseFrom(it) }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,31 @@ internal class AppBackupManager(
blobCache.populateCache(blobInfos, snapshots)
}

suspend fun afterBackupFinished(success: Boolean) {
suspend fun afterBackupFinished(success: Boolean): Boolean {
log.info { "After backup finished. Success: $success" }
MemoryLogger.log()
// free up memory by clearing blobs cache
blobCache.clear()
var result = false
try {
if (success) {
val snapshot =
snapshotCreator?.finalizeSnapshot() ?: error("Had no snapshotCreator")
keepTrying {
keepTrying { // saving this is so important, we even keep trying
snapshotManager.saveSnapshot(snapshot)
}
settingsManager.token = snapshot.token
// after snapshot was written, we can clear local cache as its info is in snapshot
blobCache.clearLocalCache()
}
result = true
} catch (e: Exception) {
log.error(e) { "Error finishing backup" }
} finally {
snapshotCreator = null
}
MemoryLogger.log()
return result
}

private suspend fun keepTrying(n: Int = 3, block: suspend () -> Unit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ package com.stevesoltys.seedvault.transport.backup
import com.stevesoltys.seedvault.transport.SnapshotManager
import org.koin.android.ext.koin.androidContext
import org.koin.dsl.module
import java.io.File

val backupModule = module {
single { BackupInitializer(get()) }
single { BackupReceiver(get(), get(), get()) }
single { BlobCache(androidContext()) }
single { BlobCreator(get(), get()) }
single { SnapshotManager(get(), get(), get()) }
single {
val snapshotFolder = File(androidContext().filesDir, "snapshots")
SnapshotManager(snapshotFolder, get(), get(), get())
}
single { SnapshotCreatorFactory(androidContext(), get(), get(), get()) }
single { InputFactory() }
single {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@

package com.stevesoltys.seedvault.transport.restore

import com.android.internal.R.attr.handle
import com.github.luben.zstd.ZstdInputStream
import com.stevesoltys.seedvault.backend.BackendManager
import com.stevesoltys.seedvault.crypto.Crypto
import com.stevesoltys.seedvault.header.UnsupportedVersionException
import com.stevesoltys.seedvault.header.VERSION
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.runBlocking
import org.calyxos.seedvault.core.backends.AppBackupFileType
import org.calyxos.seedvault.core.toHexString
import java.io.ByteArrayInputStream
import java.io.File
import java.io.InputStream
import java.io.SequenceInputStream
import java.security.GeneralSecurityException
Expand All @@ -24,48 +27,78 @@ internal class Loader(
private val backendManager: BackendManager,
) {

private val log = KotlinLogging.logger {}

/**
* Downloads the given [fileHandle], decrypts and decompresses its content
* and returns the content as a decrypted and decompressed stream.
*
* Attention: The responsibility with closing the returned stream lies with the caller.
*
* @param cacheFile if non-null, the ciphertext of the loaded file will be cached there
* for later loading with [loadFile].
*/
suspend fun loadFile(fileHandle: AppBackupFileType, cacheFile: File? = null): InputStream {
val expectedHash = when (fileHandle) {
is AppBackupFileType.Snapshot -> fileHandle.hash
is AppBackupFileType.Blob -> fileHandle.name
}
return loadFromStream(backendManager.backend.load(fileHandle), expectedHash, cacheFile)
}

/**
* The responsibility with closing the returned stream lies with the caller.
*/
suspend fun loadFile(handle: AppBackupFileType): InputStream {
fun loadFile(file: File, expectedHash: String): InputStream {
return loadFromStream(file.inputStream(), expectedHash)
}

suspend fun loadFiles(handles: List<AppBackupFileType>): InputStream {
val enumeration: Enumeration<InputStream> = object : Enumeration<InputStream> {
val iterator = handles.iterator()

override fun hasMoreElements(): Boolean {
return iterator.hasNext()
}

override fun nextElement(): InputStream {
return runBlocking { loadFile(iterator.next()) }
}
}
return SequenceInputStream(enumeration)
}

private fun loadFromStream(
inputStream: InputStream,
expectedHash: String,
cacheFile: File? = null,
): InputStream {
// We load the entire ciphertext into memory,
// so we can check the SHA-256 hash before decrypting and parsing the data.
val cipherText = backendManager.backend.load(handle).use { inputStream ->
inputStream.readAllBytes()
}
val cipherText = inputStream.use { it.readAllBytes() }
// check SHA-256 hash first thing
val sha256 = crypto.sha256(cipherText).toHexString()
val expectedHash = when (handle) {
is AppBackupFileType.Snapshot -> handle.hash
is AppBackupFileType.Blob -> handle.name
}
if (sha256 != expectedHash) {
throw GeneralSecurityException("File had wrong SHA-256 hash: $handle")
}
// check that we can handle the version of that snapshot
val version = cipherText[0]
if (version <= 1) throw GeneralSecurityException("Unexpected version: $version")
if (version > VERSION) throw UnsupportedVersionException(version)
// cache ciperText in cacheFile, if existing
try {
cacheFile?.outputStream()?.use { outputStream ->
outputStream.write(cipherText)
}
} catch (e: Exception) {
log.error(e) { "Error writing cache file $cacheFile" }
}
// get associated data for version, used for authenticated decryption
val ad = crypto.getAdForVersion(version)
// skip first version byte when creating cipherText stream
val inputStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1)
val byteStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1)
// decrypt and decompress cipherText stream and parse snapshot
return ZstdInputStream(crypto.newDecryptingStream(inputStream, ad))
return ZstdInputStream(crypto.newDecryptingStream(byteStream, ad))
}

suspend fun loadFiles(handles: List<AppBackupFileType>): InputStream {
val enumeration: Enumeration<InputStream> = object : Enumeration<InputStream> {
val iterator = handles.iterator()

override fun hasMoreElements(): Boolean {
return iterator.hasNext()
}

override fun nextElement(): InputStream {
return runBlocking { loadFile(iterator.next()) }
}
}
return SequenceInputStream(enumeration)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import android.app.backup.IBackupObserver
import android.content.Context
import android.content.pm.ApplicationInfo.FLAG_SYSTEM
import android.content.pm.PackageManager.NameNotFoundException
import android.os.Looper
import android.util.Log
import android.util.Log.INFO
import android.util.Log.isLoggable
Expand Down Expand Up @@ -136,19 +137,18 @@ internal class NotificationBackupObserver(
if (isLoggable(TAG, INFO)) {
Log.i(TAG, "Backup finished $numPackages/$requestedPackages. Status: $status")
}
val success = status == 0
var success = status == 0
val size = if (success) metadataManager.getPackagesBackupSize() else 0L
val total = try {
packageService.allUserPackages.size
} catch (e: Exception) {
Log.e(TAG, "Error getting number of all user packages: ", e)
requestedPackages
}
// TODO handle exceptions
runBlocking {
// TODO check if UI thread
Log.d("TAG", "Finalizing backup...")
appBackupManager.afterBackupFinished(success)
check(!Looper.getMainLooper().isCurrentThread)
Log.d(TAG, "Finalizing backup...")
success = appBackupManager.afterBackupFinished(success)
}
nm.onBackupFinished(success, numPackagesToReport, total, size)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ import org.calyxos.seedvault.core.toHexString
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.io.TempDir
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.File
import java.io.InputStream
import java.io.OutputStream
import java.nio.file.Path
import java.security.MessageDigest
import kotlin.random.Random

Expand All @@ -32,15 +35,28 @@ internal class SnapshotManagerTest : TransportTest() {
private val backend: Backend = mockk()

private val loader = Loader(crypto, backendManager) // need a real loader
private val snapshotManager = SnapshotManager(crypto, loader, backendManager)

private val ad = Random.nextBytes(1)
private val passThroughOutputStream = slot<OutputStream>()
private val passThroughInputStream = slot<InputStream>()
private val snapshotHandle = slot<AppBackupFileType.Snapshot>()

// @Test
// fun `test onSnapshotsLoaded sets latestSnapshot`(@TempDir tmpDir: Path) = runBlocking {
// val snapshotManager = getSnapshotManager(File(tmpDir.toString()))
//
// val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1)
// val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2)
// snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2))
// Unit
// }

@Test
fun `test saving and loading`() = runBlocking {
fun `test saving and loading`(@TempDir tmpDir: Path) = runBlocking {
val snapshotManager = getSnapshotManager(File(tmpDir.toString()))

val messageDigest = MessageDigest.getInstance("SHA-256")
val bytes = slot<ByteArray>()
val outputStream = ByteArrayOutputStream()

every { crypto.getAdForVersion() } returns ad
Expand All @@ -49,12 +65,16 @@ internal class SnapshotManagerTest : TransportTest() {
}
every { crypto.repoId } returns repoId
every { backendManager.backend } returns backend
every { crypto.sha256(capture(bytes)) } answers {
messageDigest.digest(bytes.captured)
}
coEvery { backend.save(capture(snapshotHandle)) } returns outputStream

snapshotManager.saveSnapshot(snapshot)

println(snapshotHandle.captured)

// check that file content hash matches snapshot hash
val messageDigest = MessageDigest.getInstance("SHA-256")
assertEquals(
messageDigest.digest(outputStream.toByteArray()).toHexString(),
snapshotHandle.captured.hash,
Expand All @@ -75,4 +95,8 @@ internal class SnapshotManagerTest : TransportTest() {
assertEquals(snapshot, snapshots[0])
}
}

private fun getSnapshotManager(tmpFolder: File): SnapshotManager {
return SnapshotManager(tmpFolder, crypto, loader, backendManager)
}
}

0 comments on commit fa1579e

Please sign in to comment.