Skip to content

Commit

Permalink
add serialization and deserialization for ContractClass
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoratger committed Jul 25, 2024
1 parent 6eb75c0 commit 1b3bb4c
Show file tree
Hide file tree
Showing 7 changed files with 14,724 additions and 10 deletions.
208 changes: 200 additions & 8 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Deref;
use std::sync::Arc;

use cairo_felt::Felt252;
use cairo_lang_casm;
use cairo_lang_casm::hints::Hint;
use cairo_lang_sierra::program::Program as SierraProgram;
use cairo_lang_starknet_classes::casm_contract_class::{
Expand All @@ -22,21 +23,22 @@ use cairo_vm::types::relocatable::MaybeRelocatable;
use cairo_vm::vm::runners::builtin_runner::{HASH_BUILTIN_NAME, POSEIDON_BUILTIN_NAME};
use cairo_vm::vm::runners::cairo_runner::ExecutionResources;
use itertools::Itertools;
use num_traits::Num;
use serde::de::Error as DeserializationError;
use serde::{Deserialize, Deserializer};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use starknet_api::core::EntryPointSelector;
use starknet_api::deprecated_contract_class::{
ContractClass as DeprecatedContractClass, EntryPoint, EntryPointOffset, EntryPointType,
Program as DeprecatedProgram,
};

use super::execution_utils::poseidon_hash_many_cost;
use super::native::utils::contract_entrypoint_to_entrypoint_selector;
use super::execution_utils::{cairo_vm_to_sn_api_program, poseidon_hash_many_cost};
use crate::abi::abi_utils::selector_from_name;
use crate::abi::constants::{self, CONSTRUCTOR_ENTRY_POINT_NAME};
use crate::execution::entry_point::CallEntryPoint;
use crate::execution::errors::{ContractClassError, PreExecutionError};
use crate::execution::execution_utils::{felt_to_stark_felt, sn_api_to_cairo_vm_program};
use crate::execution::native::utils::contract_entrypoint_to_entrypoint_selector;
use crate::fee::eth_gas_constants;
use crate::transaction::errors::TransactionExecutionError;

Expand All @@ -51,7 +53,7 @@ pub mod test;

pub type ContractClassResult<T> = Result<T, ContractClassError>;

#[derive(Clone, Debug, Eq, PartialEq, derive_more::From)]
#[derive(Clone, Debug, Eq, PartialEq, derive_more::From, Serialize, Deserialize)]
pub enum ContractClass {
V0(ContractClassV0),
V1(ContractClassV1),
Expand Down Expand Up @@ -98,7 +100,7 @@ impl ContractClass {
}

// V0.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, Eq, PartialEq)]
pub struct ContractClassV0(pub Arc<ContractClassV0Inner>);
impl Deref for ContractClassV0 {
type Target = ContractClassV0Inner;
Expand Down Expand Up @@ -149,9 +151,9 @@ impl ContractClassV0 {
}
}

#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, Eq, PartialEq)]
pub struct ContractClassV0Inner {
#[serde(deserialize_with = "deserialize_program")]
#[serde(deserialize_with = "deserialize_program", serialize_with = "serialize_program")]
pub program: Program,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPoint>>,
}
Expand All @@ -178,6 +180,39 @@ impl Deref for ContractClassV1 {
}
}

impl Serialize for ContractClassV1 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Convert the ContractClassV1 instance to CasmContractClass
let casm_contract_class: CasmContractClass = self
.try_into()
.map_err(|err: ProgramError| serde::ser::Error::custom(err.to_string()))?;

// Serialize the JSON string to bytes
casm_contract_class.serialize(serializer)
}
}

impl<'de> Deserialize<'de> for ContractClassV1 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Deserialize into a JSON value
let json_value: serde_json::Value = Deserialize::deserialize(deserializer)?;

// Convert into a JSON string
let json_string = serde_json::to_string(&json_value)
.map_err(|err| DeserializationError::custom(err.to_string()))?;

// Use try_from_json_string to deserialize into ContractClassV1
ContractClassV1::try_from_json_string(&json_string)
.map_err(|err| DeserializationError::custom(err.to_string()))
}
}

