Skip to content

Commit

Permalink
RFC
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Feb 7, 2025
1 parent 6ed1729 commit d67250d
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 63 deletions.
2 changes: 1 addition & 1 deletion manul/src/dev/session_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<D: digest::Digest> signature::DigestVerifier<D, TestSignature> for TestVeri
#[derive(Debug, Clone, Copy, Default)]
pub struct TestHasher {
cursor: usize,
buffer: [u8; 32],
buffer: digest::Output<Self>,
}

impl digest::HashMarker for TestHasher {}
Expand Down
1 change: 1 addition & 0 deletions manul/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ pub use round::{
pub use serialization::{Deserializer, Serializer};

pub(crate) use errors::ReceiveErrorType;
pub(crate) use message::ProtocolMessagePartHashable;
pub(crate) use object_safe::{BoxedRng, ObjectSafeRound};
49 changes: 49 additions & 0 deletions manul/src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::string::{String, ToString};

use digest::Digest;
use serde::{Deserialize, Serialize};

use super::{
Expand Down Expand Up @@ -121,6 +122,42 @@ pub trait ProtocolMessagePart: ProtocolMessageWrapper {
}
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub(crate) enum PartKind {
EchoBroadcast,
NormalBroadcast,
DirectMessage,
}

pub(crate) trait HasPartKind {
const KIND: PartKind;
}

// We don't want to expose this functionality to the user, so it is separate from `ProtocolMessagePart` trait.
pub(crate) trait ProtocolMessagePartHashable: ProtocolMessagePart + HasPartKind {
fn hash<D: Digest>(&self) -> digest::Output<D> {
let mut digest = D::new_with_prefix(b"ProtocolMessagePart");
match Self::KIND {
PartKind::EchoBroadcast => digest.update([0u8]),
PartKind::NormalBroadcast => digest.update([1u8]),
PartKind::DirectMessage => digest.update([2u8]),
}
match self.maybe_message().as_ref() {
None => digest.update([0u8]),
Some(payload) => {
let payload_len =
u64::try_from(payload.as_ref().len()).expect("payload length does not exceed 18 exabytes");
digest.update([1u8]);
digest.update(payload_len.to_be_bytes());
digest.update(payload);
}
};
digest.finalize()
}
}

impl<T: ProtocolMessagePart + HasPartKind> ProtocolMessagePartHashable for T {}

/// A serialized direct message.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DirectMessage(Option<MessagePayload>);
Expand All @@ -135,6 +172,10 @@ impl ProtocolMessageWrapper for DirectMessage {
}
}

impl HasPartKind for DirectMessage {
const KIND: PartKind = PartKind::DirectMessage;
}

impl ProtocolMessagePart for DirectMessage {
type Error = DirectMessageError;
}
Expand All @@ -153,6 +194,10 @@ impl ProtocolMessageWrapper for EchoBroadcast {
}
}

impl HasPartKind for EchoBroadcast {
const KIND: PartKind = PartKind::EchoBroadcast;
}

impl ProtocolMessagePart for EchoBroadcast {
type Error = EchoBroadcastError;
}
Expand All @@ -171,6 +216,10 @@ impl ProtocolMessageWrapper for NormalBroadcast {
}
}

impl HasPartKind for NormalBroadcast {
const KIND: PartKind = PartKind::NormalBroadcast;
}

