Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/hazmat/primitives/hpke.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ specifying auxiliary authenticated information.

HKDF-SHA384

.. attribute:: HKDF_SHA512

HKDF-SHA512

.. attribute:: SHAKE128

SHAKE-128

.. class:: AEAD

An enumeration of authenticated encryption algorithms.
Expand Down
1 change: 1 addition & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class KDF:
HKDF_SHA256: KDF
HKDF_SHA384: KDF
HKDF_SHA512: KDF
SHAKE128: KDF

class AEAD:
AES_128_GCM: AEAD
Expand Down
105 changes: 94 additions & 11 deletions src/rust/src/backend/hpke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use pyo3::types::{PyAnyMethods, PyBytesMethods};

use crate::backend::aead::{AesGcm, ChaCha20Poly1305};
use crate::backend::hashes::Hash;
use crate::backend::kdf::{hkdf_extract, HkdfExpand};
use crate::backend::x25519;
use crate::buf::CffiBuf;
Expand All @@ -14,6 +15,15 @@ use crate::{exceptions, types};
const HPKE_VERSION: &[u8] = b"HPKE-v1";
const HPKE_MODE_BASE: u8 = 0x00;

fn u16_length_prefix(length: usize, label: &str) -> CryptographyResult<[u8; 2]> {
let length = u16::try_from(length).map_err(|_| {
CryptographyError::from(pyo3::exceptions::PyValueError::new_err(format!(
"{label} is too large."
)))
})?;
Ok(length.to_be_bytes())
}

mod kem_params {
pub const X25519_ID: u16 = 0x0020;
pub const X25519_NSECRET: usize = 32;
Expand All @@ -24,6 +34,8 @@ mod kdf_params {
pub const HKDF_SHA256_ID: u16 = 0x0001;
pub const HKDF_SHA384_ID: u16 = 0x0002;
pub const HKDF_SHA512_ID: u16 = 0x0003;
pub const SHAKE128_ID: u16 = 0x0010;
pub const SHAKE128_NH: usize = 32;
}

mod aead_params {
Expand Down Expand Up @@ -70,6 +82,7 @@ pub(crate) enum KDF {
HKDF_SHA256,
HKDF_SHA384,
HKDF_SHA512,
SHAKE128,
}

impl KDF {
Expand All @@ -78,18 +91,17 @@ impl KDF {
KDF::HKDF_SHA256 => kdf_params::HKDF_SHA256_ID,
KDF::HKDF_SHA384 => kdf_params::HKDF_SHA384_ID,
KDF::HKDF_SHA512 => kdf_params::HKDF_SHA512_ID,
KDF::SHAKE128 => kdf_params::SHAKE128_ID,
}
}

fn hash_algorithm<'p>(
&self,
py: pyo3::Python<'p>,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
match self {
KDF::HKDF_SHA256 => Ok(types::SHA256.get(py)?.call0()?),
KDF::HKDF_SHA384 => Ok(types::SHA384.get(py)?.call0()?),
KDF::HKDF_SHA512 => Ok(types::SHA512.get(py)?.call0()?),
}
fn nh(&self) -> usize {
debug_assert!(self.is_one_stage());
kdf_params::SHAKE128_NH
}

fn is_one_stage(&self) -> bool {
matches!(self, KDF::SHAKE128)
}
}

Expand Down Expand Up @@ -169,6 +181,20 @@ impl Suite {
hkdf_expand.derive(py, CffiBuf::from_bytes(py, prk))
}

fn hpke_hkdf_hash_algorithm<'p>(
&self,
py: pyo3::Python<'p>,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
debug_assert!(!self.kdf.is_one_stage());
if self.kdf == KDF::HKDF_SHA256 {
Ok(types::SHA256.get(py)?.call0()?)
} else if self.kdf == KDF::HKDF_SHA384 {
Ok(types::SHA384.get(py)?.call0()?)
} else {
Ok(types::SHA512.get(py)?.call0()?)
}
}

fn kem_labeled_extract(
&self,
py: pyo3::Python<'_>,
Expand Down Expand Up @@ -285,7 +311,7 @@ impl Suite {
labeled_ikm.extend_from_slice(label);
labeled_ikm.extend_from_slice(ikm);

let algorithm = self.kdf.hash_algorithm(py)?;
let algorithm = self.hpke_hkdf_hash_algorithm(py)?;
let buf = CffiBuf::from_bytes(py, &labeled_ikm);
hkdf_extract(py, &algorithm.unbind(), salt, &buf)
}
Expand All @@ -305,10 +331,35 @@ impl Suite {
labeled_info.extend_from_slice(&self.hpke_suite_id);
labeled_info.extend_from_slice(label);
labeled_info.extend_from_slice(info);
let algorithm = self.kdf.hash_algorithm(py)?;
let algorithm = self.hpke_hkdf_hash_algorithm(py)?;
Suite::hkdf_expand(py, algorithm, prk, &labeled_info, length)
}

fn hpke_labeled_derive<'p>(
&self,
py: pyo3::Python<'p>,
ikm: &[u8],
label: &[u8],
context: &[u8],
length: usize,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let label_len = u16_length_prefix(label.len(), "label")?;
let mut labeled_ikm = Vec::with_capacity(
ikm.len() + HPKE_VERSION.len() + 10 + 2 + label.len() + 2 + context.len(),
);
labeled_ikm.extend_from_slice(ikm);
labeled_ikm.extend_from_slice(HPKE_VERSION);
labeled_ikm.extend_from_slice(&self.hpke_suite_id);
labeled_ikm.extend_from_slice(&label_len);
labeled_ikm.extend_from_slice(label);
labeled_ikm.extend_from_slice(&(length as u16).to_be_bytes());
labeled_ikm.extend_from_slice(context);
let algorithm = types::SHAKE128.get(py)?.call1((length,))?;
let mut hash = Hash::new(py, &algorithm, None)?;
hash.update_bytes(&labeled_ikm)?;
hash.finalize(py)
}

