diff --git a/console/program/src/data_types/mod.rs b/console/program/src/data_types/mod.rs index 4d3175b771..acc93c9a4f 100644 --- a/console/program/src/data_types/mod.rs +++ b/console/program/src/data_types/mod.rs @@ -35,4 +35,4 @@ mod struct_type; pub use struct_type::StructType; mod value_type; -pub use value_type::ValueType; +pub use value_type::{ValueType, Variant}; diff --git a/console/program/src/data_types/value_type/mod.rs b/console/program/src/data_types/value_type/mod.rs index 6eff21d476..136f593b98 100644 --- a/console/program/src/data_types/value_type/mod.rs +++ b/console/program/src/data_types/value_type/mod.rs @@ -22,6 +22,8 @@ use snarkvm_console_network::prelude::*; use enum_index::EnumIndex; +pub type Variant = u8; + #[derive(Clone, PartialEq, Eq, Hash, EnumIndex)] pub enum ValueType { /// A constant type. @@ -38,6 +40,20 @@ pub enum ValueType { Future(Locator), } +impl ValueType { + /// Returns the variant of the value type. + pub const fn variant(&self) -> Variant { + match self { + ValueType::Constant(..) => 0, + ValueType::Public(..) => 1, + ValueType::Private(..) => 2, + ValueType::Record(..) => 3, + ValueType::ExternalRecord(..) => 4, + ValueType::Future(..) => 5, + } + } +} + impl From> for ValueType { fn from(entry: EntryType) -> Self { match entry { diff --git a/synthesizer/process/src/lib.rs b/synthesizer/process/src/lib.rs index 65bfede98e..85185f39f8 100644 --- a/synthesizer/process/src/lib.rs +++ b/synthesizer/process/src/lib.rs @@ -49,7 +49,7 @@ use console::{ program::{Identifier, Literal, Locator, Plaintext, ProgramID, Record, Response, Value, compute_function_id}, types::{Field, U16, U64}, }; -use ledger_block::{Deployment, Execution, Fee, Input, Transition}; +use ledger_block::{Deployment, Execution, Fee, Input, Output, Transition}; use ledger_store::{FinalizeStorage, FinalizeStore, atomic_batch_scope}; use synthesizer_program::{ Branch, diff --git a/synthesizer/process/src/verify_execution.rs b/synthesizer/process/src/verify_execution.rs index bc993a3e55..da4d9fbe89 100644 --- a/synthesizer/process/src/verify_execution.rs +++ b/synthesizer/process/src/verify_execution.rs @@ -112,6 +112,14 @@ impl Process { ensure!(function.inputs().len() == num_inputs, "The number of transition inputs is incorrect"); ensure!(function.outputs().len() == num_outputs, "The number of transition outputs is incorrect"); + // Ensure the input and output types are equivalent to the ones defined in the function. + // We only need to check that the variant type matches because we already check the hashes in + // the `Input::verify` and `Output::verify` functions. + let transition_input_variants = transition.inputs().iter().map(Input::variant).collect::>(); + let transition_output_variants = transition.outputs().iter().map(Output::variant).collect::>(); + ensure!(function.input_variants() == transition_input_variants, "The input variants do not match"); + ensure!(function.output_variants() == transition_output_variants, "The output variants do not match"); + // Retrieve the parent program ID. // Note: The last transition in the execution does not have a parent, by definition. let parent = reverse_call_graph.get(transition.id()).and_then(|tid| execution.get_program_id(tid)); diff --git a/synthesizer/process/src/verify_fee.rs b/synthesizer/process/src/verify_fee.rs index e9ec997a6b..7f4d8e1fe9 100644 --- a/synthesizer/process/src/verify_fee.rs +++ b/synthesizer/process/src/verify_fee.rs @@ -22,13 +22,14 @@ impl Process { pub fn verify_fee(&self, fee: &Fee, deployment_or_execution_id: Field) -> Result<()> { let timer = timer!("Process::verify_fee"); + // Retrieve the stack. + let stack = self.get_stack(fee.program_id())?; + // Retrieve the function from the stack. + let function = stack.get_function(fee.function_name())?; + #[cfg(debug_assertions)] { println!("Verifying fee from {}/{}...", fee.program_id(), fee.function_name()); - // Retrieve the stack. - let stack = self.get_stack(fee.program_id())?; - // Retrieve the function from the stack. - let function = stack.get_function(fee.function_name())?; // Ensure the number of function calls in this function is 1. if stack.get_number_of_calls(function.name())? != 1 { bail!("The number of function calls in '{}/{}' should be 1", stack.program_id(), function.name()) @@ -52,6 +53,14 @@ impl Process { // Ensure the number of outputs is within the allowed range. ensure!(fee.outputs().len() <= N::MAX_INPUTS, "Fee exceeded maximum number of outputs"); + // Ensure the input and output types are equivalent to the ones defined in the function. + // We only need to check that the variant type matches because we already check the hashes in + // the `Input::verify` and `Output::verify` functions. + let fee_input_variants = fee.inputs().iter().map(Input::variant).collect::>(); + let fee_output_variants = fee.outputs().iter().map(Output::variant).collect::>(); + ensure!(function.input_variants() == fee_input_variants, "The fee input variants do not match"); + ensure!(function.output_variants() == fee_output_variants, "The fee output variants do not match"); + // Retrieve the candidate deployment or execution ID. let Ok(candidate_id) = fee.deployment_or_execution_id() else { bail!("Failed to get the deployment or execution ID in the fee transition") diff --git a/synthesizer/program/src/function/mod.rs b/synthesizer/program/src/function/mod.rs index cb1a9b6eef..caebf538b5 100644 --- a/synthesizer/program/src/function/mod.rs +++ b/synthesizer/program/src/function/mod.rs @@ -28,7 +28,7 @@ use crate::{ }; use console::{ network::prelude::*, - program::{Identifier, Register, ValueType}, + program::{Identifier, Register, ValueType, Variant}, }; use indexmap::IndexSet; @@ -69,6 +69,11 @@ impl, Command: CommandTrait> Fun self.inputs.iter().map(|input| input.value_type()).cloned().collect() } + /// Returns the function input type variants. + pub fn input_variants(&self) -> Vec { + self.inputs.iter().map(|input| input.value_type().variant()).collect() + } + /// Returns the function instructions. pub fn instructions(&self) -> &[Instruction] { &self.instructions @@ -84,6 +89,11 @@ impl, Command: CommandTrait> Fun self.outputs.iter().map(|output| output.value_type()).cloned().collect() } + /// Returns the function output type variants. + pub fn output_variants(&self) -> Vec { + self.outputs.iter().map(|output| output.value_type().variant()).collect() + } + /// Returns the function finalize logic. pub const fn finalize_logic(&self) -> Option<&FinalizeCore> { self.finalize_logic.as_ref() diff --git a/synthesizer/program/src/lib.rs b/synthesizer/program/src/lib.rs index 3c71b0dd0d..f6ad7c30bf 100644 --- a/synthesizer/program/src/lib.rs +++ b/synthesizer/program/src/lib.rs @@ -812,6 +812,10 @@ function swap: assert_eq!(function.input_types()[0], ValueType::ExternalRecord(Locator::from_str("eth.aleo/eth")?)); assert_eq!(function.input_types()[1], ValueType::ExternalRecord(Locator::from_str("usdc.aleo/usdc")?)); + // Ensure the input variants are correct. + assert_eq!(function.input_types()[0].variant(), 4); + assert_eq!(function.input_types()[1].variant(), 4); + // Ensure there are two instructions. assert_eq!(function.instructions().len(), 2); @@ -827,6 +831,10 @@ function swap: assert_eq!(function.output_types()[0], ValueType::ExternalRecord(Locator::from_str("eth.aleo/eth")?)); assert_eq!(function.output_types()[1], ValueType::ExternalRecord(Locator::from_str("usdc.aleo/usdc")?)); + // Ensure the output variants are correct. + assert_eq!(function.output_types()[0].variant(), 4); + assert_eq!(function.output_types()[1].variant(), 4); + Ok(()) } }