impl ProtocolMessagePart for NormalBroadcast {
type Error = NormalBroadcastError;
}
Expand Down
6 changes: 3 additions & 3 deletions manul/src/session/evidence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
protocol::{
Deserializer, DirectMessage, DirectMessageError, EchoBroadcast, EchoBroadcastError, MessageValidationError,
NormalBroadcast, NormalBroadcastError, Protocol, ProtocolError, ProtocolMessage, ProtocolMessagePart,
ProtocolValidationError, RoundId,
ProtocolMessagePartHashable, ProtocolValidationError, RoundId,
},
utils::SerializableMap,
};
Expand Down Expand Up @@ -438,7 +438,7 @@ fn verify_message_parts<SP, T>(
) -> Result<BTreeMap<RoundId, T>, EvidenceError>
where
SP: SessionParameters,
T: Clone + ProtocolMessagePart,
T: Clone + ProtocolMessagePartHashable,
{
let mut verified_parts = BTreeMap::new();
for (round_id, message_part) in message_parts.iter() {
Expand All @@ -462,7 +462,7 @@ fn verify_message_part<SP, T>(
) -> Result<T, EvidenceError>
where
SP: SessionParameters,
T: Clone + ProtocolMessagePart,
T: Clone + ProtocolMessagePartHashable,
{
let verified_part = if let Some(message_part) = message_part {
let metadata = message_part.metadata();
Expand Down
106 changes: 47 additions & 59 deletions manul/src/session/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ use super::{
wire_format::WireFormat,
LocalError,
};
use crate::protocol::{
DeserializationError, DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart, RoundId,
};
use crate::protocol::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePartHashable, RoundId};

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct SerializedSignature(#[serde(with = "SliceLike::<Hex>")] Box<[u8]>);
struct SerializedSignature(#[serde(with = "SliceLike::<Hex>")] Box<[u8]>);

impl SerializedSignature {
pub fn new<SP>(signature: SP::Signature) -> Result<Self, LocalError>
Expand All @@ -26,11 +24,11 @@ impl SerializedSignature {
SP::WireFormat::serialize(signature).map(Self)
}

pub fn deserialize<SP>(&self) -> Result<SP::Signature, DeserializationError>
pub fn deserialize<SP>(&self) -> Result<SP::Signature, MessageVerificationError>
where
SP: SessionParameters,
{
SP::WireFormat::deserialize::<SP::Signature>(&self.0)
SP::WireFormat::deserialize::<SP::Signature>(&self.0).map_err(|_| MessageVerificationError::InvalidSignature)
}
}

Expand All @@ -43,6 +41,18 @@ pub(crate) enum MessageVerificationError {
SignatureMismatch,
}

impl From<LocalError> for MessageVerificationError {
fn from(source: LocalError) -> Self {
Self::Local(source)
}
}

impl From<signature::Error> for MessageVerificationError {
fn from(_source: signature::Error) -> Self {
Self::SignatureMismatch
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct SignedMessagePart<M> {
signature: SerializedSignature,
Expand Down Expand Up @@ -78,38 +88,32 @@ pub struct MessageWithMetadata<M> {
message: M,
}

fn payload_hash<SP: SessionParameters>(message: &impl ProtocolMessagePart) -> digest::Output<SP::Digest> {
let digest = SP::Digest::new_with_prefix(b"PayloadHash");
let digest = match message.maybe_message().as_ref() {
None => digest.chain_update([0u8]),
Some(payload) => digest.chain_update([1u8]).chain_update(payload),
};
digest.finalize()
}

fn message_digest<SP: SessionParameters>(
metadata: &MessageMetadata,
payload_hash: &[u8],
message_part_hash: &[u8],
) -> Result<SP::Digest, LocalError> {
let message_part_hash_len =
u64::try_from(message_part_hash.as_ref().len()).expect("message part hash length does not exceed 18 exabytes");
Ok(SP::Digest::new_with_prefix(b"SignedMessagePartDigest")
.chain_update(SP::WireFormat::serialize(metadata)?)
.chain_update(payload_hash))
.chain_update(message_part_hash_len.to_be_bytes())
.chain_update(message_part_hash))
}

impl<M: ProtocolMessagePart> MessageWithMetadata<M> {
impl<M: ProtocolMessagePartHashable> MessageWithMetadata<M> {
fn digest<SP>(&self) -> Result<SP::Digest, LocalError>
where
SP: SessionParameters,
{
let payload_hash = payload_hash::<SP>(&self.message);
let digest = message_digest::<SP>(&self.metadata, &payload_hash)?;
let message_part_hash = self.message.hash::<SP::Digest>();
let digest = message_digest::<SP>(&self.metadata, &message_part_hash)?;
Ok(digest)
}
}

impl<M> SignedMessagePart<M>
where
M: ProtocolMessagePart,
M: ProtocolMessagePartHashable,
{
pub fn new<SP>(
rng: &mut impl CryptoRngCore,
Expand Down Expand Up @@ -137,11 +141,11 @@ where
where
SP: SessionParameters,
{
let payload_hash = payload_hash::<SP>(&self.message_with_metadata.message);
let message_part_hash = self.message_with_metadata.message.hash::<SP::Digest>();
SignedMessageHash {
signature: self.signature.clone(),
metadata: self.message_with_metadata.metadata.clone(),
payload_hash: payload_hash.as_ref().into(),
message_part_hash: message_part_hash.as_ref().into(),
}
}

Expand All @@ -157,22 +161,13 @@ where
where
SP: SessionParameters,
{
let digest = self
.message_with_metadata
.digest::<SP>()
.map_err(MessageVerificationError::Local)?;
let signature = self
.signature
.deserialize::<SP>()
.map_err(|_| MessageVerificationError::InvalidSignature)?;
if verifier.verify_digest(digest, &signature).is_ok() {
Ok(VerifiedMessagePart {
signature: self.signature,
message_with_metadata: self.message_with_metadata,
})
} else {
Err(MessageVerificationError::SignatureMismatch)
}
let digest = self.message_with_metadata.digest::<SP>()?;
let signature = self.signature.deserialize::<SP>()?;
verifier.verify_digest(digest, &signature)?;
Ok(VerifiedMessagePart {
signature: self.signature,
message_with_metadata: self.message_with_metadata,
})
}
}

Expand Down Expand Up @@ -354,7 +349,7 @@ pub(crate) struct SignedMessageHash {
signature: SerializedSignature,
metadata: MessageMetadata,
#[serde(with = "SliceLike::<Hex>")]
payload_hash: Box<[u8]>,
message_part_hash: Box<[u8]>,
}

impl SignedMessageHash {
Expand All @@ -366,29 +361,22 @@ impl SignedMessageHash {
where
SP: SessionParameters,
{
let digest =
message_digest::<SP>(&self.metadata, &self.payload_hash).map_err(MessageVerificationError::Local)?;
let signature = self
.signature
.deserialize::<SP>()
.map_err(|_| MessageVerificationError::InvalidSignature)?;
if verifier.verify_digest(digest, &signature).is_ok() {
Ok(VerifiedMessageHash {
signature: self.signature,
metadata: self.metadata,
payload_hash: self.payload_hash,
})
} else {
Err(MessageVerificationError::SignatureMismatch)
}
let digest = message_digest::<SP>(&self.metadata, &self.message_part_hash)?;
let signature = self.signature.deserialize::<SP>()?;
verifier.verify_digest(digest, &signature)?;
Ok(VerifiedMessageHash {
signature: self.signature,
metadata: self.metadata,
message_part_hash: self.message_part_hash,
})
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct VerifiedMessageHash {
signature: SerializedSignature,
metadata: MessageMetadata,
payload_hash: Box<[u8]>,
message_part_hash: Box<[u8]>,
}

impl VerifiedMessageHash {
Expand All @@ -399,9 +387,9 @@ impl VerifiedMessageHash {
pub(crate) fn is_hash_of<SP, M>(&self, message: &SignedMessagePart<M>) -> bool
where
SP: SessionParameters,
M: ProtocolMessagePart,
M: ProtocolMessagePartHashable,
{
let payload_hash = payload_hash::<SP>(&message.message_with_metadata.message);
payload_hash.as_ref() == self.payload_hash.as_ref()
let message_part_hash = message.message_with_metadata.message.hash::<SP::Digest>();
message_part_hash.as_ref() == self.message_part_hash.as_ref()
}
}

0 comments on commit d67250d

Please sign in to comment.