fn aead_encrypt<'p>(
&self,
py: pyo3::Python<'p>,
Expand Down Expand Up @@ -415,6 +466,38 @@ impl Suite {
pyo3::Bound<'p, pyo3::types::PyBytes>,
pyo3::Bound<'p, pyo3::types::PyBytes>,
)> {
if self.kdf.is_one_stage() {
let shared_secret_len = u16_length_prefix(shared_secret.len(), "shared_secret")?;
let info_len = u16_length_prefix(info.len(), "info")?;

let mut secrets = Vec::with_capacity(4 + shared_secret.len());
secrets.extend_from_slice(&0u16.to_be_bytes());
secrets.extend_from_slice(&shared_secret_len);
secrets.extend_from_slice(shared_secret);

let mut key_schedule_context = Vec::with_capacity(5 + info.len());
key_schedule_context.push(HPKE_MODE_BASE);
key_schedule_context.extend_from_slice(&0u16.to_be_bytes());
key_schedule_context.extend_from_slice(&info_len);
key_schedule_context.extend_from_slice(info);

let key_length = self.aead.key_length();
let nonce_length = self.aead.nonce_length();
let secret = self.hpke_labeled_derive(
py,
&secrets,
b"secret",
&key_schedule_context,
key_length + nonce_length + self.kdf.nh(),
)?;
let secret_bytes = secret.as_bytes();
let key = pyo3::types::PyBytes::new(py, &secret_bytes[..key_length]);
let base_nonce =
pyo3::types::PyBytes::new(py, &secret_bytes[key_length..key_length + nonce_length]);

return Ok((key, base_nonce));
}

let psk_id_hash = self.hpke_labeled_extract(py, None, b"psk_id_hash", b"")?;
let info_hash = self.hpke_labeled_extract(py, None, b"info_hash", info)?;
let mut key_schedule_context = vec![HPKE_MODE_BASE];
Expand Down
2 changes: 2 additions & 0 deletions src/rust/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ pub static SHA384: LazyPyImport =
LazyPyImport::new("cryptography.hazmat.primitives.hashes", &["SHA384"]);
pub static SHA512: LazyPyImport =
LazyPyImport::new("cryptography.hazmat.primitives.hashes", &["SHA512"]);
pub static SHAKE128: LazyPyImport =
LazyPyImport::new("cryptography.hazmat.primitives.hashes", &["SHAKE128"]);

pub static NO_DIGEST_INFO: LazyPyImport = LazyPyImport::new(
"cryptography.hazmat.primitives.asymmetric.utils",
Expand Down
58 changes: 58 additions & 0 deletions tests/hazmat/primitives/test_hpke.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from cryptography.exceptions import InvalidTag
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import x25519
from cryptography.hazmat.primitives.hpke import (
AEAD,
Expand All @@ -30,12 +31,21 @@
)
)

SUPPORTED_SHAKE128_AEADS = [
AEAD.AES_128_GCM,
AEAD.AES_256_GCM,
AEAD.CHACHA20_POLY1305,
]


@pytest.mark.supported(
only_if=lambda backend: backend.x25519_supported(),
skip_message="Requires OpenSSL with X25519 support",
)
class TestHPKE:
def test_shake128_is_available(self):
assert isinstance(KDF.SHAKE128, KDF)

def test_invalid_kem_type(self):
with pytest.raises(TypeError):
Suite("not a kem", KDF.HKDF_SHA256, AEAD.AES_128_GCM) # type: ignore[arg-type]
Expand Down Expand Up @@ -272,3 +282,51 @@ def test_vector_decryption(self, subtests):
suite, ciphertext, sk_r, info=info, aad=aad
)
assert pt == pt_expected

@pytest.mark.supported(
only_if=lambda backend: backend.hash_supported(
hashes.SHAKE128(digest_size=32)
),
skip_message="Does not support SHAKE128",
)
@pytest.mark.parametrize("aead", SUPPORTED_SHAKE128_AEADS)
def test_roundtrip_shake128(self, aead):
suite = Suite(KEM.X25519, KDF.SHAKE128, aead)

sk_r = x25519.X25519PrivateKey.generate()
pk_r = sk_r.public_key()

ciphertext = suite.encrypt(b"Hello, HPKE!", pk_r, info=b"shake128")
plaintext = suite.decrypt(ciphertext, sk_r, info=b"shake128")

assert plaintext == b"Hello, HPKE!"

@pytest.mark.supported(
only_if=lambda backend: backend.hash_supported(
hashes.SHAKE128(digest_size=32)
),
skip_message="Does not support SHAKE128",
)
def test_info_mismatch_fails_shake128(self):
suite = Suite(KEM.X25519, KDF.SHAKE128, AEAD.AES_128_GCM)

sk_r = x25519.X25519PrivateKey.generate()
pk_r = sk_r.public_key()

ciphertext = suite.encrypt(b"Secret", pk_r, info=b"sender info")

with pytest.raises(InvalidTag):
suite.decrypt(ciphertext, sk_r, info=b"different info")

@pytest.mark.supported(
only_if=lambda backend: backend.hash_supported(
hashes.SHAKE128(digest_size=32)
),
skip_message="Does not support SHAKE128",
)
def test_info_too_large_fails_shake128(self):
suite = Suite(KEM.X25519, KDF.SHAKE128, AEAD.AES_128_GCM)
pk_r = x25519.X25519PrivateKey.generate().public_key()

with pytest.raises(ValueError, match="info is too large"):
suite.encrypt(b"test", pk_r, info=b"x" * 65536)