impl ContractClassV1 {
fn constructor_selector(&self) -> Option<EntryPointSelector> {
Some(self.0.entry_points_by_type[&EntryPointType::Constructor].first()?.selector)
Expand Down Expand Up @@ -352,7 +387,7 @@ pub struct ContractClassV1Inner {
bytecode_segment_lengths: NestedIntList,
}

#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct EntryPointV1 {
pub selector: EntryPointSelector,
pub offset: EntryPointOffset,
Expand All @@ -365,6 +400,119 @@ impl EntryPointV1 {
}
}

// Implementation of the TryInto trait to convert a reference of ContractClassV1 into
// CasmContractClass.
impl TryInto<CasmContractClass> for &ContractClassV1 {
// Definition of the error type that can be returned during the conversion.
type Error = ProgramError;

// Implementation of the try_into function which performs the conversion.
fn try_into(self) -> Result<CasmContractClass, Self::Error> {
// Converting the program data into a vector of BigUintAsHex.
let bytecode: Vec<cairo_lang_utils::bigint::BigUintAsHex> = self
.program
.iter_data()
.map(|x| cairo_lang_utils::bigint::BigUintAsHex {
value: x.get_int_ref().unwrap().to_biguint(),
})
.collect();

// Serialize the Program object to JSON bytes.
let serialized_program = self.program.serialize()?;
// Deserialize the JSON bytes into a serde_json::Value.
let json_value: serde_json::Value = serde_json::from_slice(&serialized_program)?;

// Extract the hints from the JSON value.
let hints = json_value.get("hints").ok_or_else(|| {
ProgramError::Parse(serde::ser::Error::custom("failed to parse hints"))
})?;

// Transform the hints into a vector of tuples (usize, Vec<Hint>).
let hints: Vec<(usize, Vec<Hint>)> = hints
.as_object() // Convert to JSON object.
.unwrap()
.iter()
.map(|(key, value)| {
// Transform each hint value into a Vec<Hint>.
let hints: Vec<Hint> = value
.as_array() // Convert to JSON array.
.unwrap()
.iter()
.map(|hint_params| {
// Extract the "code" parameter and convert to a string.
let hint_param_code = hint_params.get("code").unwrap().clone();
let hint_string = hint_param_code.as_str().expect("failed to parse hint as string");

// Retrieve the hint from the self.hints map.
self.hints.get(hint_string).expect("failed to get hint").clone()
})
.collect();
// Convert the key to usize and create a tuple (usize, Vec<Hint>).
(key.parse().unwrap(), hints)
})
.collect();

// Define the bytecode segment lengths
let bytecode_segment_lengths = Some(self.bytecode_segment_lengths.clone());

// Transform the entry points of type Constructor into CasmContractEntryPoint.
let constructor = self
.entry_points_by_type
.get(&EntryPointType::Constructor)
.unwrap_or(&vec![])
.iter()
.map(|constructor| CasmContractEntryPoint {
selector: num_bigint::BigUint::from_bytes_be(constructor.selector.0.bytes()),
offset: constructor.offset.0,
builtins: constructor.builtins.clone(),
})
.collect();

// Transform the entry points of type External into CasmContractEntryPoint.
let external = self
.entry_points_by_type
.get(&EntryPointType::External)
.unwrap_or(&vec![])
.iter()
.map(|external| CasmContractEntryPoint {
selector: num_bigint::BigUint::from_bytes_be(external.selector.0.bytes()),
offset: external.offset.0,
builtins: external.builtins.clone(),
})
.collect();

// Transform the entry points of type L1Handler into CasmContractEntryPoint.
let l1_handler = self
.entry_points_by_type
.get(&EntryPointType::L1Handler)
.unwrap_or(&vec![])
.iter()
.map(|l1_handler| CasmContractEntryPoint {
selector: num_bigint::BigUint::from_bytes_be(l1_handler.selector.0.bytes()),
offset: l1_handler.offset.0,
builtins: l1_handler.builtins.clone(),
})
.collect();

// Construct the CasmContractClass from the extracted and transformed data.
Ok(CasmContractClass {
prime: num_bigint::BigUint::from_str_radix(&self.program.prime()[2..], 16)
.expect("failed to parse prime"),
compiler_version: "".to_string(),
bytecode,
bytecode_segment_lengths,
hints,
pythonic_hints: None,
entry_points_by_type:
cairo_lang_starknet_classes::casm_contract_class::CasmContractEntryPoints {
constructor,
external,
l1_handler,
},
})
}
}

impl TryFrom<CasmContractClass> for ContractClassV1 {
type Error = ProgramError;

Expand Down Expand Up @@ -447,6 +595,16 @@ pub fn deserialize_program<'de, D: Deserializer<'de>>(
.map_err(|err| DeserializationError::custom(err.to_string()))
}

/// Converts the program type from Cairo VM into a SN API-compatible type.
pub fn serialize_program<S: Serializer>(
program: &Program,
serializer: S,
) -> Result<S::Ok, S::Error> {
let deprecated_program = cairo_vm_to_sn_api_program(program.clone())
.map_err(|err| serde::ser::Error::custom(err.to_string()))?;
deprecated_program.serialize(serializer)
}

// V1 utilities.

// TODO(spapini): Share with cairo-lang-runner.
Expand Down Expand Up @@ -588,6 +746,40 @@ impl TryFrom<SierraContractClass> for SierraContractClassV1 {
}
}

impl Serialize for SierraContractClassV1 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Convert the SierraContractClassV1 instance to CasmContractClass
let casm_contract_class: CasmContractClass = self
.clone()
.to_casm_contract_class()
.map_err(|err| serde::ser::Error::custom(err.to_string()))?;

// Serialize the JSON string to bytes
casm_contract_class.serialize(serializer)
}
}

