From 1ac6779d9ee540babb67e82ca256bfaabb99ccd5 Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Thu, 26 Oct 2023 03:57:49 -0700 Subject: [PATCH] fix: update enum-try-as-inner (#80) --- Cargo.toml | 1 + ot/mpz-ot-core/Cargo.toml | 2 +- ot/mpz-ot-core/src/chou_orlandi/msgs.rs | 7 ++++ ot/mpz-ot-core/src/kos/msgs.rs | 7 ++++ ot/mpz-ot/Cargo.toml | 2 +- ot/mpz-ot/src/actor/kos/error.rs | 26 ++++++------ ot/mpz-ot/src/actor/kos/mod.rs | 28 +------------ ot/mpz-ot/src/actor/kos/msgs.rs | 7 ++++ ot/mpz-ot/src/actor/kos/receiver.rs | 6 +-- ot/mpz-ot/src/actor/kos/sender.rs | 4 +- ot/mpz-ot/src/chou_orlandi/error.rs | 30 ++++++++++---- ot/mpz-ot/src/chou_orlandi/receiver.rs | 33 +++++---------- ot/mpz-ot/src/chou_orlandi/sender.rs | 35 ++++++---------- ot/mpz-ot/src/kos/error.rs | 30 ++++++++++---- ot/mpz-ot/src/kos/mod.rs | 12 ++---- ot/mpz-ot/src/kos/receiver.rs | 42 +++++++++---------- ot/mpz-ot/src/kos/sender.rs | 55 ++++++++++++------------- 17 files changed, 160 insertions(+), 167 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index acb05109..e4599002 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,3 +92,4 @@ once_cell = "1" # DO NOT BUMP, SEE https://github.com/privacy-scaling-explorations/mpz/issues/61 generic-array = "0.14" itybity = "0.2" +enum-try-as-inner = "0.1.0" diff --git a/ot/mpz-ot-core/Cargo.toml b/ot/mpz-ot-core/Cargo.toml index 8e7132c7..7fb85d15 100644 --- a/ot/mpz-ot-core/Cargo.toml +++ b/ot/mpz-ot-core/Cargo.toml @@ -32,8 +32,8 @@ derive_builder.workspace = true itybity.workspace = true opaque-debug.workspace = true cfg-if.workspace = true -enum-try-as-inner = { tag = "0.1.0", git = "https://github.com/sinui0/enum-try-as-inner" } bytemuck = { workspace = true, features = ["derive"] } +enum-try-as-inner.workspace = true [dev-dependencies] rstest.workspace = true diff --git a/ot/mpz-ot-core/src/chou_orlandi/msgs.rs b/ot/mpz-ot-core/src/chou_orlandi/msgs.rs index 43d8c4a1..0d7c0494 100644 --- a/ot/mpz-ot-core/src/chou_orlandi/msgs.rs +++ b/ot/mpz-ot-core/src/chou_orlandi/msgs.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; /// A CO15 protocol message. #[derive(Debug, Clone, EnumTryAsInner, Serialize, Deserialize)] +#[derive_err(Debug)] #[allow(missing_docs)] pub enum Message { SenderSetup(SenderSetup), @@ -18,6 +19,12 @@ pub enum Message { CointossReceiverPayload(cointoss::msgs::ReceiverPayload), } +impl From for std::io::Error { + fn from(err: MessageError) -> Self { + std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string()) + } +} + /// Sender setup message. #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] pub struct SenderSetup { diff --git a/ot/mpz-ot-core/src/kos/msgs.rs b/ot/mpz-ot-core/src/kos/msgs.rs index 72eb442f..afdf2b63 100644 --- a/ot/mpz-ot-core/src/kos/msgs.rs +++ b/ot/mpz-ot-core/src/kos/msgs.rs @@ -14,6 +14,7 @@ use crate::msgs::Derandomize; /// A KOS15 protocol message. #[derive(Debug, Clone, EnumTryAsInner, Serialize, Deserialize)] +#[derive_err(Debug)] #[allow(missing_docs)] pub enum Message { BaseMsg(BaseMsg), @@ -26,6 +27,12 @@ pub enum Message { CointossSenderPayload(CointossSenderPayload), } +impl From> for std::io::Error { + fn from(err: MessageError) -> Self { + std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string()) + } +} + /// Extension message sent by the receiver. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Extend { diff --git a/ot/mpz-ot/Cargo.toml b/ot/mpz-ot/Cargo.toml index 8585018b..fb7d95c0 100644 --- a/ot/mpz-ot/Cargo.toml +++ b/ot/mpz-ot/Cargo.toml @@ -29,7 +29,7 @@ p256 = { workspace = true, optional = true } thiserror.workspace = true rayon = { workspace = true } itybity.workspace = true -enum-try-as-inner = { tag = "0.1.0", git = "https://github.com/sinui0/enum-try-as-inner" } +enum-try-as-inner.workspace = true opaque-debug.workspace = true serde = { workspace = true, optional = true } diff --git a/ot/mpz-ot/src/actor/kos/error.rs b/ot/mpz-ot/src/actor/kos/error.rs index e4d77dbc..6b17a438 100644 --- a/ot/mpz-ot/src/actor/kos/error.rs +++ b/ot/mpz-ot/src/actor/kos/error.rs @@ -1,5 +1,5 @@ use crate::{ - actor::kos::msgs::Message, + actor::kos::msgs::MessageError, kos::{ReceiverError, SenderError}, }; @@ -33,9 +33,9 @@ impl From for SenderActorError { } } -impl From> for SenderActorError { - fn from(value: enum_try_as_inner::Error) -> Self { - SenderError::StateError(value.to_string()).into() +impl From for SenderActorError { + fn from(err: crate::kos::SenderStateError) -> Self { + SenderError::from(err).into() } } @@ -51,11 +51,11 @@ impl From> for SenderActorError { } } -impl From>> for SenderActorError { - fn from(value: enum_try_as_inner::Error>) -> Self { +impl From> for SenderActorError { + fn from(err: MessageError) -> Self { SenderActorError::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, - value.to_string(), + err.to_string(), )) } } @@ -104,9 +104,9 @@ impl From for ReceiverActorError { } } -impl From> for ReceiverActorError { - fn from(value: enum_try_as_inner::Error) -> Self { - ReceiverError::StateError(value.to_string()).into() +impl From for ReceiverActorError { + fn from(err: crate::kos::ReceiverStateError) -> Self { + ReceiverError::from(err).into() } } @@ -122,11 +122,11 @@ impl From> for ReceiverActorError { } } -impl From>> for ReceiverActorError { - fn from(value: enum_try_as_inner::Error>) -> Self { +impl From> for ReceiverActorError { + fn from(err: MessageError) -> Self { ReceiverActorError::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, - value.to_string(), + err.to_string(), )) } } diff --git a/ot/mpz-ot/src/actor/kos/mod.rs b/ot/mpz-ot/src/actor/kos/mod.rs index d5eacb2d..988c300a 100644 --- a/ot/mpz-ot/src/actor/kos/mod.rs +++ b/ot/mpz-ot/src/actor/kos/mod.rs @@ -6,7 +6,7 @@ mod sender; use futures::{SinkExt, StreamExt}; use utils_aio::{sink::IoSink, stream::IoStream}; -use crate::kos::{msgs::Message as KosMessage, ReceiverError, SenderError}; +use crate::kos::msgs::Message as KosMessage; pub use error::{ReceiverActorError, SenderActorError}; pub use receiver::{ReceiverActor, SharedReceiver}; @@ -26,35 +26,11 @@ pub(crate) fn into_kos_stream<'a, St: IoStream> + Send + Unpin, stream: &'a mut St, ) -> impl IoStream> + Send + Unpin + 'a { StreamExt::map(stream, |msg| match msg { - Ok(msg) => match msg.into_protocol() { - Ok(msg) => Ok(msg), - Err(err) => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - err.to_string(), - )), - }, + Ok(msg) => msg.try_into_protocol().map_err(From::from), Err(err) => Err(err), }) } -impl From>> for SenderError { - fn from(value: enum_try_as_inner::Error>) -> Self { - SenderError::from(std::io::Error::new( - std::io::ErrorKind::InvalidData, - value.to_string(), - )) - } -} - -impl From>> for ReceiverError { - fn from(value: enum_try_as_inner::Error>) -> Self { - ReceiverError::from(std::io::Error::new( - std::io::ErrorKind::InvalidData, - value.to_string(), - )) - } -} - #[cfg(test)] mod tests { use crate::{ diff --git a/ot/mpz-ot/src/actor/kos/msgs.rs b/ot/mpz-ot/src/actor/kos/msgs.rs index 9ef85d90..b87c52ad 100644 --- a/ot/mpz-ot/src/actor/kos/msgs.rs +++ b/ot/mpz-ot/src/actor/kos/msgs.rs @@ -10,12 +10,19 @@ use mpz_ot_core::{ /// KOS actor message #[derive(Debug, Clone, EnumTryAsInner, Serialize, Deserialize)] +#[derive_err(Debug)] #[allow(missing_docs)] pub enum Message { ActorMessage(ActorMessage), Protocol(KosMessage), } +impl From> for std::io::Error { + fn from(err: MessageError) -> Self { + std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string()) + } +} + impl From for Message { fn from(value: ActorMessage) -> Self { Message::ActorMessage(value) diff --git a/ot/mpz-ot/src/actor/kos/receiver.rs b/ot/mpz-ot/src/actor/kos/receiver.rs index 58b22c9e..f735cde5 100644 --- a/ot/mpz-ot/src/actor/kos/receiver.rs +++ b/ot/mpz-ot/src/actor/kos/receiver.rs @@ -153,7 +153,7 @@ where let mut keys = self .receiver .state_mut() - .as_extension_mut()? + .try_as_extension_mut()? .keys(choices.len())?; let derandomize = keys.derandomize(choices)?; @@ -205,7 +205,7 @@ where _ = caller_response.send( self.receiver .state_mut() - .as_verify_mut() + .try_as_verify_mut() .map_err(ReceiverError::from) .and_then(|receiver| { receiver.remove_record(*id).map_err(ReceiverError::from) @@ -262,7 +262,7 @@ where /// Handles a message from the KOS sender actor. async fn handle_msg(&mut self, msg: Message) -> Result<(), ReceiverActorError> { - let msg = msg.into_actor_message()?; + let msg = msg.try_into_actor_message()?; match msg { ActorMessage::TransferPayload(TransferPayload { id, payload }) => { diff --git a/ot/mpz-ot/src/actor/kos/sender.rs b/ot/mpz-ot/src/actor/kos/sender.rs index 19d2a06a..caeb8718 100644 --- a/ot/mpz-ot/src/actor/kos/sender.rs +++ b/ot/mpz-ot/src/actor/kos/sender.rs @@ -132,7 +132,7 @@ where futures::select! { // Processes a message received from the Receiver. msg = self.stream.select_next_some() => { - self.handle_msg(msg?.into_actor_message()?)?; + self.handle_msg(msg?.try_into_actor_message()?)?; } // Processes a command from a controller. cmd = self.commands.select_next_some() => { @@ -185,7 +185,7 @@ where let keys = self .sender .state_mut() - .as_extension_mut() + .try_as_extension_mut() .map_err(SenderError::from) .and_then(|sender| { sender diff --git a/ot/mpz-ot/src/chou_orlandi/error.rs b/ot/mpz-ot/src/chou_orlandi/error.rs index 5ba6b135..7da75ea4 100644 --- a/ot/mpz-ot/src/chou_orlandi/error.rs +++ b/ot/mpz-ot/src/chou_orlandi/error.rs @@ -1,4 +1,4 @@ -use mpz_ot_core::chou_orlandi::msgs::Message; +use mpz_ot_core::chou_orlandi::msgs::MessageError; use crate::OTError; @@ -10,7 +10,7 @@ pub enum SenderError { IOError(#[from] std::io::Error), #[error(transparent)] CoreError(#[from] mpz_ot_core::chou_orlandi::SenderError), - #[error("invalid state: expected {0}")] + #[error("{0}")] StateError(String), #[error(transparent)] CointossError(#[from] mpz_core::cointoss::CointossError), @@ -27,11 +27,17 @@ impl From for OTError { } } -impl From> for SenderError { - fn from(value: enum_try_as_inner::Error) -> Self { +impl From for SenderError { + fn from(err: crate::chou_orlandi::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +impl From for SenderError { + fn from(err: MessageError) -> Self { SenderError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, - value.to_string(), + err.to_string(), )) } } @@ -44,7 +50,7 @@ pub enum ReceiverError { IOError(#[from] std::io::Error), #[error(transparent)] CoreError(#[from] mpz_ot_core::chou_orlandi::ReceiverError), - #[error("invalid state: expected {0}")] + #[error("{0}")] StateError(String), #[error(transparent)] CointossError(#[from] mpz_core::cointoss::CointossError), @@ -61,11 +67,17 @@ impl From for OTError { } } -impl From> for ReceiverError { - fn from(value: enum_try_as_inner::Error) -> Self { +impl From for ReceiverError { + fn from(err: crate::chou_orlandi::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} + +impl From for ReceiverError { + fn from(err: MessageError) -> Self { ReceiverError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, - value.to_string(), + err.to_string(), )) } } diff --git a/ot/mpz-ot/src/chou_orlandi/receiver.rs b/ot/mpz-ot/src/chou_orlandi/receiver.rs index 2ac180ab..227bed3b 100644 --- a/ot/mpz-ot/src/chou_orlandi/receiver.rs +++ b/ot/mpz-ot/src/chou_orlandi/receiver.rs @@ -20,7 +20,8 @@ use crate::{CommittedOTReceiver, OTError, OTReceiver, OTSetup}; use super::ReceiverError; #[derive(Debug, EnumTryAsInner)] -enum State { +#[derive_err(Debug)] +pub(crate) enum State { Initialized { config: ReceiverConfig, seed: Option<[u8; 32]>, @@ -30,12 +31,6 @@ enum State { Error, } -impl From> for ReceiverError { - fn from(value: enum_try_as_inner::Error) -> Self { - ReceiverError::StateError(value.to_string()) - } -} - /// Chou-Orlandi receiver. #[derive(Debug)] pub struct Receiver { @@ -84,10 +79,8 @@ impl OTSetup for Receiver { return Ok(()); } - let (config, seed) = self - .state - .replace(State::Error) - .into_initialized() + let (config, seed) = std::mem::replace(&mut self.state, State::Error) + .try_into_initialized() .map_err(ReceiverError::from)?; // If the receiver is committed, we generate the seed using a cointoss. @@ -110,7 +103,7 @@ impl OTSetup for Receiver { let sender_setup = stream .expect_next() .await? - .into_sender_setup() + .try_into_sender_setup() .map_err(ReceiverError::from)?; let receiver = Backend::spawn(move || receiver.setup(sender_setup)).await; @@ -138,7 +131,7 @@ async fn execute_cointoss< let payload = stream .expect_next() .await? - .into_cointoss_receiver_payload()?; + .try_into_cointoss_receiver_payload()?; let (seeds, payload) = sender.finalize(payload)?; @@ -167,10 +160,8 @@ where stream: &mut St, choices: &[T], ) -> Result, OTError> { - let mut receiver = self - .state - .replace(State::Error) - .into_setup() + let mut receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_setup() .map_err(ReceiverError::from)?; let choices = choices.to_vec(); @@ -186,7 +177,7 @@ where let sender_payload = stream .expect_next() .await? - .into_sender_payload() + .try_into_sender_payload() .map_err(ReceiverError::from)?; let (receiver, data) = Backend::spawn(move || { @@ -213,10 +204,8 @@ impl CommittedOTReceiver for Receiver { sink: &mut Si, _stream: &mut St, ) -> Result<(), OTError> { - let receiver = self - .state - .replace(State::Error) - .into_setup() + let receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_setup() .map_err(ReceiverError::from)?; let Some(cointoss_payload) = self.cointoss_payload.take() else { diff --git a/ot/mpz-ot/src/chou_orlandi/sender.rs b/ot/mpz-ot/src/chou_orlandi/sender.rs index 8e90ac80..e758e4d8 100644 --- a/ot/mpz-ot/src/chou_orlandi/sender.rs +++ b/ot/mpz-ot/src/chou_orlandi/sender.rs @@ -16,19 +16,14 @@ use utils_aio::{ use enum_try_as_inner::EnumTryAsInner; #[derive(Debug, EnumTryAsInner)] -enum State { +#[derive_err(Debug)] +pub(crate) enum State { Initialized(SenderCore), Setup(SenderCore), Complete, Error, } -impl From> for SenderError { - fn from(value: enum_try_as_inner::Error) -> Self { - SenderError::StateError(value.to_string()) - } -} - /// Chou-Orlandi sender. #[derive(Debug)] pub struct Sender { @@ -76,10 +71,8 @@ impl OTSetup for Sender { return Ok(()); } - let sender = self - .state - .replace(State::Error) - .into_initialized() + let sender = std::mem::replace(&mut self.state, State::Error) + .try_into_initialized() .map_err(SenderError::from)?; // If the receiver is committed, we run the cointoss protocol @@ -111,7 +104,7 @@ async fn execute_cointoss< let commitment = stream .expect_next() .await? - .into_cointoss_sender_commitment()?; + .try_into_cointoss_sender_commitment()?; let (receiver, payload) = receiver.reveal(commitment)?; @@ -132,16 +125,14 @@ impl OTSender<[Block; 2]> for Sender { stream: &mut St, input: &[[Block; 2]], ) -> Result<(), OTError> { - let mut sender = self - .state - .replace(State::Error) - .into_setup() + let mut sender = std::mem::replace(&mut self.state, State::Error) + .try_into_setup() .map_err(SenderError::from)?; let receiver_payload = stream .expect_next() .await? - .into_receiver_payload() + .try_into_receiver_payload() .map_err(SenderError::from)?; let input = input.to_vec(); @@ -171,10 +162,8 @@ impl VerifiableOTSender for Sender { _sink: &mut Si, stream: &mut St, ) -> Result, OTError> { - let sender = self - .state - .replace(State::Error) - .into_setup() + let sender = std::mem::replace(&mut self.state, State::Error) + .try_into_setup() .map_err(SenderError::from)?; let Some(cointoss_receiver) = self.cointoss_receiver.take() else { @@ -186,13 +175,13 @@ impl VerifiableOTSender for Sender { let cointoss_payload = stream .expect_next() .await? - .into_cointoss_sender_payload() + .try_into_cointoss_sender_payload() .map_err(SenderError::from)?; let receiver_reveal = stream .expect_next() .await? - .into_receiver_reveal() + .try_into_receiver_reveal() .map_err(SenderError::from)?; let cointoss_seed = cointoss_receiver diff --git a/ot/mpz-ot/src/kos/error.rs b/ot/mpz-ot/src/kos/error.rs index 43435427..8ebce4a5 100644 --- a/ot/mpz-ot/src/kos/error.rs +++ b/ot/mpz-ot/src/kos/error.rs @@ -1,4 +1,4 @@ -use mpz_ot_core::kos::msgs::Message; +use mpz_ot_core::kos::msgs::MessageError; use crate::OTError; @@ -14,7 +14,7 @@ pub enum SenderError { BaseOTError(#[from] crate::OTError), #[error(transparent)] CointossError(#[from] mpz_core::cointoss::CointossError), - #[error("invalid state: expected {0}")] + #[error("{0}")] StateError(String), #[error("configuration error: {0}")] ConfigError(String), @@ -31,17 +31,23 @@ impl From for OTError { } } +impl From for SenderError { + fn from(err: crate::kos::SenderStateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + impl From for OTError { fn from(err: mpz_ot_core::kos::SenderError) -> Self { SenderError::from(err).into() } } -impl From>> for SenderError { - fn from(value: enum_try_as_inner::Error>) -> Self { +impl From> for SenderError { + fn from(err: MessageError) -> Self { SenderError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, - value.to_string(), + err.to_string(), )) } } @@ -58,7 +64,7 @@ pub enum ReceiverError { BaseOTError(#[from] crate::OTError), #[error(transparent)] CointossError(#[from] mpz_core::cointoss::CointossError), - #[error("invalid state: expected {0}")] + #[error("{0}")] StateError(String), #[error("configuration error: {0}")] ConfigError(String), @@ -77,17 +83,23 @@ impl From for OTError { } } +impl From for ReceiverError { + fn from(err: crate::kos::ReceiverStateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} + impl From for OTError { fn from(err: mpz_ot_core::kos::ReceiverError) -> Self { ReceiverError::from(err).into() } } -impl From>> for ReceiverError { - fn from(value: enum_try_as_inner::Error>) -> Self { +impl From> for ReceiverError { + fn from(err: MessageError) -> Self { ReceiverError::from(std::io::Error::new( std::io::ErrorKind::InvalidData, - value.to_string(), + err.to_string(), )) } } diff --git a/ot/mpz-ot/src/kos/mod.rs b/ot/mpz-ot/src/kos/mod.rs index f700bd7f..c9ffb8b1 100644 --- a/ot/mpz-ot/src/kos/mod.rs +++ b/ot/mpz-ot/src/kos/mod.rs @@ -9,8 +9,8 @@ use futures_util::{SinkExt, StreamExt}; pub use receiver::Receiver; pub use sender::Sender; -pub(crate) use receiver::State as ReceiverState; -pub(crate) use sender::State as SenderState; +pub(crate) use receiver::StateError as ReceiverStateError; +pub(crate) use sender::StateError as SenderStateError; pub use mpz_ot_core::kos::{ msgs, PayloadRecord, ReceiverConfig, ReceiverConfigBuilder, ReceiverConfigBuilderError, @@ -32,13 +32,7 @@ pub(crate) fn into_base_stream<'a, St: IoStream> + Send + Unpin stream: &'a mut St, ) -> impl IoStream + Send + Unpin + 'a { StreamExt::map(stream, |msg| match msg { - Ok(msg) => match msg.into_base_msg() { - Ok(msg) => Ok(msg), - Err(err) => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - err.to_string(), - )), - }, + Ok(msg) => msg.try_into_base_msg().map_err(From::from), Err(err) => Err(err), }) } diff --git a/ot/mpz-ot/src/kos/receiver.rs b/ot/mpz-ot/src/kos/receiver.rs index 8fe4b11d..ed3661e0 100644 --- a/ot/mpz-ot/src/kos/receiver.rs +++ b/ot/mpz-ot/src/kos/receiver.rs @@ -19,6 +19,7 @@ use crate::{OTError, OTReceiver, OTSender, OTSetup, VerifiableOTReceiver, Verifi use super::{into_base_sink, into_base_stream, ReceiverError, ReceiverVerifyError}; #[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] pub(crate) enum State { Initialized(Box>), Extension(Box>), @@ -26,12 +27,6 @@ pub(crate) enum State { Error, } -impl From> for ReceiverError { - fn from(value: enum_try_as_inner::Error) -> Self { - ReceiverError::StateError(value.to_string()) - } -} - /// KOS receiver. #[derive(Debug)] pub struct Receiver { @@ -60,7 +55,7 @@ where /// The number of remaining OTs which can be consumed. pub fn remaining(&self) -> Result { - Ok(self.state.as_extension()?.remaining()) + Ok(self.state.try_as_extension()?.remaining()) } /// Returns a reference to the inner receiver state. @@ -89,7 +84,8 @@ where stream: &mut St, count: usize, ) -> Result<(), ReceiverError> { - let mut ext_receiver = self.state.replace(State::Error).into_extension()?; + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; // Extend the OTs, adding padding for the consistency check. let (mut ext_receiver, extend) = Backend::spawn(move || { @@ -115,7 +111,7 @@ where let cointoss_payload = stream .expect_next() .await? - .into_cointoss_receiver_payload()?; + .try_into_cointoss_receiver_payload()?; // Open commitment let (mut seeds, payload) = cointoss_sender.finalize(cointoss_payload)?; @@ -154,13 +150,13 @@ where sink: &mut Si, stream: &mut St, ) -> Result<(), ReceiverError> { - let receiver = self.state.replace(State::Error).into_extension()?; + let receiver = std::mem::replace(&mut self.state, State::Error).try_into_extension()?; // Finalize coin toss to determine expected delta let cointoss_payload = stream .expect_next() .await? - .into_cointoss_sender_payload() + .try_into_cointoss_sender_payload() .map_err(ReceiverError::from)?; let Some(cointoss_receiver) = self.cointoss_receiver.take() else { @@ -215,10 +211,8 @@ where return Ok(()); } - let ext_receiver = self - .state - .replace(State::Error) - .into_initialized() + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_initialized() .map_err(ReceiverError::from)?; // If the sender is committed, we run a coin toss @@ -226,7 +220,7 @@ where let commitment = stream .expect_next() .await? - .into_cointoss_commit() + .try_into_cointoss_commit() .map_err(ReceiverError::from)?; let (cointoss_receiver, payload) = cointoss::Receiver::new(vec![thread_rng().gen()]) @@ -276,7 +270,10 @@ where stream: &mut St, choices: &[bool], ) -> Result, OTError> { - let receiver = self.state.as_extension_mut().map_err(ReceiverError::from)?; + let receiver = self + .state + .try_as_extension_mut() + .map_err(ReceiverError::from)?; let mut receiver_keys = receiver.keys(choices.len()).map_err(ReceiverError::from)?; @@ -292,7 +289,7 @@ where let payload = stream .expect_next() .await? - .into_sender_payload() + .try_into_sender_payload() .map_err(ReceiverError::from)?; let received = Backend::spawn(move || { @@ -320,7 +317,10 @@ where stream: &mut St, choices: &[bool], ) -> Result, OTError> { - let receiver = self.state.as_extension_mut().map_err(ReceiverError::from)?; + let receiver = self + .state + .try_as_extension_mut() + .map_err(ReceiverError::from)?; let mut receiver_keys = receiver.keys(choices.len()).map_err(ReceiverError::from)?; @@ -336,7 +336,7 @@ where let payload = stream .expect_next() .await? - .into_sender_payload() + .try_into_sender_payload() .map_err(ReceiverError::from)?; let received = Backend::spawn(move || { @@ -370,7 +370,7 @@ where self.verify_delta(sink, stream).await?; } - let receiver = self.state.as_verify().map_err(ReceiverError::from)?; + let receiver = self.state.try_as_verify().map_err(ReceiverError::from)?; let record = receiver .remove_record(id as u32) diff --git a/ot/mpz-ot/src/kos/sender.rs b/ot/mpz-ot/src/kos/sender.rs index 863dc72a..a1a21411 100644 --- a/ot/mpz-ot/src/kos/sender.rs +++ b/ot/mpz-ot/src/kos/sender.rs @@ -22,6 +22,7 @@ use enum_try_as_inner::EnumTryAsInner; use super::{into_base_sink, into_base_stream}; #[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] pub(crate) enum State { Initialized(SenderCore), Extension(SenderCore), @@ -29,12 +30,6 @@ pub(crate) enum State { Error, } -impl From> for SenderError { - fn from(value: enum_try_as_inner::Error) -> Self { - SenderError::StateError(value.to_string()) - } -} - /// KOS sender. #[derive(Debug)] pub struct Sender { @@ -63,7 +58,7 @@ where /// The number of remaining OTs which can be consumed. pub fn remaining(&self) -> Result { - Ok(self.state.as_extension()?.remaining()) + Ok(self.state.try_as_extension()?.remaining()) } /// Returns a mutable reference to the inner sender state. @@ -87,7 +82,7 @@ where stream: &mut St, delta: Block, ) -> Result<(), SenderError> { - if self.state.as_initialized()?.config().sender_commit() { + if self.state.try_as_initialized()?.config().sender_commit() { return Err(SenderError::ConfigError( "committed sender can not choose delta".to_string(), )); @@ -105,7 +100,7 @@ where stream: &mut St, delta: Block, ) -> Result<(), SenderError> { - let ext_sender = self.state.replace(State::Error).into_initialized()?; + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; let choices = delta.into_lsb0_vec(); let seeds = self @@ -141,17 +136,18 @@ where stream: &mut St, count: usize, ) -> Result<(), SenderError> { - let mut ext_sender = self.state.replace(State::Error).into_extension()?; + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; // Receive extend message from the receiver. let extend = stream .expect_next() .await? - .into_extend() + .try_into_extend() .map_err(SenderError::from)?; // Receive coin toss commitments from the receiver. - let commitment = stream.expect_next().await?.into_cointoss_commit()?; + let commitment = stream.expect_next().await?.try_into_cointoss_commit()?; // Extend the OTs, adding padding for the consistency check. let mut ext_sender = Backend::spawn(move || { @@ -172,10 +168,13 @@ where .await?; // Receive coin toss sender payload from the receiver. - let cointoss_sender_payload = stream.expect_next().await?.into_cointoss_sender_payload()?; + let cointoss_sender_payload = stream + .expect_next() + .await? + .try_into_cointoss_sender_payload()?; // Receive consistency check from the receiver. - let receiver_check = stream.expect_next().await?.into_check()?; + let receiver_check = stream.expect_next().await?.try_into_check()?; // Derive chi seed for the consistency check. let chi_seed = cointoss_receiver.finalize(cointoss_sender_payload)?[0]; @@ -206,11 +205,7 @@ where sink: &mut Si, stream: &mut St, ) -> Result<(), SenderError> { - let _ = self - .state - .replace(State::Error) - .into_extension() - .map_err(SenderError::from)?; + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; // Reveal coin toss payload let Some(payload) = self.cointoss_payload.take() else { @@ -259,10 +254,8 @@ where return Ok(()); } - let sender = self - .state - .replace(State::Error) - .into_initialized() + let sender = std::mem::replace(&mut self.state, State::Error) + .try_into_initialized() .map_err(SenderError::from)?; // If the sender is committed, we sample delta using a coin toss. @@ -274,7 +267,7 @@ where let payload = stream .expect_next() .await? - .into_cointoss_receiver_payload() + .try_into_cointoss_receiver_payload() .map_err(SenderError::from)?; let (seeds, payload) = cointoss_sender @@ -316,12 +309,15 @@ where stream: &mut St, msgs: &[[Block; 2]], ) -> Result<(), OTError> { - let sender = self.state.as_extension_mut().map_err(SenderError::from)?; + let sender = self + .state + .try_as_extension_mut() + .map_err(SenderError::from)?; let derandomize = stream .expect_next() .await? - .into_derandomize() + .try_into_derandomize() .map_err(SenderError::from)?; let mut sender_keys = sender.keys(msgs.len()).map_err(SenderError::from)?; @@ -354,12 +350,15 @@ where stream: &mut St, msgs: &[[[u8; N]; 2]], ) -> Result<(), OTError> { - let sender = self.state.as_extension_mut().map_err(SenderError::from)?; + let sender = self + .state + .try_as_extension_mut() + .map_err(SenderError::from)?; let derandomize = stream .expect_next() .await? - .into_derandomize() + .try_into_derandomize() .map_err(SenderError::from)?; let mut sender_keys = sender.keys(msgs.len()).map_err(SenderError::from)?;