Skip to content

Commit

Permalink
Transit Batch endpoints: short-cut on empty list of inputs. (#66)
Browse files Browse the repository at this point in the history
At present, running with an empty list of inputs will raise an
error because there is no Right answer. To correct this, and to 
avoid an unneeded call, and to keep stronger constraints, 
we adapt the batch endpoints to require a NonEmptyList
  • Loading branch information
diesalbla authored Dec 6, 2019
1 parent a3e0360 commit 7a419af
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 65 deletions.
31 changes: 16 additions & 15 deletions core/src/main/scala/com/banno/vault/transit/Transit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.banno.vault.transit

import cats.syntax.all._
import cats.data.NonEmptyList
import cats.effect.Sync
import org.http4s._
import org.http4s.Method.{GET, POST}
Expand Down Expand Up @@ -64,8 +65,8 @@ object Transit {
*/
def encryptBatch[F[_]: Sync]
(client: Client[F], vaultUri: Uri, token: String, key: KeyName)
(plaintexts: List[PlainText])
: F[List[TransitError.Or[CipherText]]] =
(plaintexts: NonEmptyList[PlainText])
: F[NonEmptyList[TransitError.Or[CipherText]]] =
new TransitClient[F](client, vaultUri, token, key).encryptBatch(plaintexts)

/** Function to decrypt data.
Expand Down Expand Up @@ -98,8 +99,8 @@ object Transit {
*/
def decryptBatch[F[_]: Sync]
(client: Client[F], vaultUri: Uri, token: String, key: KeyName)
(cipherTexts: List[CipherText])
: F[List[TransitError.Or[PlainText]]] =
(cipherTexts: NonEmptyList[CipherText])
: F[NonEmptyList[TransitError.Or[PlainText]]] =
new TransitClient[F](client, vaultUri, token, key).decryptBatch(cipherTexts)

/** Function to decrypt a batch of data, where each ciphertext is accompanied by its context information.
Expand All @@ -112,8 +113,8 @@ object Transit {
*/
def decryptBatchInContext[F[_]: Sync]
(client: Client[F], vaultUri: Uri, token: String, key: KeyName)
(inputs: List[(CipherText, Context)])
: F[List[TransitError.Or[PlainText]]] =
(inputs: NonEmptyList[(CipherText, Context)])
: F[NonEmptyList[TransitError.Or[PlainText]]] =
new TransitClient[F](client, vaultUri, token, key).decryptInContextBatch(inputs)

}
Expand Down Expand Up @@ -187,7 +188,7 @@ final class TransitClient[F[_]](client: Client[F], vaultUri: Uri, token: String,
*
* https://www.vaultproject.io/api/secret/transit/index.html#batch_input
*/
def encryptBatch(plaintexts: List[PlainText]): F[List[TransitError.Or[CipherText]]] = {
def encryptBatch(plaintexts: NonEmptyList[PlainText]): F[NonEmptyList[TransitError.Or[CipherText]]] = {
val payload = EncryptBatchRequest(plaintexts.map(EncryptRequest(_, None)))
encryptBatchAux(payload, "EncryptBatch without context")
}
Expand All @@ -196,12 +197,12 @@ final class TransitClient[F[_]](client: Client[F], vaultUri: Uri, token: String,
*
* https://www.vaultproject.io/api/secret/transit/index.html#batch_input
*/
def encryptInContextBatch(inputs: List[(PlainText, Context)]): F[List[TransitError.Or[CipherText]]] = {
val payload = EncryptBatchRequest(inputs.map { case (pt, ctx) => EncryptRequest(pt, Some(ctx)) })
encryptBatchAux(payload, "EncryptBatch with context")
}
def encryptInContextBatch(inputs: NonEmptyList[(PlainText, Context)]): F[NonEmptyList[TransitError.Or[CipherText]]] = {
val payload = EncryptBatchRequest(inputs.map { case (pt, ctx) => EncryptRequest(pt, Some(ctx)) })
encryptBatchAux(payload, "EncryptBatch with context")
}

private def encryptBatchAux(payload: EncryptBatchRequest, op: String): F[List[TransitError.Or[CipherText]]] = {
private def encryptBatchAux(payload: EncryptBatchRequest, op: String): F[NonEmptyList[TransitError.Or[CipherText]]] = {
val request = postOf(encryptUri, payload)
for {
results <- F.handleErrorWith(client.expect[EncryptBatchResponse](request).map(_.batchResults)) {
Expand Down Expand Up @@ -245,7 +246,7 @@ final class TransitClient[F[_]](client: Client[F], vaultUri: Uri, token: String,
*
* https://www.vaultproject.io/api/secret/transit/index.html#batch_input-2
*/
def decryptBatch(inputs: List[CipherText]): F[List[TransitError.Or[PlainText]]] = {
def decryptBatch(inputs: NonEmptyList[CipherText]): F[NonEmptyList[TransitError.Or[PlainText]]] = {
val payload = DecryptBatchRequest(inputs.map((cipht: CipherText) => DecryptRequest(cipht, None)))
decryptBatchAux(payload, "DecryptBatch without context")
}
Expand All @@ -256,12 +257,12 @@ final class TransitClient[F[_]](client: Client[F], vaultUri: Uri, token: String,
*
* https://www.vaultproject.io/api/secret/transit/index.html#batch_input-2
*/
def decryptInContextBatch(inputs: List[(CipherText, Context)]): F[List[TransitError.Or[PlainText]]] = {
def decryptInContextBatch(inputs: NonEmptyList[(CipherText, Context)]): F[NonEmptyList[TransitError.Or[PlainText]]] = {
val payload = DecryptBatchRequest(inputs.map { case (cipht, ctx) => DecryptRequest(cipht, Some(ctx)) } )
decryptBatchAux(payload, "DecryptBatch with context")
}

private def decryptBatchAux(payload: DecryptBatchRequest, op: String): F[List[TransitError.Or[PlainText]]] = {
private def decryptBatchAux(payload: DecryptBatchRequest, op: String): F[NonEmptyList[TransitError.Or[PlainText]]] = {
val request = postOf(decryptUri, payload)
for {
results <- F.handleErrorWith(client.expect[DecryptBatchResponse](request).map(_.batchResults)){
Expand Down
23 changes: 12 additions & 11 deletions core/src/main/scala/com/banno/vault/transit/models.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.banno.vault.transit

import cats.Eq
import cats.data.NonEmptyList
import cats.kernel.instances.all._
import cats.syntax.eq._
import io.circe.{Decoder, Encoder, Json}
Expand Down Expand Up @@ -144,24 +145,24 @@ private[transit] object EncryptResponse {
Decoder.forProduct1("data")((d: EncryptResult) => EncryptResponse(d))
}

private[transit] final case class EncryptBatchRequest(batchInput: List[EncryptRequest])
private[transit] final case class EncryptBatchRequest(batchInput: NonEmptyList[EncryptRequest])
private[transit] object EncryptBatchRequest {
implicit val eqEncryptBatchRequest: Eq[EncryptBatchRequest] =
Eq.by[EncryptBatchRequest, List[EncryptRequest]](_.batchInput)
Eq.by[EncryptBatchRequest, NonEmptyList[EncryptRequest]](_.batchInput)
implicit val encodeEncryptBatchRequest: Encoder[EncryptBatchRequest] =
Encoder.forProduct1("batch_input")(_.batchInput)
implicit val decodeEncryptBatchRequest: Decoder[EncryptBatchRequest] =
Decoder.forProduct1("batch_input")((bi: List[EncryptRequest]) => EncryptBatchRequest(bi))
Decoder.forProduct1("batch_input")((bi: NonEmptyList[EncryptRequest]) => EncryptBatchRequest(bi))
}

private[transit] final case class EncryptBatchResponse(batchResults: List[TransitError.Or[EncryptResult]])
private[transit] final case class EncryptBatchResponse(batchResults: NonEmptyList[TransitError.Or[EncryptResult]])
private[transit] object EncryptBatchResponse {
implicit val eqEncryptBatchResponse: Eq[EncryptBatchResponse] =
Eq.by(_.batchResults)
implicit val encodeEncryptBatchResponse: Encoder[EncryptBatchResponse] =
Encoder.forProduct1("batch_results")(_.batchResults)
implicit val decodeEncryptBatchResponse: Decoder[EncryptBatchResponse] =
Decoder.forProduct1("batch_results")((br: List[TransitError.Or[EncryptResult]]) => EncryptBatchResponse(br))
Decoder.forProduct1("batch_results")((br: NonEmptyList[TransitError.Or[EncryptResult]]) => EncryptBatchResponse(br))
}

private[transit] final case class DecryptRequest(ciphertext: CipherText, context: Option[Context])
Expand Down Expand Up @@ -224,22 +225,22 @@ private[transit] object TransitError {

}

private[transit] final case class DecryptBatchRequest(batchInput: List[DecryptRequest])
private[transit] final case class DecryptBatchRequest(batchInput: NonEmptyList[DecryptRequest])
private[transit] object DecryptBatchRequest {
implicit val eqDecryptBatchRequest: Eq[DecryptBatchRequest] =
Eq.by[DecryptBatchRequest, List[DecryptRequest]](_.batchInput)
Eq.by[DecryptBatchRequest, NonEmptyList[DecryptRequest]](_.batchInput)
implicit val encodeDecryptBatchRequest: Encoder[DecryptBatchRequest] =
Encoder.forProduct1("batch_input")(_.batchInput)
implicit val decodeDecryptBatchRequest: Decoder[DecryptBatchRequest] =
Decoder.forProduct1("batch_input")((bi: List[DecryptRequest]) => DecryptBatchRequest(bi))
Decoder.forProduct1("batch_input")((bi: NonEmptyList[DecryptRequest]) => DecryptBatchRequest(bi))

}
private[transit] final case class DecryptBatchResponse(batchResults: List[TransitError.Or[DecryptResult]])
private[transit] final case class DecryptBatchResponse(batchResults: NonEmptyList[TransitError.Or[DecryptResult]])
private[transit] object DecryptBatchResponse {
implicit val eqDecryptBatchResponse: Eq[DecryptBatchResponse] =
Eq.by[DecryptBatchResponse, List[TransitError.Or[DecryptResult]]](_.batchResults)
Eq.by[DecryptBatchResponse, NonEmptyList[TransitError.Or[DecryptResult]]](_.batchResults)
implicit val encodeDecryptBatchResponse: Encoder[DecryptBatchResponse] =
Encoder.forProduct1("batch_results")(_.batchResults)
implicit val decodeDecryptBatchResponse: Decoder[DecryptBatchResponse] =
Decoder.forProduct1("batch_results")((br: List[TransitError.Or[DecryptResult]]) => DecryptBatchResponse(br))
Decoder.forProduct1("batch_results")((br: NonEmptyList[TransitError.Or[DecryptResult]]) => DecryptBatchResponse(br))
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.banno.vault.transit

import cats.data.{EitherT, NonEmptyList}
import cats.syntax.eq._
import cats.instances.option._
import cats.effect.Sync
Expand All @@ -24,14 +25,10 @@ import org.http4s.dsl.Http4sDsl
import org.http4s.circe._
import org.http4s.{DecodeFailure, EntityDecoder, HttpApp, Request, Response}
import org.http4s.util.CaseInsensitiveString
import cats.data.EitherT

final class MockTransitService[F[_]: Sync](
keyname: String,
token: String,
encryptCase: EncryptCase,
batchCases: List[EncryptCase]
) extends Http4sDsl[F] {
final class MockTransitService[F[_]: Sync]
(keyname: String, token: String, encryptCases: NonEmptyList[EncryptCase])
extends Http4sDsl[F] {

private implicit val encryptRequestEntityDecoder: EntityDecoder[F, EncryptRequest] = jsonOf
private implicit val decryptRequestEntityDecoder: EntityDecoder[F, DecryptRequest] = jsonOf
Expand Down Expand Up @@ -70,39 +67,40 @@ final class MockTransitService[F[_]: Sync](

private def encryptOne(req: Request[F]): EitherT[F, DecodeFailure, Response[F]] =
req.attemptAs[EncryptRequest].semiflatMap { case encReq =>
if (encryptCase.matches(encReq))
Ok( Json.obj("data" -> encryptResult(encryptCase.ciphertext)))
else
Gone()
encryptCases.find(_.matches(encReq)) match {
case Some(encryptCase) => Ok( Json.obj("data" -> encryptResult(encryptCase.ciphertext)))
case None => Gone()
}
}

private def decryptOne(req: Request[F]): EitherT[F, DecodeFailure, Response[F]] =
req.attemptAs[DecryptRequest].semiflatMap { case decreq =>
if (encryptCase.matches(decreq))
Ok( Json.obj("data" -> decryptResult(encryptCase.plaintext)))
else Gone()
req.attemptAs[DecryptRequest].semiflatMap { case decReq =>
encryptCases.find(_.matches(decReq)) match {
case Some(encryptCase) => Ok( Json.obj("data" -> decryptResult(encryptCase.plaintext)))
case None => Gone()
}
}

private def decryptBatch(req: Request[F]): EitherT[F, DecodeFailure, Response[F]] =
req.attemptAs[DecryptBatchRequest].semiflatMap { case DecryptBatchRequest(inputs) =>
val results: List[Json] = inputs.map { case decreq =>
batchCases.find(_.matches(decreq)) match {
val results: NonEmptyList[Json] = inputs.map { case decreq =>
encryptCases.find(_.matches(decreq)) match {
case None => error("Not known for this context or ciphertext")
case Some(bc) => decryptResult(bc.plaintext)
}
}
Ok(Json.obj("batch_results" -> Json.fromValues(results)))
Ok(Json.obj("batch_results" -> Json.fromValues(results.toList)))
}

private def encryptBatch(req: Request[F]): EitherT[F, DecodeFailure, Response[F]] =
req.attemptAs[EncryptBatchRequest].semiflatMap { case EncryptBatchRequest(inputs) =>
val results: List[Json] = inputs.map { case encReq =>
batchCases.find(_.matches(encReq)) match {
val results: NonEmptyList[Json] = inputs.map { case encReq =>
encryptCases.find(_.matches(encReq)) match {
case None => error("Not known for this context or plaintext")
case Some(bc) => encryptResult(bc.ciphertext)
}
}
Ok(Json.obj("batch_results" -> Json.fromValues(results)))
Ok(Json.obj("batch_results" -> Json.fromValues(results.toList)))
}
}

Expand Down
12 changes: 6 additions & 6 deletions core/src/test/scala/com/banno/vault/transit/ModelsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
package com.banno.vault.transit

import io.circe.Json
import cats.data.NonEmptyList
import cats.implicits._
import org.specs2.{Spec, ScalaCheck}
import org.specs2.specification.core.SpecStructure
import org.scalacheck.Prop
import io.circe.syntax._
import org.scalacheck.Gen

object TransitModelsSpec extends Spec with ScalaCheck {
import TransitGenerators._
Expand Down Expand Up @@ -96,20 +96,20 @@ object TransitModelsSpec extends Spec with ScalaCheck {
"plaintext" -> Json.fromString(pt.value),
"context" -> Json.fromString(ctx.value)
)
})
}.toList)
)
}

val encodeEncryptBatchResponseProp: Prop = Prop.forAll(Gen.listOf(cipherText)){ cts =>
val encodeEncryptBatchResponseProp: Prop = Prop.forAll(nelGen(cipherText)){ cts =>
val json = Json.obj("batch_response" -> Json.fromValues(
cts.map((ct: CipherText) => Json.obj("ciphertext" -> Json.fromString(ct.ciphertext)))
cts.map((ct: CipherText) => Json.obj("ciphertext" -> Json.fromString(ct.ciphertext))).toList
))
EncryptBatchResponse(cts.map((ct: CipherText) => Right(EncryptResult(ct)))).asJson === json
}

val decodeDecryptBatchResponseProp: Prop = Prop.forAll(Gen.listOf(base64)){ (plaintexts: List[Base64]) =>
val decodeDecryptBatchResponseProp: Prop = Prop.forAll(nelGen(base64)){ (plaintexts: NonEmptyList[Base64]) =>
val json = Json.obj( "batch_response" -> Json.fromValues(
plaintexts.map( pt => DecryptResult(PlainText(pt)).asJson)
plaintexts.map( pt => DecryptResult(PlainText(pt)).asJson).toList
))
val expected = DecryptBatchResponse(plaintexts.map( (pt: Base64) => Right(DecryptResult(PlainText(pt)))))
DecryptBatchResponse.decodeDecryptBatchResponse.decodeJson(json) === Right(expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package com.banno.vault.transit

import cats.data.NonEmptyList
import com.banno.vault.VaultArbitraries
import scodec.bits.ByteVector
import org.scalacheck.Gen

object TransitGenerators extends VaultArbitraries {

def nelGen[A](base: Gen[A]): Gen[NonEmptyList[A]] =
Gen.nonEmptyListOf(base).map((xs: List[A]) => NonEmptyList.fromListUnsafe(xs))

// copied from scodec-bits repository.
def standardByteVectors(maxSize: Int): Gen[ByteVector] = for {
size <- Gen.choose(0, maxSize)
Expand All @@ -46,15 +50,15 @@ object TransitGenerators extends VaultArbitraries {
cipherText.map((p: CipherText) => EncryptResult(p))

val genEncryptBatchRequest: Gen[EncryptBatchRequest] =
Gen.listOf(genEncryptRequest).map(ps => EncryptBatchRequest(ps))
nelGen(genEncryptRequest).map(ps => EncryptBatchRequest(ps))

val genAllRightEncryptBatchResponse: Gen[EncryptBatchResponse] =
Gen.listOf(right[TransitError, EncryptResult](encryptResult))
.map( (rps: List[TransitError.Or[EncryptResult]]) => EncryptBatchResponse(rps))
nelGen(right[TransitError, EncryptResult](encryptResult))
.map( (rps: NonEmptyList[TransitError.Or[EncryptResult]]) => EncryptBatchResponse(rps))

val genEncryptBatchResponse: Gen[EncryptBatchResponse] =
Gen.listOf(errorOr(encryptResult))
.map((rps: List[TransitError.Or[EncryptResult]]) => EncryptBatchResponse(rps))
nelGen(errorOr(encryptResult))
.map((rps: NonEmptyList[TransitError.Or[EncryptResult]]) => EncryptBatchResponse(rps))

def some[A](genA: Gen[A]): Gen[Option[A]] = genA.map( (a:A) => Some(a) )
def right[A, B](genB: Gen[B]): Gen[Either[A, B]] = genB.map(b => Right(b))
Expand Down
17 changes: 10 additions & 7 deletions core/src/test/scala/com/banno/vault/transit/TransitSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.banno.vault.transit

import cats.data.NonEmptyList
import cats.effect.IO
import cats.implicits._
import cats.kernel.Eq
Expand All @@ -39,6 +40,8 @@ object TransitSpec extends Spec with ScalaCheck with TransitData {
| decryptBatch may work for all inputs $decryptBatchAllFineSpec
""".stripMargin

import TransitGenerators.nelGen

val encryptSpec: Prop = Prop.forAll(genTestCase){ testCase =>
val transit = new TransitClient[IO](testCase.singleMockClient, testUri, token, KeyName(keyName))
val plainText = PlainText(Order.toBase64(testCase.order))
Expand Down Expand Up @@ -72,30 +75,30 @@ object TransitSpec extends Spec with ScalaCheck with TransitData {
actual.attempt.unsafeRunSync.isLeft
}

val encryptBatchAllFineSpec: Prop = Prop.forAll(Gen.nonEmptyListOf(genTestCase)){ testCases =>
val encryptBatchAllFineSpec: Prop = Prop.forAll(nelGen(genTestCase)){ testCases =>
val encCases = testCases.map(_.encryptCase)
val mockClient = Client.fromHttpApp {
new MockTransitService[IO](keyName, "vaultToken", encCases.head, encCases).routes
new MockTransitService[IO](keyName, "vaultToken", encCases).routes
}
val transit = new TransitClient[IO](mockClient, testUri, token, KeyName(keyName))
val inputs = testCases.map( tc => (tc.plaintext, tc.context))
val actual = transit.encryptInContextBatch(inputs).attempt.unsafeRunSync()
actual.isRight &&
actual.forall(_.forall(_.isRight)) &&
actual.forall(_.zip(testCases).forall { case (res, inp) => res === Right(inp.encrypted) })
actual.forall(_.toList.zip(testCases.toList).forall { case (res, inp) => res === Right(inp.encrypted) })
}

val decryptBatchAllFineSpec: Prop = Prop.forAll(Gen.nonEmptyListOf(genTestCase)){ testCases =>
val decryptBatchAllFineSpec: Prop = Prop.forAll(nelGen(genTestCase)){ testCases =>
val encCases = testCases.map(_.encryptCase)
val mockClient = Client.fromHttpApp {
new MockTransitService[IO](keyName, "vaultToken", encCases.head, encCases).routes
new MockTransitService[IO](keyName, "vaultToken", encCases).routes
}
val transit = new TransitClient[IO](mockClient, testUri, token, KeyName(keyName))
val inputs = testCases.map( tc => (tc.encrypted, tc.context))
val actual = transit.decryptInContextBatch(inputs).attempt.unsafeRunSync()
actual.isRight &&
actual.forall(_.forall(_.isRight)) &&
actual.forall(_.zip(testCases).forall { case (res, inp) => res === Right(inp.plaintext) })
actual.forall(_.toList.zip(testCases.toList).forall { case (res, inp) => res === Right(inp.plaintext) })
}
}

Expand All @@ -108,7 +111,7 @@ trait TransitData {
val context = Context(Agent.toBase64(agent))
def encryptCase: EncryptCase = EncryptCase(plaintext, Some(context), encrypted)
def singleMockClient: Client[IO] = Client.fromHttpApp {
new MockTransitService[IO](keyName, "vaultToken", encryptCase, Nil).routes
new MockTransitService[IO](keyName, "vaultToken", NonEmptyList.of(encryptCase)).routes
}
}

Expand Down

0 comments on commit 7a419af

Please sign in to comment.