Skip to content

Commit

Permalink
refactor(symmetric)!: replace aad with tag
Browse files Browse the repository at this point in the history
  • Loading branch information
jhdcruz committed Jul 28, 2023
1 parent 4a0e7a6 commit 4f9acb4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const val TAG_LENGTH: Int = 128
/**
* Authenticated Encryption with Associated Data.
*
* It incorporates `aad` that provides a cryptographic checksum that can be used to help
* It incorporates `tag` that provides a cryptographic checksum that can be used to help
* validate a decryption such as additional clear text, or associated data used for validation
*
* @param algorithm algorithm to use for KeyGenerator (e.g. AES, ChaCha20).
Expand All @@ -29,7 +29,7 @@ abstract class AEAD(
) : SymmetricEncryption(algorithm, mode) {

/**
* Encrypts the provided [data] along with [aad] (if provided) using [key].
* Encrypts the provided [data] along with [tag] (if provided) using [key].
*
* This is useful for **advanced use cases** if you want finer control.
*
Expand All @@ -43,17 +43,17 @@ abstract class AEAD(
@NotNull data: ByteArray,
@NotNull iv: ByteArray,
@NotNull key: ByteArray,
@NotNull aad: ByteArray = byteArrayOf(),
@NotNull tag: ByteArray = byteArrayOf(),
): Map<String, ByteArray> {
val keySpec = SecretKeySpec(key, algorithm)
val parameterSpec = GCMParameterSpec(TAG_LENGTH, iv)

return cipher.run {
init(Cipher.ENCRYPT_MODE, keySpec, parameterSpec)

// add aad if not empty
if (aad.isNotEmpty()) {
updateAAD(aad)
// add tag if not empty
if (tag.isNotEmpty()) {
updateAAD(tag)
}

doFinal(data)
Expand All @@ -66,7 +66,7 @@ abstract class AEAD(
}

/**
* Decrypts [encrypted] data with optional [aad] verification using [key] nad [iv].
* Decrypts [encrypted] data with optional [tag] verification using [key] nad [iv].
*
* @return Decrypted data
*/
Expand All @@ -75,7 +75,7 @@ abstract class AEAD(
@NotNull encrypted: ByteArray,
@NotNull iv: ByteArray,
@NotNull key: ByteArray,
@NotNull aad: ByteArray = byteArrayOf(),
@NotNull tag: ByteArray = byteArrayOf(),
): ByteArray {
return try {
val keySpec = SecretKeySpec(key, algorithm)
Expand All @@ -84,54 +84,54 @@ abstract class AEAD(
cipher.run {
init(Cipher.DECRYPT_MODE, keySpec, gcmIv)

// check aad if not empty
if (aad.isNotEmpty()) {
updateAAD(aad)
// check tag if not empty
if (tag.isNotEmpty()) {
updateAAD(tag)
}

doFinal(encrypted)
}
} catch (e: AEADBadTagException) {
throw KipherException(
"Invalid additional authenticated data (AAD), data might have been tampered.",
"Invalid additional authenticated data (tag), data might have been tampered.",
e,
)
}
}

/**
* Encrypts the provided [data] along with optional [aad] and [key].
* Encrypts the provided [data] along with optional [tag] and [key].
*
* This method already generates a new key for each encryption.
* [generateKey] is optional.
*
* If you want to use custom keys, and leave [aad] empty,
* If you want to use custom keys, and leave [tag] empty,
* pass an empty [Byte] instead of `null`.
*
* @return Concatenated encrypted data in `[iv, data]` format with `key` and `aad`.
* @return Concatenated encrypted data in `[iv, data]` format with `key` and `tag`.
*/
@JvmOverloads
fun encrypt(
@NotNull data: ByteArray,
@NotNull aad: ByteArray = byteArrayOf(),
@NotNull tag: ByteArray = byteArrayOf(),
@NotNull key: ByteArray = generateKey(),
): Map<String, ByteArray> {
val encrypted = encryptBare(
data = data,
iv = generateIv(),
key = key,
aad = aad,
tag = tag,
).concat()

return mapOf(
"data" to encrypted,
"key" to key,
"aad" to aad,
"tag" to tag,
)
}

/**
* Decrypts [encrypted] data using [key] and [aad] if provided.
* Decrypts [encrypted] data using [key] and [tag] if provided.
*
* This method assumes that the [encrypted] data is in `[iv, data]` format,
* presumably encrypted using [encrypt].
Expand All @@ -140,14 +140,14 @@ abstract class AEAD(
fun decrypt(
@NotNull encrypted: ByteArray,
@NotNull key: ByteArray,
@NotNull aad: ByteArray = byteArrayOf(),
@NotNull tag: ByteArray = byteArrayOf(),
): ByteArray {
encrypted.extract().let { data ->
return decryptBare(
encrypted = data.getValue("data"),
iv = data.getValue("iv"),
key = key,
aad = aad,
tag = tag,
)
}
}
Expand All @@ -157,19 +157,19 @@ abstract class AEAD(
*
* This method assumes that [encrypted] is a [Map] that contains
* concatenated encrypted data in `[iv, data]` format
* with `key` and `aad`, presumably encrypted using [encrypt].
* with `key` and `tag`, presumably encrypted using [encrypt].
*/
fun decrypt(@NotNull encrypted: Map<String, ByteArray>): ByteArray {
val key = encrypted.getValue("key")
val aad = encrypted.getValue("aad")
val tag = encrypted.getValue("tag")
val concatData = encrypted.getValue("data")

concatData.extract().let { data ->
return decryptBare(
encrypted = data.getValue("data"),
iv = data.getValue("iv"),
key = key,
aad = aad,
tag = tag,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

package io.github.jhdcruz.kipher.symmetric

import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.aad
import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.invalidKey
import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.message
import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.tag
import io.github.jhdcruz.kipher.symmetric.aes.AesCBC
import io.github.jhdcruz.kipher.symmetric.aes.AesGCM
import org.junit.jupiter.api.RepeatedTest
Expand Down Expand Up @@ -39,7 +39,7 @@ internal class BaseSymmetricTest {
val aesGcm = AesGCM()

assertThrows<InvalidAlgorithmParameterException> {
aesGcm.encrypt(message, invalidKey, aad)
aesGcm.encrypt(message, invalidKey, tag)
}
}

Expand All @@ -56,10 +56,10 @@ internal class BaseSymmetricTest {
@Test
fun `test authenticated decryption using invalid secret key`() {
val aesGcm = AesGCM()
val encrypted = aesGcm.encrypt(message, aad)
val encrypted = aesGcm.encrypt(message, tag)

assertThrows<InvalidAlgorithmParameterException> {
aesGcm.decrypt(encrypted["data"]!!, invalidKey, encrypted["aad"]!!)
aesGcm.decrypt(encrypted["data"]!!, invalidKey, encrypted["tag"]!!)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
package io.github.jhdcruz.kipher.symmetric

import io.github.jhdcruz.kipher.core.Format.toHexString
import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.aad
import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.message
import io.github.jhdcruz.kipher.symmetric.SymmetricTestParams.tag
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import java.util.stream.Stream
Expand Down Expand Up @@ -65,10 +65,10 @@ internal class SymmetricEncryptionTest {

@ParameterizedTest
@MethodSource("io.github.jhdcruz.kipher.symmetric.SymmetricTestParams#getAeadClasses")
fun `AEAD encryption test with AAD`(encryptionClass: Class<out AEAD>) {
fun `AEAD encryption test with tag`(encryptionClass: Class<out AEAD>) {
val encryption = encryptionClass.getDeclaredConstructor().newInstance()

val encrypted = encryption.encrypt(message, aad)
val encrypted = encryption.encrypt(message, tag)
val decrypted = encryption.decrypt(encrypted)

println("${encryption.mode} = ${encrypted["data"]!!.size}")
Expand All @@ -84,9 +84,9 @@ internal class SymmetricEncryptionTest {
fun `AEAD encryption test with parameters`(encryptionClass: Class<out AEAD>) {
val encryption = encryptionClass.getDeclaredConstructor().newInstance()

val encrypted = encryption.encrypt(message, aad)
val encrypted = encryption.encrypt(message, tag)
val decrypted =
encryption.decrypt(encrypted["data"]!!, encrypted["key"]!!, encrypted["aad"]!!)
encryption.decrypt(encrypted["data"]!!, encrypted["key"]!!, encrypted["tag"]!!)

assertEquals(message.decodeToString(), decrypted.decodeToString())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.util.stream.Stream
*/
internal object SymmetricTestParams {
val message = "message".encodeToByteArray()
val aad = "metadata".encodeToByteArray()
val tag = "metadata".encodeToByteArray()
val invalidKey = "invalid-key".encodeToByteArray()

/**
Expand Down

0 comments on commit 4f9acb4

Please sign in to comment.