impl<'de> Deserialize<'de> for SierraContractClassV1 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Deserialize into a JSON value
let json_value: serde_json::Value = Deserialize::deserialize(deserializer)?;

// Convert into a JSON string
let json_string = serde_json::to_string(&json_value)
.map_err(|err| DeserializationError::custom(err.to_string()))?;

// Use try_from_json_string to deserialize into ContractClassV1
SierraContractClassV1::try_from_json_string(&json_string)
.map_err(|err| DeserializationError::custom(err.to_string()))
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SierraContractClassV1Inner {
pub sierra_program: SierraProgram,
Expand Down
57 changes: 56 additions & 1 deletion crates/blockifier/src/execution/contract_class_test.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::collections::HashSet;
use std::fs;
use std::sync::Arc;

use assert_matches::assert_matches;
use cairo_lang_starknet_classes::NestedIntList;
use rstest::rstest;

use crate::execution::contract_class::{ContractClassV1, ContractClassV1Inner};
use crate::execution::contract_class::{ContractClassV0, ContractClassV1, ContractClassV1Inner};
use crate::transaction::errors::TransactionExecutionError;

#[rstest]
Expand Down Expand Up @@ -42,3 +43,57 @@ fn test_get_visited_segments() {
TransactionExecutionError::InvalidSegmentStructure(907, 807)
);
}

#[test]
fn test_deserialization_of_contract_class_v_0() {
let contract_class: ContractClassV0 =
serde_json::from_slice(&fs::read("./tests/cairo0/counter.json").unwrap())
.expect("failed to deserialize contract class from file");

assert_eq!(contract_class, ContractClassV0::from_file("./tests/cairo0/counter.json"));

// Serialize the ContractClassV0 instance to JSON
let serialized_contract_class =
serde_json::to_string_pretty(&contract_class).expect("Failed to serialize");

// Save the serialized JSON to a file
let output_path = std::path::Path::new("./tests/cairo0/serialized_output.json");
fs::write(output_path, &serialized_contract_class)
.expect("Failed to write serialized JSON to file");

// Re-read the serialized file for inspection
let serialized_json_content =
fs::read_to_string(output_path).expect("Failed to read serialized JSON file");

// Deserialize from the serialized string
let _deserialized_contract_class: ContractClassV0 =
serde_json::from_str(&serialized_json_content)
.expect("failed to deserialize contract class from serialized string");
}

#[test]
fn test_deserialization_of_contract_class_v_1() {
let contract_class: ContractClassV1 =
serde_json::from_slice(&fs::read("./tests/cairo1/counter.json").unwrap())
.expect("failed to deserialize contract class from file");

assert_eq!(contract_class, ContractClassV1::from_file("./tests/cairo1/counter.json"));

// Serialize the ContractClassV0 instance to JSON
let serialized_contract_class =
serde_json::to_string_pretty(&contract_class).expect("Failed to serialize");

// Save the serialized JSON to a file
let output_path = std::path::Path::new("./tests/cairo1/serialized_output.json");
fs::write(output_path, &serialized_contract_class)
.expect("Failed to write serialized JSON to file");

// Re-read the serialized file for inspection
let serialized_json_content =
fs::read_to_string(output_path).expect("Failed to read serialized JSON file");

// Deserialize from the serialized string
let _deserialized_contract_class: ContractClassV1 =
serde_json::from_str(&serialized_json_content)
.expect("failed to deserialize contract class from serialized string");
}
Loading

0 comments on commit 1b3bb4c

Please sign in to comment.