From 6a0a9d277d6bce859c852444f782d5ba0896c2dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Sm=C3=B3=C5=82ka?= Date: Tue, 1 Oct 2024 22:36:52 +0200 Subject: [PATCH] Introduced data transformer commit-id:8d990066 --- Cargo.lock | 5 + crates/sncast/Cargo.toml | 5 + .../calldata_representation.rs | 220 ++++ .../src/helpers/data_transformer/mod.rs | 3 + .../helpers/data_transformer/sierra_abi.rs | 621 ++++++++++ .../helpers/data_transformer/transformer.rs | 161 +++ crates/sncast/src/helpers/mod.rs | 1 + .../tests/integration/data_transformer.rs | 1022 +++++++++++++++++ crates/sncast/tests/integration/mod.rs | 1 + 9 files changed, 2039 insertions(+) create mode 100644 crates/sncast/src/helpers/data_transformer/calldata_representation.rs create mode 100644 crates/sncast/src/helpers/data_transformer/mod.rs create mode 100644 crates/sncast/src/helpers/data_transformer/sierra_abi.rs create mode 100644 crates/sncast/src/helpers/data_transformer/transformer.rs create mode 100644 crates/sncast/tests/integration/data_transformer.rs diff --git a/Cargo.lock b/Cargo.lock index 06dfc1e98b..1beba1dda2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4663,9 +4663,13 @@ dependencies = [ "base16ct", "blockifier", "cairo-lang-casm", + "cairo-lang-diagnostics", + "cairo-lang-filesystem", + "cairo-lang-parser", "cairo-lang-runner", "cairo-lang-sierra", "cairo-lang-sierra-to-casm", + "cairo-lang-syntax", "cairo-lang-utils", "cairo-vm", "camino", @@ -4678,6 +4682,7 @@ dependencies = [ "fs_extra", "indoc", "itertools 0.12.1", + "num-bigint", "num-traits 0.2.19", "primitive-types", "project-root", diff --git a/crates/sncast/Cargo.toml b/crates/sncast/Cargo.toml index 550c8cf826..60067a007e 100644 --- a/crates/sncast/Cargo.toml +++ b/crates/sncast/Cargo.toml @@ -35,8 +35,13 @@ cairo-lang-casm.workspace = true cairo-lang-sierra-to-casm.workspace = true cairo-lang-utils.workspace = true cairo-lang-sierra.workspace = true +cairo-lang-parser.workspace = true +cairo-lang-syntax.workspace = true +cairo-lang-diagnostics.workspace = true +cairo-lang-filesystem.workspace = true itertools.workspace = true num-traits.workspace = true +num-bigint.workspace = true starknet-types-core.workspace = true cairo-vm.workspace = true blockifier.workspace = true diff --git a/crates/sncast/src/helpers/data_transformer/calldata_representation.rs b/crates/sncast/src/helpers/data_transformer/calldata_representation.rs new file mode 100644 index 0000000000..1b381f087b --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/calldata_representation.rs @@ -0,0 +1,220 @@ +use anyhow::{bail, Context}; +use conversions::{ + byte_array::ByteArray, + bytes31::CairoBytes31, + serde::serialize::{BufferWriter, CairoSerialize}, + u256::CairoU256, + u512::CairoU512, +}; +use starknet::core::types::Felt; +use std::{any::type_name, str::FromStr}; + +#[derive(Debug)] +pub(super) struct CalldataStructField(AllowedCalldataArguments); + +impl CalldataStructField { + pub fn new(value: AllowedCalldataArguments) -> Self { + Self(value) + } +} + +#[derive(Debug)] +pub(super) struct CalldataStruct(Vec); + +impl CalldataStruct { + pub fn new(arguments: Vec) -> Self { + Self(arguments) + } +} + +#[derive(Debug)] +pub(super) struct CalldataArrayMacro(Vec); + +impl CalldataArrayMacro { + pub fn new(arguments: Vec) -> Self { + Self(arguments) + } +} + +#[derive(Debug)] +pub(super) struct CalldataEnum { + position: usize, + argument: Option>, +} + +impl CalldataEnum { + pub fn new(position: usize, argument: Option>) -> Self { + Self { position, argument } + } +} + +#[derive(Debug)] +pub(super) enum CalldataPrimitive { + Bool(bool), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), + U256(CairoU256), + U512(CairoU512), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + I128(i128), + Felt(Felt), + ByteArray(ByteArray), +} + +fn neat_parsing_error_message(value: &str, parsing_type: &str, reason: Option<&str>) -> String { + if let Some(message) = reason { + format!(r#"Failed to parse value "{value}" into type "{parsing_type}": {message}"#) + } else { + format!(r#"Failed to parse value "{value}" into type "{parsing_type}""#) + } +} + +fn parse_with_type(value: &str) -> anyhow::Result +where + ::Err: std::error::Error + Send + Sync + 'static, +{ + value + .parse::() + .context(neat_parsing_error_message(value, type_name::(), None)) +} + +impl CalldataPrimitive { + pub(super) fn try_new(type_with_path: &str, value: &str) -> anyhow::Result { + let type_str = type_with_path + .split("::") + .last() + .context("Couldn't parse parameter type from ABI")?; + + // TODO add all corelib types (Issue #2550) + match type_str { + "u8" => Ok(Self::U8(parse_with_type(value)?)), + "u16" => Ok(Self::U16(parse_with_type(value)?)), + "u32" => Ok(Self::U32(parse_with_type(value)?)), + "u64" => Ok(Self::U64(parse_with_type(value)?)), + "u128" => Ok(Self::U128(parse_with_type(value)?)), + "u256" => Ok(Self::U256(parse_with_type(value)?)), + "u512" => Ok(Self::U512(parse_with_type(value)?)), + "i8" => Ok(Self::I8(parse_with_type(value)?)), + "i16" => Ok(Self::I16(parse_with_type(value)?)), + "i32" => Ok(Self::I32(parse_with_type(value)?)), + "i64" => Ok(Self::I64(parse_with_type(value)?)), + "i128" => Ok(Self::I128(parse_with_type(value)?)), + // bytes31 is a helper type defined in Cairo corelib; + // (e.g. alexandria_data_structures::bit_array::BitArray uses that) + // https://github.com/starkware-libs/cairo/blob/bf48e658b9946c2d5446eeb0c4f84868e0b193b5/corelib/src/bytes_31.cairo#L14 + // It's actually felt under the hood. Although conversion from felt252 to bytes31 returns Result, it never fails. + "bytes31" => Ok(Self::Felt(parse_with_type::(value)?.into())), + "felt252" | "felt" | "ContractAddress" | "ClassHash" | "StorageAddress" + | "EthAddress" => { + let felt = Felt::from_dec_str(value) + .with_context(|| neat_parsing_error_message(value, type_with_path, None))?; + Ok(Self::Felt(felt)) + } + "bool" => Ok(Self::Bool(parse_with_type(value)?)), + "ByteArray" => Ok(Self::ByteArray(ByteArray::from(value))), + _ => { + bail!(neat_parsing_error_message( + value, + type_with_path, + Some(&format!("unsupported type {type_with_path}")) + )) + } + } + } +} + +#[derive(Debug)] +pub(super) struct CalldataTuple(Vec); + +impl CalldataTuple { + pub fn new(arguments: Vec) -> Self { + Self(arguments) + } +} + +#[derive(Debug)] +pub(super) enum AllowedCalldataArguments { + Struct(CalldataStruct), + ArrayMacro(CalldataArrayMacro), + Enum(CalldataEnum), + Primitive(CalldataPrimitive), + Tuple(CalldataTuple), +} + +impl CairoSerialize for CalldataPrimitive { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/ + fn serialize(&self, output: &mut BufferWriter) { + match self { + CalldataPrimitive::Bool(value) => value.serialize(output), + CalldataPrimitive::U8(value) => value.serialize(output), + CalldataPrimitive::U16(value) => value.serialize(output), + CalldataPrimitive::U32(value) => value.serialize(output), + CalldataPrimitive::U64(value) => value.serialize(output), + CalldataPrimitive::U128(value) => value.serialize(output), + CalldataPrimitive::U256(value) => value.serialize(output), + CalldataPrimitive::U512(value) => value.serialize(output), + CalldataPrimitive::I8(value) => value.serialize(output), + CalldataPrimitive::I16(value) => value.serialize(output), + CalldataPrimitive::I32(value) => value.serialize(output), + CalldataPrimitive::I64(value) => value.serialize(output), + CalldataPrimitive::I128(value) => value.serialize(output), + CalldataPrimitive::Felt(value) => value.serialize(output), + CalldataPrimitive::ByteArray(value) => value.serialize(output), + }; + } +} + +impl CairoSerialize for CalldataStructField { + // Every argument serialized in order of occurrence + fn serialize(&self, output: &mut BufferWriter) { + self.0.serialize(output); + } +} + +impl CairoSerialize for CalldataStruct { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/#serialization_of_structs + fn serialize(&self, output: &mut BufferWriter) { + self.0.iter().for_each(|field| field.serialize(output)); + } +} + +impl CairoSerialize for CalldataTuple { + fn serialize(&self, output: &mut BufferWriter) { + self.0.iter().for_each(|field| field.serialize(output)); + } +} + +impl CairoSerialize for CalldataArrayMacro { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/#serialization_of_arrays + fn serialize(&self, output: &mut BufferWriter) { + self.0.len().serialize(output); + self.0.iter().for_each(|field| field.serialize(output)); + } +} + +impl CairoSerialize for CalldataEnum { + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/serialization-of-cairo-types/#serialization_of_enums + fn serialize(&self, output: &mut BufferWriter) { + self.position.serialize(output); + if self.argument.is_some() { + self.argument.as_ref().unwrap().serialize(output); + } + } +} +impl CairoSerialize for AllowedCalldataArguments { + fn serialize(&self, output: &mut BufferWriter) { + match self { + AllowedCalldataArguments::Struct(value) => value.serialize(output), + AllowedCalldataArguments::ArrayMacro(value) => value.serialize(output), + AllowedCalldataArguments::Enum(value) => value.serialize(output), + AllowedCalldataArguments::Primitive(value) => value.serialize(output), + AllowedCalldataArguments::Tuple(value) => value.serialize(output), + } + } +} diff --git a/crates/sncast/src/helpers/data_transformer/mod.rs b/crates/sncast/src/helpers/data_transformer/mod.rs new file mode 100644 index 0000000000..00a1c860eb --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/mod.rs @@ -0,0 +1,3 @@ +pub mod calldata_representation; +pub mod sierra_abi; +pub mod transformer; diff --git a/crates/sncast/src/helpers/data_transformer/sierra_abi.rs b/crates/sncast/src/helpers/data_transformer/sierra_abi.rs new file mode 100644 index 0000000000..199068cfa8 --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/sierra_abi.rs @@ -0,0 +1,621 @@ +use crate::helpers::data_transformer::calldata_representation::{ + AllowedCalldataArguments, CalldataArrayMacro, CalldataEnum, CalldataPrimitive, CalldataStruct, + CalldataStructField, CalldataTuple, +}; +use anyhow::{bail, ensure, Context, Result}; +use cairo_lang_parser::utils::SimpleParserDatabase; +use cairo_lang_syntax::node::ast::PathSegment::Simple; +use cairo_lang_syntax::node::ast::{ + ArgClause, ArgList, Expr, ExprFunctionCall, ExprInlineMacro, ExprListParenthesized, ExprPath, + ExprStructCtorCall, ExprUnary, OptionStructArgExpr, PathSegment, StructArg, TerminalFalse, + TerminalLiteralNumber, TerminalShortString, TerminalString, TerminalTrue, UnaryOperator, + WrappedArgList, +}; +use cairo_lang_syntax::node::{Terminal, Token}; +use itertools::Itertools; +use regex::Regex; +use starknet::core::types::contract::{AbiEntry, AbiEnum, AbiNamedMember, AbiStruct}; +use std::collections::HashSet; +use std::ops::Neg; + +pub(super) fn build_representation( + expression: Expr, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, +) -> Result { + match expression { + Expr::StructCtorCall(item) => item.transform(expected_type, abi, db), + Expr::Literal(item) => item.transform(expected_type, abi, db), + Expr::Unary(item) => item.transform(expected_type, abi, db), + Expr::ShortString(item) => item.transform(expected_type, abi, db), + Expr::String(item) => item.transform(expected_type, abi, db), + Expr::False(item) => item.transform(expected_type, abi, db), + Expr::True(item) => item.transform(expected_type, abi, db), + Expr::Path(item) => item.transform(expected_type, abi, db), + Expr::FunctionCall(item) => item.transform(expected_type, abi, db), + Expr::InlineMacro(item) => item.transform(expected_type, abi, db), + Expr::Tuple(item) => item.transform(expected_type, abi, db), + _ => { + bail!(r#"Invalid argument type: unsupported expression for type "{expected_type}""#) + } + } +} + +trait SupportedCalldataKind { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result; +} + +impl SupportedCalldataKind for ExprStructCtorCall { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let struct_path: Vec = split(&self.path(db), db)?; + let struct_path_joined = struct_path.clone().join("::"); + + validate_path_argument(expected_type, &struct_path, &struct_path_joined)?; + + let structs_from_abi = find_all_structs(abi); + let struct_abi_definition = find_valid_enum_or_struct(structs_from_abi, &struct_path)?; + + let struct_args = self.arguments(db).arguments(db).elements(db); + + let struct_args_with_values = get_struct_arguments_with_values(&struct_args, db) + .context("Found invalid expression in struct argument")?; + + if struct_args_with_values.len() != struct_abi_definition.members.len() { + bail!( + r#"Invalid number of struct arguments in struct "{}", expected {} arguments, found {}"#, + struct_path_joined, + struct_abi_definition.members.len(), + struct_args.len() + ) + } + + // validate if all arguments' names have corresponding names in abi + if struct_args_with_values + .iter() + .map(|(arg_name, _)| arg_name.clone()) + .collect::>() + != struct_abi_definition + .members + .iter() + .map(|x| x.name.clone()) + .collect::>() + { + // TODO add message which arguments are invalid (Issue #2549) + bail!( + r#"Arguments in constructor invocation for struct {} do not match struct arguments in ABI"#, + expected_type + ) + } + + let fields = struct_args_with_values + .into_iter() + .map(|(arg_name, expr)| { + let abi_entry = struct_abi_definition + .members + .iter() + .find(|&abi_member| abi_member.name == arg_name) + .expect("Arg name should be in ABI - it is checked before with HashSets"); + Ok(CalldataStructField::new(build_representation( + expr, + &abi_entry.r#type, + abi, + db, + )?)) + }) + .collect::>>()?; + + Ok(AllowedCalldataArguments::Struct(CalldataStruct::new( + fields, + ))) + } +} + +impl SupportedCalldataKind for TerminalLiteralNumber { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let (value, suffix) = self + .numeric_value_and_suffix(db) + .with_context(|| format!("Couldn't parse value: {}", self.text(db)))?; + + let proper_param_type = match suffix { + None => expected_type, + Some(ref suffix) => suffix.as_str(), + }; + + Ok(AllowedCalldataArguments::Primitive( + CalldataPrimitive::try_new(proper_param_type, &value.to_string())?, + )) + } +} + +impl SupportedCalldataKind for ExprUnary { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let (value, suffix) = match self.expr(db) { + Expr::Literal(literal_number) => literal_number + .numeric_value_and_suffix(db) + .with_context(|| format!("Couldn't parse value: {}", literal_number.text(db))), + _ => bail!("Invalid expression with unary operator, only numbers allowed"), + }?; + + let proper_param_type = match suffix { + None => expected_type, + Some(ref suffix) => suffix.as_str(), + }; + + match self.op(db) { + UnaryOperator::Not(_) => bail!( + "Invalid unary operator in expression !{} , only - allowed, got !", + value + ), + UnaryOperator::Desnap(_) => bail!( + "Invalid unary operator in expression *{} , only - allowed, got *", + value + ), + UnaryOperator::BitNot(_) => bail!( + "Invalid unary operator in expression ~{} , only - allowed, got ~", + value + ), + UnaryOperator::At(_) => bail!( + "Invalid unary operator in expression @{} , only - allowed, got @", + value + ), + UnaryOperator::Minus(_) => {} + } + + Ok(AllowedCalldataArguments::Primitive( + CalldataPrimitive::try_new(proper_param_type, &value.neg().to_string())?, + )) + } +} + +impl SupportedCalldataKind for TerminalShortString { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let value = self + .string_value(db) + .context("Invalid shortstring passed as an argument")?; + + Ok(AllowedCalldataArguments::Primitive( + CalldataPrimitive::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for TerminalString { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let value = self + .string_value(db) + .context("Invalid string passed as an argument")?; + + Ok(AllowedCalldataArguments::Primitive( + CalldataPrimitive::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for TerminalFalse { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Could use terminal_false.boolean_value(db) and simplify try_new() + let value = self.text(db).to_string(); + + Ok(AllowedCalldataArguments::Primitive( + CalldataPrimitive::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for TerminalTrue { + fn transform( + &self, + expected_type: &str, + _abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + let value = self.text(db).to_string(); + + Ok(AllowedCalldataArguments::Primitive( + CalldataPrimitive::try_new(expected_type, &value)?, + )) + } +} + +impl SupportedCalldataKind for ExprPath { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Enums with no value - Enum::Variant + let enum_path_with_variant = split(self, db)?; + let (enum_variant_name, enum_path) = enum_path_with_variant.split_last().unwrap(); + let enum_path_joined = enum_path.join("::"); + + validate_path_argument(expected_type, enum_path, &enum_path_joined)?; + + let (enum_position, enum_variant) = + find_enum_variant_position(enum_variant_name, enum_path, abi)?; + + if enum_variant.r#type != "()" { + bail!( + r#"Couldn't find variant "{}" in enum "{}""#, + enum_variant_name, + enum_path_joined + ) + } + + Ok(AllowedCalldataArguments::Enum(CalldataEnum::new( + enum_position, + None, + ))) + } +} + +impl SupportedCalldataKind for ExprFunctionCall { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Enums with value - Enum::Variant(10) + let enum_path_with_variant = split(&self.path(db), db)?; + let (enum_variant_name, enum_path) = enum_path_with_variant.split_last().unwrap(); + let enum_path_joined = enum_path.join("::"); + + validate_path_argument(expected_type, enum_path, &enum_path_joined)?; + + let (enum_position, enum_variant) = + find_enum_variant_position(enum_variant_name, enum_path, abi)?; + + // When creating an enum with variant, there can be only one argument. Parsing the + // argument inside ArgList (enum_expr_path_with_value.arguments(db).arguments(db)), + // then popping from the vector and unwrapping safely. + let expr = parse_argument_list(&self.arguments(db).arguments(db), db)? + .pop() + .unwrap(); + let parsed_expr = build_representation(expr, &enum_variant.r#type, abi, db)?; + + Ok(AllowedCalldataArguments::Enum(CalldataEnum::new( + enum_position, + Some(Box::new(parsed_expr)), + ))) + } +} + +impl SupportedCalldataKind for ExprInlineMacro { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // array![] calls + let parsed_exprs = parse_inline_macro(self, db)?; + + let array_element_type_pattern = Regex::new("core::array::Array::<(.*)>").unwrap(); + let abi_argument_type = array_element_type_pattern + .captures(expected_type) + .with_context(|| { + format!(r#"Invalid argument type, expected "{expected_type}", got array"#,) + })? + .get(1) + .with_context(|| { + format!( + "Couldn't parse array element type from the ABI array parameter: {expected_type}" + ) + })? + .as_str(); + + let arguments = parsed_exprs + .into_iter() + .map(|arg| build_representation(arg, abi_argument_type, abi, db)) + .collect::>>()?; + + Ok(AllowedCalldataArguments::ArrayMacro( + CalldataArrayMacro::new(arguments), + )) + } +} + +impl SupportedCalldataKind for ExprListParenthesized { + fn transform( + &self, + expected_type: &str, + abi: &[AbiEntry], + db: &SimpleParserDatabase, + ) -> Result { + // Regex capturing types between the parentheses, e.g.: for "(core::felt252, core::u8)" + // will capture "core::felt252, core::u8" + let tuple_types_pattern = Regex::new(r"\(([^)]+)\)").unwrap(); + let tuple_types: Vec<&str> = tuple_types_pattern + .captures(expected_type) + .with_context(|| { + format!(r#"Invalid argument type, expected "{expected_type}", got tuple"#,) + })? + .get(1) + .map(|x| x.as_str().split(", ").collect()) + .unwrap(); + + let parsed_exprs = self + .expressions(db) + .elements(db) + .into_iter() + .zip(tuple_types) + .map(|(expr, single_param)| build_representation(expr, single_param, abi, db)) + .collect::>>()?; + + Ok(AllowedCalldataArguments::Tuple(CalldataTuple::new( + parsed_exprs, + ))) + } +} + +fn split(path: &ExprPath, db: &SimpleParserDatabase) -> Result> { + path.elements(db) + .iter() + .map(|p| match p { + Simple(segment) => Ok(segment.ident(db).token(db).text(db).to_string()), + PathSegment::WithGenericArgs(_) => { + bail!("Cannot use generic args when specifying struct/enum path") + } + }) + .collect::>>() +} + +fn get_struct_arguments_with_values( + arguments: &[StructArg], + db: &SimpleParserDatabase, +) -> Result> { + arguments + .iter() + .map(|elem| { + match elem { + // Holds info about parameter and argument in struct creation, e.g.: + // in case of "Struct { a: 1, b: 2 }", two separate StructArgSingle hold info + // about "a: 1" and "b: 2" respectively. + StructArg::StructArgSingle(whole_arg) => { + match whole_arg.arg_expr(db) { + // It's probably a case of constructor invocation `Struct {a, b}` catching variables `a` and `b` from context + OptionStructArgExpr::Empty(_) => { + bail!( + "Shorthand arguments are not allowed - used \"{ident}\", expected \"{ident}: value\"", + ident = whole_arg.identifier(db).text(db) + ) + } + // Holds info about the argument, e.g.: in case of "a: 1" holds info + // about ": 1" + OptionStructArgExpr::StructArgExpr(arg_value_with_colon) => Ok(( + whole_arg.identifier(db).text(db).to_string(), + arg_value_with_colon.expr(db), + )), + } + } + StructArg::StructArgTail(_) => { + bail!("Struct unpack-init with \"..\" operator is not allowed") + } + } + }) + .collect() +} + +fn find_enum_variant_position<'a>( + variant: &String, + path: &[String], + abi: &'a [AbiEntry], +) -> Result<(usize, &'a AbiNamedMember)> { + let enums_from_abi = abi + .iter() + .filter_map(|abi_entry| { + if let AbiEntry::Enum(abi_enum) = abi_entry { + Some(abi_enum) + } else { + None + } + }) + .collect::>(); + + let enum_abi_definition = find_valid_enum_or_struct(enums_from_abi, path)?; + + let position_and_enum_variant = enum_abi_definition + .variants + .iter() + .find_position(|item| item.name == *variant) + .with_context(|| { + format!( + r#"Couldn't find variant "{}" in enum "{}""#, + variant, + path.join("::") + ) + })?; + + Ok(position_and_enum_variant) +} + +fn parse_argument_list(arguments: &ArgList, db: &SimpleParserDatabase) -> Result> { + let arguments = arguments.elements(db); + if arguments + .iter() + .map(|arg| arg.modifiers(db).elements(db)) + .any(|mod_list| !mod_list.is_empty()) + { + bail!("\"ref\" and \"mut\" modifiers are not allowed") + } + + arguments + .iter() + .map(|arg| match arg.arg_clause(db) { + ArgClause::Unnamed(expr) => Ok(expr.value(db)), + ArgClause::Named(_) => { + bail!("Named arguments are not allowed") + } + ArgClause::FieldInitShorthand(_) => { + bail!("Field init shorthands are not allowed") + } + }) + .collect::>>() +} + +fn parse_inline_macro( + invocation: &ExprInlineMacro, + db: &SimpleParserDatabase, +) -> Result> { + match invocation + .path(db) + .elements(db) + .iter() + .last() + .expect("Macro must have a name") + { + Simple(simple) => { + let macro_name = simple.ident(db).text(db); + if macro_name != "array" { + bail!( + r#"Invalid macro name, expected "array![]", got "{}""#, + macro_name + ) + } + } + PathSegment::WithGenericArgs(_) => { + bail!("Invalid path specified: generic args in array![] macro not supported") + } + }; + + match invocation.arguments(db) { + WrappedArgList::BracketedArgList(args) => { + let arglist = args.arguments(db); + parse_argument_list(&arglist, db) + } + WrappedArgList::ParenthesizedArgList(_) | WrappedArgList::BracedArgList(_) => + bail!("`array` macro supports only square brackets: array![]"), + WrappedArgList::Missing(_) => unreachable!("If any type of parentheses is missing, then diagnostics have been reported and whole flow should have already been terminated.") + } +} + +fn find_all_structs(abi: &[AbiEntry]) -> Vec<&AbiStruct> { + abi.iter() + .filter_map(|entry| match entry { + AbiEntry::Struct(r#struct) => Some(r#struct), + _ => None, + }) + .collect() +} + +fn validate_path_argument( + param_type: &str, + path_argument: &[String], + path_argument_joined: &String, +) -> Result<()> { + if *path_argument.last().unwrap() != param_type.split("::").last().unwrap() + && path_argument_joined != param_type + { + bail!( + r#"Invalid argument type, expected "{}", got "{}""#, + param_type, + path_argument_joined + ) + } + Ok(()) +} + +trait EnumOrStruct { + const VARIANT: &'static str; + const VARIANT_CAPITALIZED: &'static str; + fn name(&self) -> String; +} + +impl EnumOrStruct for AbiStruct { + const VARIANT: &'static str = "struct"; + const VARIANT_CAPITALIZED: &'static str = "Struct"; + + fn name(&self) -> String { + self.name.clone() + } +} + +impl EnumOrStruct for AbiEnum { + const VARIANT: &'static str = "enum"; + const VARIANT_CAPITALIZED: &'static str = "Enum"; + + fn name(&self) -> String { + self.name.clone() + } +} + +// 'item' here means enum or struct +fn find_valid_enum_or_struct<'item, T: EnumOrStruct>( + items_from_abi: Vec<&'item T>, + path: &[String], +) -> Result<&'item T> { + // Argument is a module path to an item (module_name::StructName {}) + if path.len() > 1 { + let full_path_item = items_from_abi + .into_iter() + .find(|x| x.name() == path.join("::")); + + ensure!( + full_path_item.is_some(), + r#"{} "{}" not found in ABI"#, + T::VARIANT_CAPITALIZED, + path.join("::") + ); + + return Ok(full_path_item.unwrap()); + } + + // Argument is just the name of the item (Struct {}) + let mut matching_items_from_abi: Vec<&T> = items_from_abi + .into_iter() + .filter(|x| x.name().split("::").last() == path.last().map(String::as_str)) + .collect(); + + ensure!( + !matching_items_from_abi.is_empty(), + r#"{} "{}" not found in ABI"#, + T::VARIANT_CAPITALIZED, + path.join("::") + ); + + ensure!( + matching_items_from_abi.len() == 1, + r#"Found more than one {} "{}" in ABI, please specify a full path to the item"#, + T::VARIANT, + path.join("::") + ); + + Ok(matching_items_from_abi.pop().unwrap()) +} diff --git a/crates/sncast/src/helpers/data_transformer/transformer.rs b/crates/sncast/src/helpers/data_transformer/transformer.rs new file mode 100644 index 0000000000..1aa657e4ee --- /dev/null +++ b/crates/sncast/src/helpers/data_transformer/transformer.rs @@ -0,0 +1,161 @@ +use crate::helpers::data_transformer::sierra_abi::build_representation; +use anyhow::{bail, ensure, Context, Result}; +use cairo_lang_diagnostics::DiagnosticsBuilder; +use cairo_lang_filesystem::ids::{FileKind, FileLongId, VirtualFile}; +use cairo_lang_parser::parser::Parser; +use cairo_lang_parser::utils::SimpleParserDatabase; +use cairo_lang_syntax::node::ast::Expr; +use cairo_lang_utils::Intern; +use conversions::serde::serialize::SerializeToFeltVec; +use itertools::Itertools; +use starknet::core::types::contract::{AbiEntry, AbiFunction, StateMutability}; +use starknet::core::types::{ContractClass, Felt}; +use starknet::core::utils::get_selector_from_name; +use std::collections::HashMap; + +pub fn transform( + calldata: &[String], + class_definition: ContractClass, + function_selector: &Felt, +) -> Result> { + let sierra_class = match class_definition { + ContractClass::Sierra(class) => class, + ContractClass::Legacy(_) => { + bail!("Transformation of Cairo-like expressions is not available for Cairo0 contracts") + } + }; + + let abi: Vec = serde_json::from_str(sierra_class.abi.as_str()) + .context("Couldn't deserialize ABI received from chain")?; + + let selector_function_map = map_selectors_to_functions(&abi); + + let function = selector_function_map + .get(function_selector) + .with_context(|| { + format!( + r#"Function with selector "{function_selector}" not found in ABI of the contract"# + ) + })?; + + let db = SimpleParserDatabase::default(); + + let result_for_cairo_expression_input = + process_as_cairo_expressions(calldata, function, &abi, &db) + .context("Error while processing Cairo-like calldata"); + + if result_for_cairo_expression_input.is_ok() { + return result_for_cairo_expression_input; + } + + let result_for_serialized_input = process_as_serialized(calldata, &abi, &db) + .context("Error while processing serialized calldata"); + + match result_for_serialized_input { + Err(_) => result_for_cairo_expression_input, + ok => ok, + } +} + +fn process_as_cairo_expressions( + calldata: &[String], + function: &AbiFunction, + abi: &[AbiEntry], + db: &SimpleParserDatabase, +) -> Result> { + let n_inputs = function.inputs.len(); + let n_arguments = calldata.len(); + + ensure!( + n_inputs == n_arguments, + "Invalid number of arguments: passed {}, expected {}", + n_arguments, + n_inputs, + ); + + function + .inputs + .iter() + .zip(calldata) + .map(|(parameter, value)| { + let expr = parse(value, db)?; + let representation = build_representation(expr, ¶meter.r#type, abi, db)?; + Ok(representation.serialize_to_vec()) + }) + .flatten_ok() + .collect::>() +} + +fn process_as_serialized( + calldata: &[String], + abi: &[AbiEntry], + db: &SimpleParserDatabase, +) -> Result> { + calldata + .iter() + .map(|expression| { + let expr = parse(expression, db)?; + let representation = build_representation(expr, "felt252", abi, db)?; + Ok(representation.serialize_to_vec()) + }) + .flatten_ok() + .collect::>() +} + +fn map_selectors_to_functions(abi: &[AbiEntry]) -> HashMap { + let mut map = HashMap::new(); + + for abi_entry in abi { + match abi_entry { + AbiEntry::Function(func) => { + map.insert( + get_selector_from_name(func.name.as_str()).unwrap(), + func.clone(), + ); + } + AbiEntry::Constructor(constructor) => { + // Transparency of constructors and other functions + map.insert( + get_selector_from_name(constructor.name.as_str()).unwrap(), + AbiFunction { + name: constructor.name.clone(), + inputs: constructor.inputs.clone(), + outputs: vec![], + state_mutability: StateMutability::View, + }, + ); + } + AbiEntry::Interface(interface) => { + map.extend(map_selectors_to_functions(&interface.items)); + } + _ => {} + } + } + + map +} + +fn parse(source: &str, db: &SimpleParserDatabase) -> Result { + let file = FileLongId::Virtual(VirtualFile { + parent: None, + name: "parser_input".into(), + content: source.to_string().into(), + code_mappings: [].into(), + kind: FileKind::Expr, + }) + .intern(db); + + let mut diagnostics = DiagnosticsBuilder::default(); + let expression = Parser::parse_file_expr(db, &mut diagnostics, file, source); + let diagnostics = diagnostics.build(); + + if diagnostics.check_error_free().is_err() { + bail!( + "Invalid Cairo expression found in input calldata \"{}\":\n{}", + source, + diagnostics.format(db) + ) + } + + Ok(expression) +} diff --git a/crates/sncast/src/helpers/mod.rs b/crates/sncast/src/helpers/mod.rs index b9ac734f9e..9453a22e45 100644 --- a/crates/sncast/src/helpers/mod.rs +++ b/crates/sncast/src/helpers/mod.rs @@ -2,6 +2,7 @@ pub mod block_explorer; pub mod braavos; pub mod configuration; pub mod constants; +pub mod data_transformer; pub mod error; pub mod fee; pub mod rpc; diff --git a/crates/sncast/tests/integration/data_transformer.rs b/crates/sncast/tests/integration/data_transformer.rs new file mode 100644 index 0000000000..8b6f0101e1 --- /dev/null +++ b/crates/sncast/tests/integration/data_transformer.rs @@ -0,0 +1,1022 @@ +use core::fmt; +use itertools::Itertools; +use primitive_types::U256; +use shared::rpc::create_rpc_client; +use sncast::helpers::data_transformer::transformer::transform; +use starknet::core::types::{BlockId, BlockTag, ContractClass, Felt}; +use starknet::core::utils::get_selector_from_name; +use starknet::providers::Provider; +use std::ops::Not; +use test_case::test_case; +use tokio::sync::OnceCell; + +const RPC_ENDPOINT: &str = "http://188.34.188.184:7070/rpc/v0_7"; + +// https://sepolia.starkscan.co/class/0x02a9b456118a86070a8c116c41b02e490f3dcc9db3cad945b4e9a7fd7cec9168#code +const TEST_CLASS_HASH: Felt = + Felt::from_hex_unchecked("0x02a9b456118a86070a8c116c41b02e490f3dcc9db3cad945b4e9a7fd7cec9168"); + +static CLASS: OnceCell = OnceCell::const_new(); + +// 2^128 + 3 +// const BIG_NUMBER: &str = "340282366920938463463374607431768211459"; + +async fn init_class() -> ContractClass { + let client = create_rpc_client(RPC_ENDPOINT).unwrap(); + + client + .get_class(BlockId::Tag(BlockTag::Latest), TEST_CLASS_HASH) + .await + .unwrap() +} + +trait Contains { + fn assert_contains(&self, value: T); +} + +impl Contains<&str> for anyhow::Error { + fn assert_contains(&self, value: &str) { + self.chain() + .any(|err| err.to_string().contains(value)) + .not() + .then(|| panic!("{value:?}\nnot found in\n{self:#?}")); + } +} + +#[tokio::test] +async fn test_function_not_found() { + let simulated_cli_input = vec![String::from("'some_felt'")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("nonexistent_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + assert!(output.is_err()); + output.unwrap_err().assert_contains( + format!(r#"Function with selector "{selector}" not found in ABI of the contract"#).as_str(), + ); +} + +#[tokio::test] +async fn test_happy_case_numeric_type_suffix() -> anyhow::Result<()> { + let simulated_cli_input = vec![String::from("1010101_u32")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("unsigned_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector)?; + + assert_eq!(output, vec![Felt::from(1_010_101_u32)]); + + Ok(()) +} + +#[tokio::test] +async fn test_invalid_numeric_type_suffix() { + let simulated_cli_input = vec![String::from("1_u10")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("simple_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + assert!(output.is_err()); + + output + .unwrap_err() + .assert_contains(r#"Failed to parse value "1" into type "u10": unsupported type u10"#); +} + +#[tokio::test] +async fn test_invalid_cairo_expression() { + let simulated_cli_input = vec![String::from("some_invalid_expression:")]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("simple_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + output + .unwrap_err() + .assert_contains("Invalid Cairo expression found in input calldata"); +} + +#[tokio::test] +async fn test_invalid_argument_number() { + let simulated_cli_input = vec!["0x123", "'some_obsolete_argument'", "10"] + .into_iter() + .map(String::from) + .collect_vec(); + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("simple_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + output + .unwrap_err() + .assert_contains("Invalid number of arguments: passed 3, expected 1"); +} + +#[tokio::test] +async fn test_happy_case_simple_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("100")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![100.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_simple_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x64")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![100.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_u256_function_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![U256::MAX.to_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("u256_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0xffffffffffffffffffffffffffffffff"), + Felt::from_hex_unchecked("0xffffffffffffffffffffffffffffffff"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_u256_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x2137"), String::from("0x0")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("u256_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x2137"), + Felt::from_hex_unchecked("0x0"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_signed_function_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("-273")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("signed_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(-273i16)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_signed_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![Felt::from(-273i16).to_hex_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("signed_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(-273i16)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +// Problem: Although transformer fails to process the given input as `i32`, itthen succeeds to interpret it as `felt252` +// Overflow checks will not work for functions having the same serialized and Cairo-like calldata length. +// User must provide a type suffix or get the invoke-time error +#[ignore = "Impossible to pass with the current solution"] +#[tokio::test] +async fn test_signed_fn_overflow() { + let simulated_cli_input = vec![(i32::MAX as u64 + 1).to_string()]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("signed_fn").unwrap(); + + let output = transform(&simulated_cli_input, contract_class, &selector); + + output + .unwrap_err() + .assert_contains(r#"Failed to parse value "2147483648" into type "i32""#); +} + +#[tokio::test] +async fn test_signed_fn_overflow_with_type_suffix() { + let simulated_cli_input = vec![format!("{}_i32", i32::MAX as u64 + 1)]; + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + let selector = get_selector_from_name("signed_fn").unwrap(); + + let result = transform(&simulated_cli_input, contract_class, &selector); + + result + .unwrap_err() + .assert_contains(r#"Failed to parse value "2147483648" into type "i32""#); +} + +#[tokio::test] +async fn test_happy_case_unsigned_function_cairo_expressions_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![u32::MAX.to_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("unsigned_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(u32::MAX)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_unsigned_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![Felt::from(u32::MAX).to_hex_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("unsigned_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from(u32::MAX)]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_tuple_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("(2137_felt252, 1_u8, Enum::One)")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("tuple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![2137.into(), 1.into(), 0.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_tuple_function_with_nested_struct_cairo_expression_input( +) -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from( + "(123, 234, Enum::Three(NestedStructWithField {a: SimpleStruct {a: 345}, b: 456 }))", + )]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("tuple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![123, 234, 2, 345, 456] + .into_iter() + .map(Felt::from) + .collect(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_tuple_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("0x859"), + String::from("0x1"), + String::from("0x0"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("tuple_fn").unwrap(), + )?; + + let expected_output: Vec = vec![2137.into(), 1.into(), 0.into()]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_complex_function_cairo_expressions_input() -> anyhow::Result<()> { + let max_u256 = U256::max_value().to_string(); + + let simulated_cli_input = vec![ + "array![array![0x2137, 0x420], array![0x420, 0x2137]]", + "8_u8", + "-270", + "\"some_string\"", + "(0x69, 100)", + "true", + &max_u256, + ] + .into_iter() + .map(String::from) + .collect_vec(); + + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let result = transform( + &simulated_cli_input, + contract_class, + &get_selector_from_name("complex_fn").unwrap(), + )?; + + // Manually serialized in Cairo + let expected_output: Vec = vec![ + "2", + "2", + "8503", + "1056", + "2", + "1056", + "8503", + "8", + "3618502788666131213697322783095070105623107215331596699973092056135872020211", + "0", + "139552669935068984642203239", + "11", + "105", + "100", + "1", + "340282366920938463463374607431768211455", + "340282366920938463463374607431768211455", + ] + .into_iter() + .map(Felt::from_dec_str) + .collect::>() + .unwrap(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_complex_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + // Input identical to `[..]complex_function_cairo_expressions_input` + let input: Vec = [ + "0x2", + "0x2", + "0x2137", + "0x420", + "0x2", + "0x420", + "0x2137", + "0x8", + "0x800000000000010fffffffffffffffffffffffffffffffffffffffffffffef3", + "0x0", + "0x736f6d655f737472696e67", + "0xb", + "0x69", + "0x64", + "0x1", + "0xffffffffffffffffffffffffffffffff", + "0xffffffffffffffffffffffffffffffff", + ] + .into_iter() + .map(String::from) + .collect(); + + let result = transform( + &input, + contract_class, + &get_selector_from_name("complex_fn").unwrap(), + )?; + + let expected_output: Vec = vec![ + "2", + "2", + "8503", + "1056", + "2", + "1056", + "8503", + "8", + "3618502788666131213697322783095070105623107215331596699973092056135872020211", + "0", + "139552669935068984642203239", + "11", + "105", + "100", + "1", + "340282366920938463463374607431768211455", + "340282366920938463463374607431768211455", + ] + .into_iter() + .map(Felt::from_dec_str) + .collect::>() + .unwrap(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_simple_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("SimpleStruct {a: 0x12}")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from_hex_unchecked("0x12")]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_simple_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x12")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + )?; + + let expected_output = vec![Felt::from_hex_unchecked("0x12")]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_simple_struct_function_invalid_struct_argument() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from(r#"SimpleStruct {a: "string"}"#)]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Failed to parse value "string" into type "core::felt252""#); +} + +#[tokio::test] +async fn test_simple_struct_function_invalid_struct_name() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("InvalidStructName {a: 0x10}")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got "InvalidStructName""#); +} + +#[test_case(r#""string_argument""#, r#"Failed to parse value "string_argument" into type "data_transformer_contract::SimpleStruct""# ; "string")] +#[test_case("'shortstring'", r#"Failed to parse value "shortstring" into type "data_transformer_contract::SimpleStruct""# ; "shortstring")] +#[test_case("true", r#"Failed to parse value "true" into type "data_transformer_contract::SimpleStruct""# ; "bool")] +#[test_case("array![0x1, 2, 0x3, 04]", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got array"# ; "array")] +#[test_case("(1, array![2], 0x3)", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got tuple"# ; "tuple")] +#[test_case("My::Enum", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got "My""# ; "enum_variant")] +#[test_case("core::path::My::Enum(10)", r#"Invalid argument type, expected "data_transformer_contract::SimpleStruct", got "core::path::My""# ; "enum_variant_with_path")] +#[tokio::test] +async fn test_simple_struct_function_cairo_expression_input_invalid_argument_type( + input: &str, + error_message: &str, +) { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![input.to_string()]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + ); + + result.unwrap_err().assert_contains(error_message); +} + +#[tokio::test] +async fn test_happy_case_nested_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from( + "NestedStructWithField { a: SimpleStruct { a: 0x24 }, b: 96 }", + )]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("nested_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x24"), + Felt::from_hex_unchecked("0x60"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_nested_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x24"), String::from("0x60")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("simple_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x24"), + Felt::from_hex_unchecked("0x60"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_empty_variant_cairo_expression_input() -> anyhow::Result<()> +{ + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("Enum::One")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![Felt::ZERO]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_empty_variant_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x0")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![Felt::ZERO]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_one_argument_variant_cairo_expression_input( +) -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("Enum::Two(128)")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x80"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_one_argument_variant_serialized_input() -> anyhow::Result<()> +{ + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x1"), String::from("0x80")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x80"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_nested_struct_variant_cairo_expression_input( +) -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from( + "Enum::Three(NestedStructWithField { a: SimpleStruct { a: 123 }, b: 234 })", + )]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x7b"), + Felt::from_hex_unchecked("0xea"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_enum_function_nested_struct_variant_serialized_input() -> anyhow::Result<()> +{ + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("0x2"), + String::from("0x7b"), + String::from("0xea"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x7b"), + Felt::from_hex_unchecked("0xea"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_enum_function_invalid_variant_cairo_expression_input() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("Enum::InvalidVariant")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Couldn't find variant "InvalidVariant" in enum "Enum""#); +} + +#[tokio::test] +async fn test_happy_case_complex_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let data = concat!( + r#"ComplexStruct {"#, + r#" a: NestedStructWithField {"#, + r#" a: SimpleStruct { a: 1 },"#, + r#" b: 2"#, + r#" },"#, + r#" b: 3, c: 4, d: 5,"#, + r#" e: Enum::Two(6),"#, + r#" f: "seven","#, + r#" g: array![8, 9],"#, + r#" h: 10, i: (11, 12)"#, + r#"}"#, + ); + + let input = vec![String::from(data)]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("complex_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + // a: NestedStruct + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x2"), + // b: felt252 + Felt::from_hex_unchecked("0x3"), + // c: u8 + Felt::from_hex_unchecked("0x4"), + // d: i32 + Felt::from_hex_unchecked("0x5"), + // e: Enum + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x6"), + // f: ByteArray + Felt::from_hex_unchecked("0x0"), + Felt::from_hex_unchecked("0x736576656e"), + Felt::from_hex_unchecked("0x5"), + // g: Array + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x8"), + Felt::from_hex_unchecked("0x9"), + // h: u256 + Felt::from_hex_unchecked("0xa"), + Felt::from_hex_unchecked("0x0"), + // i: (i128, u128) + Felt::from_hex_unchecked("0xb"), + Felt::from_hex_unchecked("0xc"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_complex_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let felts = vec![ + // a: NestedStruct + "0x1", + "0x2", + // b: felt252 + "0x3", + // c: u8 + "0x4", + // d: i32 + "0x5", + // e: Enum + "0x1", + "0x6", + // f: ByteArray + "0x0", + "0x736576656e", + "0x5", + // g: Array + "0x2", + "0x8", + "0x9", + // h: u256 + "0xa", + "0x0", + // i: (i128, u128) + "0xb", + "0xc", + ]; + + let input = felts.clone().into_iter().map(String::from).collect_vec(); + + let result = transform( + &input, + contract_class, + &get_selector_from_name("complex_struct_fn").unwrap(), + )?; + + let expected_output = felts + .into_iter() + .map(Felt::from_hex_unchecked) + .collect_vec(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +// TODO add similar test but with enums +// - take existing contract code +// - find/create a library with an enum +// - add to project as a dependency +// - create enum with the same name in your contract code +#[tokio::test] +async fn test_external_struct_function_ambiguous_struct_name_cairo_expression_input() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("BitArray { bit: 23 }"), + String::from("BitArray { data: array![0], current: 1, read_pos: 2, write_pos: 3 }"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("external_struct_fn").unwrap(), + ); + + result.unwrap_err().assert_contains( + r#"Found more than one struct "BitArray" in ABI, please specify a full path to the item"#, + ); +} + +#[tokio::test] +async fn test_happy_case_external_struct_function_cairo_expression_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("data_transformer_contract::BitArray { bit: 23 }"), + String::from("alexandria_data_structures::bit_array::BitArray { data: array![0], current: 1, read_pos: 2, write_pos: 3 }") + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("external_struct_fn").unwrap(), + )?; + + let expected_output = vec![ + Felt::from_hex_unchecked("0x17"), + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x0"), + Felt::from_hex_unchecked("0x1"), + Felt::from_hex_unchecked("0x2"), + Felt::from_hex_unchecked("0x3"), + ]; + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_happy_case_external_struct_function_serialized_input() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let felts = vec!["0x17", "0x1", "0x0", "0x1", "0x2", "0x3"]; + + let input = felts.clone().into_iter().map(String::from).collect_vec(); + + let result = transform( + &input, + contract_class, + &get_selector_from_name("enum_fn").unwrap(), + )?; + + let expected_output = felts + .into_iter() + .map(Felt::from_hex_unchecked) + .collect_vec(); + + assert_eq!(result, expected_output); + + Ok(()) +} + +#[tokio::test] +async fn test_external_struct_function_invalid_path_to_external_struct() { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![ + String::from("something::BitArray { bit: 23 }"), + String::from("BitArray { data: array![0], current: 1, read_pos: 2, write_pos: 3 }"), + ]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("external_struct_fn").unwrap(), + ); + + result + .unwrap_err() + .assert_contains(r#"Struct "something::BitArray" not found in ABI"#); +} + +#[tokio::test] +async fn test_happy_case_contract_constructor() -> anyhow::Result<()> { + let contract_class = CLASS.get_or_init(init_class).await.to_owned(); + + let input = vec![String::from("0x123")]; + + let result = transform( + &input, + contract_class, + &get_selector_from_name("constructor").unwrap(), + )?; + + let expected_output = vec![Felt::from_hex_unchecked("0x123")]; + + assert_eq!(result, expected_output); + + Ok(()) +} diff --git a/crates/sncast/tests/integration/mod.rs b/crates/sncast/tests/integration/mod.rs index ccbb1d5946..78cefcee1f 100644 --- a/crates/sncast/tests/integration/mod.rs +++ b/crates/sncast/tests/integration/mod.rs @@ -1,3 +1,4 @@ +mod data_transformer; mod fee; mod lib_tests; mod wait_for_tx;