Skip to content

Commit

Permalink
Add additional input/output checks
Browse files Browse the repository at this point in the history
  • Loading branch information
raychu86 committed Dec 2, 2024
1 parent 1de86e7 commit 072d0ce
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 7 deletions.
2 changes: 1 addition & 1 deletion console/program/src/data_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ mod struct_type;
pub use struct_type::StructType;

mod value_type;
pub use value_type::ValueType;
pub use value_type::{ValueType, Variant};
16 changes: 16 additions & 0 deletions console/program/src/data_types/value_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use snarkvm_console_network::prelude::*;

use enum_index::EnumIndex;

pub type Variant = u8;

#[derive(Clone, PartialEq, Eq, Hash, EnumIndex)]
pub enum ValueType<N: Network> {
/// A constant type.
Expand All @@ -38,6 +40,20 @@ pub enum ValueType<N: Network> {
Future(Locator<N>),
}

impl<N: Network> ValueType<N> {
/// Returns the variant of the value type.
pub const fn variant(&self) -> Variant {
match self {
ValueType::Constant(..) => 0,
ValueType::Public(..) => 1,
ValueType::Private(..) => 2,
ValueType::Record(..) => 3,
ValueType::ExternalRecord(..) => 4,
ValueType::Future(..) => 5,
}
}
}

impl<N: Network> From<EntryType<N>> for ValueType<N> {
fn from(entry: EntryType<N>) -> Self {
match entry {
Expand Down
2 changes: 1 addition & 1 deletion synthesizer/process/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use console::{
program::{Identifier, Literal, Locator, Plaintext, ProgramID, Record, Response, Value, compute_function_id},
types::{Field, U16, U64},
};
use ledger_block::{Deployment, Execution, Fee, Input, Transition};
use ledger_block::{Deployment, Execution, Fee, Input, Output, Transition};
use ledger_store::{FinalizeStorage, FinalizeStore, atomic_batch_scope};
use synthesizer_program::{
Branch,
Expand Down
8 changes: 8 additions & 0 deletions synthesizer/process/src/verify_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ impl<N: Network> Process<N> {
ensure!(function.inputs().len() == num_inputs, "The number of transition inputs is incorrect");
ensure!(function.outputs().len() == num_outputs, "The number of transition outputs is incorrect");

// Ensure the input and output types are equivalent to the ones defined in the function.
// We only need to check that the variant type matches because we already check the hashes in
// the `Input::verify` and `Output::verify` functions.
let transition_input_variants = transition.inputs().iter().map(Input::variant).collect::<Vec<_>>();
let transition_output_variants = transition.outputs().iter().map(Output::variant).collect::<Vec<_>>();
ensure!(function.input_variants() == transition_input_variants, "The input variants do not match");
ensure!(function.output_variants() == transition_output_variants, "The output variants do not match");

// Retrieve the parent program ID.
// Note: The last transition in the execution does not have a parent, by definition.
let parent = reverse_call_graph.get(transition.id()).and_then(|tid| execution.get_program_id(tid));
Expand Down
17 changes: 13 additions & 4 deletions synthesizer/process/src/verify_fee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ impl<N: Network> Process<N> {
pub fn verify_fee(&self, fee: &Fee<N>, deployment_or_execution_id: Field<N>) -> Result<()> {
let timer = timer!("Process::verify_fee");

// Retrieve the stack.
let stack = self.get_stack(fee.program_id())?;
// Retrieve the function from the stack.
let function = stack.get_function(fee.function_name())?;

#[cfg(debug_assertions)]
{
println!("Verifying fee from {}/{}...", fee.program_id(), fee.function_name());
// Retrieve the stack.
let stack = self.get_stack(fee.program_id())?;
// Retrieve the function from the stack.
let function = stack.get_function(fee.function_name())?;
// Ensure the number of function calls in this function is 1.
if stack.get_number_of_calls(function.name())? != 1 {
bail!("The number of function calls in '{}/{}' should be 1", stack.program_id(), function.name())
Expand All @@ -52,6 +53,14 @@ impl<N: Network> Process<N> {
// Ensure the number of outputs is within the allowed range.
ensure!(fee.outputs().len() <= N::MAX_INPUTS, "Fee exceeded maximum number of outputs");

// Ensure the input and output types are equivalent to the ones defined in the function.
// We only need to check that the variant type matches because we already check the hashes in
// the `Input::verify` and `Output::verify` functions.
let fee_input_variants = fee.inputs().iter().map(Input::variant).collect::<Vec<_>>();
let fee_output_variants = fee.outputs().iter().map(Output::variant).collect::<Vec<_>>();
ensure!(function.input_variants() == fee_input_variants, "The fee input variants do not match");
ensure!(function.output_variants() == fee_output_variants, "The fee output variants do not match");

// Retrieve the candidate deployment or execution ID.
let Ok(candidate_id) = fee.deployment_or_execution_id() else {
bail!("Failed to get the deployment or execution ID in the fee transition")
Expand Down
12 changes: 11 additions & 1 deletion synthesizer/program/src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
};
use console::{
network::prelude::*,
program::{Identifier, Register, ValueType},
program::{Identifier, Register, ValueType, Variant},
};

use indexmap::IndexSet;
Expand Down Expand Up @@ -69,6 +69,11 @@ impl<N: Network, Instruction: InstructionTrait<N>, Command: CommandTrait<N>> Fun
self.inputs.iter().map(|input| input.value_type()).cloned().collect()
}

/// Returns the function input type variants.
pub fn input_variants(&self) -> Vec<Variant> {
self.inputs.iter().map(|input| input.value_type().variant()).collect()
}

/// Returns the function instructions.
pub fn instructions(&self) -> &[Instruction] {
&self.instructions
Expand All @@ -84,6 +89,11 @@ impl<N: Network, Instruction: InstructionTrait<N>, Command: CommandTrait<N>> Fun
self.outputs.iter().map(|output| output.value_type()).cloned().collect()
}

/// Returns the function output type variants.
pub fn output_variants(&self) -> Vec<Variant> {
self.outputs.iter().map(|output| output.value_type().variant()).collect()
}

/// Returns the function finalize logic.
pub const fn finalize_logic(&self) -> Option<&FinalizeCore<N, Command>> {
self.finalize_logic.as_ref()
Expand Down
8 changes: 8 additions & 0 deletions synthesizer/program/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,10 @@ function swap:
assert_eq!(function.input_types()[0], ValueType::ExternalRecord(Locator::from_str("eth.aleo/eth")?));
assert_eq!(function.input_types()[1], ValueType::ExternalRecord(Locator::from_str("usdc.aleo/usdc")?));

// Ensure the input variants are correct.
assert_eq!(function.input_types()[0].variant(), 4);
assert_eq!(function.input_types()[1].variant(), 4);

// Ensure there are two instructions.
assert_eq!(function.instructions().len(), 2);

Expand All @@ -827,6 +831,10 @@ function swap:
assert_eq!(function.output_types()[0], ValueType::ExternalRecord(Locator::from_str("eth.aleo/eth")?));
assert_eq!(function.output_types()[1], ValueType::ExternalRecord(Locator::from_str("usdc.aleo/usdc")?));

// Ensure the output variants are correct.
assert_eq!(function.output_types()[0].variant(), 4);
assert_eq!(function.output_types()[1].variant(), 4);

Ok(())
}
}

0 comments on commit 072d0ce

Please sign in to comment.