diff --git a/app/src/main/java/tech/relaycorp/courier/domain/StoreMessage.kt b/app/src/main/java/tech/relaycorp/courier/domain/StoreMessage.kt index 1385563c..f4d770de 100644 --- a/app/src/main/java/tech/relaycorp/courier/domain/StoreMessage.kt +++ b/app/src/main/java/tech/relaycorp/courier/domain/StoreMessage.kt @@ -11,6 +11,7 @@ import tech.relaycorp.courier.data.model.StoredMessage import tech.relaycorp.relaynet.cogrpc.readBytesAndClose import tech.relaycorp.relaynet.messages.Cargo import tech.relaycorp.relaynet.messages.CargoCollectionAuthorization +import tech.relaycorp.relaynet.messages.InvalidMessageException import tech.relaycorp.relaynet.ramf.RAMFException import tech.relaycorp.relaynet.ramf.RAMFMessage import java.io.InputStream @@ -38,10 +39,24 @@ class StoreMessage } try { - cargo.validate(null) + cargo.validate( + when (recipientType) { + GatewayType.Internet -> null + GatewayType.Private -> + cargo.recipientCertificate + ?.let { setOf(it) } + ?: run { + logger.warning("Invalid cargo received with missing recipient certificate") + return Result.Error.Invalid + } + }, + ) } catch (exc: RAMFException) { logger.warning("Invalid cargo received: ${exc.message}") return Result.Error.Invalid + } catch (exc: InvalidMessageException) { + logger.warning("Invalid cargo received: ${exc.message}") + return Result.Error.Invalid } return storeMessage(MessageType.Cargo, cargo, cargoBytes, recipientType) diff --git a/app/src/test/java/tech/relaycorp/courier/domain/StoreMessageTest.kt b/app/src/test/java/tech/relaycorp/courier/domain/StoreMessageTest.kt index 41863f38..3b4227ca 100644 --- a/app/src/test/java/tech/relaycorp/courier/domain/StoreMessageTest.kt +++ b/app/src/test/java/tech/relaycorp/courier/domain/StoreMessageTest.kt @@ -11,7 +11,6 @@ import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import tech.relaycorp.courier.data.database.StoredMessageDao @@ -20,6 +19,7 @@ import tech.relaycorp.courier.data.model.GatewayType import tech.relaycorp.courier.data.model.StorageSize import tech.relaycorp.courier.data.model.StorageUsage import tech.relaycorp.courier.data.model.StoredMessage +import tech.relaycorp.relaynet.issueDeliveryAuthorization import tech.relaycorp.relaynet.messages.Cargo import tech.relaycorp.relaynet.messages.CargoCollectionAuthorization import tech.relaycorp.relaynet.messages.Recipient @@ -145,7 +145,10 @@ class StoreMessageTest { invalidCargo.serialize(KeyPairSet.PRIVATE_GW.private) val result = - subject.storeCargo(invalidCargoSerialized.inputStream(), GatewayType.Internet) + subject.storeCargo( + invalidCargoSerialized.inputStream(), + GatewayType.Internet, + ) assertEquals(StoreMessage.Result.Error.Invalid, result) verify(diskRepository, never()).writeMessage(any()) @@ -163,7 +166,8 @@ class StoreMessageTest { ) val cargoSerialized = cargo.serialize(KeyPairSet.PRIVATE_GW.private) - val result = subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Internet) + val result = + subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Internet) assertTrue(result is StoreMessage.Result.Success) verify(diskRepository).writeMessage(any()) @@ -185,7 +189,8 @@ class StoreMessageTest { ) val cargoSerialized = cargo.serialize(KeyPairSet.PRIVATE_GW.private) - val result = subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Internet) + val result = + subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Internet) assertTrue(result is StoreMessage.Result.Success) verify(diskRepository).writeMessage(any()) @@ -195,21 +200,30 @@ class StoreMessageTest { @Nested inner class BoundForPrivateGateway { - private val recipient = Recipient(KeyPairSet.PRIVATE_GW.public.nodeId) + private val recipient = Recipient(CDACertPath.PRIVATE_GW.subjectPublicKey.nodeId) + private val senderCertificate = + issueDeliveryAuthorization( + KeyPairSet.INTERNET_GW.public, + KeyPairSet.PRIVATE_GW.private, + ZonedDateTime.now().plusHours(1), + CDACertPath.PRIVATE_GW, + validityStartDate = ZonedDateTime.now().minusMinutes(1), + ) @Test - @Disabled // See https://github.com/relaycorp/relaynet-courier-android/issues/255 fun `Unauthorized cargo should be refused`() = runTest { val cargo = Cargo( recipient.copy(id = "${recipient.id}abc"), "payload".toByteArray(), - CDACertPath.INTERNET_GW, + senderCertificate, + senderCertificateChain = setOf(CDACertPath.PRIVATE_GW), ) val cargoSerialized = cargo.serialize(KeyPairSet.INTERNET_GW.private) - val result = subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Private) + val result = + subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Private) assertEquals(StoreMessage.Result.Error.Invalid, result) verify(diskRepository, never()).writeMessage(any()) @@ -223,11 +237,13 @@ class StoreMessageTest { Cargo( recipient, "payload".toByteArray(), - CDACertPath.INTERNET_GW, + senderCertificate, + senderCertificateChain = setOf(CDACertPath.PRIVATE_GW), ) val cargoSerialized = cargo.serialize(KeyPairSet.INTERNET_GW.private) - val result = subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Private) + val result = + subject.storeCargo(cargoSerialized.inputStream(), GatewayType.Private) assertTrue(result is StoreMessage.Result.Success) verify(diskRepository).writeMessage(any())