Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/macros/src/attribute/with_components/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pub mod errors {
pub fn MULTIPLE_ACCESS_CONTROL_COMPONENTS(components: &str) -> String {
format!("Only one AccessControl component can be used. Found: [{components}].\n")
}

/// Error when `ERC721Enumerable` is used together with `ERC721Consecutive`.
pub const ERC721_BALANCE_OF_INCOPATIBILITY: &str =
"ERC721Enumerable and ERC721Consecutive interfere with each other in token ownership tracking and should not be used together.\n";
}

#[allow(non_snake_case)]
Expand Down
39 changes: 24 additions & 15 deletions packages/macros/src/attribute/with_components/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ fn validate_contract_module(
return (vec![error], vec![]);
};

// Keep a stringified version of the module body around for validations below.
let body_ast = body.as_syntax_node();
let typed = ast::ModuleBody::from_syntax_node(db, body_ast);
let body_rnode = RewriteNode::from_ast(&typed);
let mut builder = PatchBuilder::new_ex(db, &body_ast);
builder.add_modified(body_rnode);
let (body_code, _) = builder.build();

// 2. Check that the module has the `#[starknet::contract]` attribute (error)
if !item.has_attr(db, CONTRACT_ATTRIBUTE) {
let error = Diagnostic::error(errors::NO_CONTRACT_ATTRIBUTE(CONTRACT_ATTRIBUTE));
Expand All @@ -141,7 +149,17 @@ fn validate_contract_module(
return (vec![error], vec![]);
}

// 4. Check that the module has the corresponding initializers (warning)
// 4. Disallow ERC721Enumerable and ERC721Consecutive being used together (error)
let uses_erc721_enumerable = components_info
.iter()
.any(|c| matches!(c.kind(), AllowedComponents::ERC721Enumerable));
let uses_erc721_consecutive = body_code.contains("ERC721Consecutive");
if uses_erc721_enumerable && uses_erc721_consecutive {
let error = Diagnostic::error(errors::ERC721_BALANCE_OF_INCOPATIBILITY);
return (vec![error], vec![]);
}

// 5. Check that the module has the corresponding initializers (warning)
let components_with_initializer = components_info
.iter()
.filter(|c| c.has_initializer)
Expand Down Expand Up @@ -180,17 +198,8 @@ fn validate_contract_module(
}
}

// 5. Check that the contract has the corresponding immutable configs
// 6. Check that the contract has the corresponding immutable configs
for component in components_info.iter().filter(|c| c.has_immutable_config) {
// Get the body code (maybe we can do this without the builder)
let body_ast = body.as_syntax_node();
let typed = ast::ModuleBody::from_syntax_node(db, body_ast);
let body_rnode = RewriteNode::from_ast(&typed);

let mut builder = PatchBuilder::new_ex(db, &body_ast);
builder.add_modified(body_rnode);
let (code, _) = builder.build();

// Case 1: DefaultConfig is imported and used
let component_parent_path = component
.path
Expand All @@ -200,14 +209,14 @@ fn validate_contract_module(
r"use {component_parent_path}[{{\w:, \n]*DefaultConfig(\s+as\s+\w+)?[{{\w}}, \n]*;"
))
.unwrap();
let default_config_used = default_config_import_re.is_match(&code);
let default_config_used = default_config_import_re.is_match(&body_code);
if default_config_used {
continue;
}

// Case 2: ImmutableConfig is implemented with fully qualified path
let immutable_config_implemented =
code.contains(&format!("of {}::ImmutableConfig", component.name));
body_code.contains(&format!("of {}::ImmutableConfig", component.name));
if immutable_config_implemented {
continue;
}
Expand All @@ -217,11 +226,11 @@ fn validate_contract_module(
r"use {component_parent_path}[\w:]*\w+::[{{\w, \n]*ImmutableConfig(?:\s+as\s+(\w+))?[{{\w}}, \n]*;"
))
.unwrap();
if let Some(captures) = immutable_config_import_re.captures(&code) {
if let Some(captures) = immutable_config_import_re.captures(&body_code) {
// Use the alias if present, otherwise use "ImmutableConfig"
let config_name = captures.get(1).map_or("ImmutableConfig", |m| m.as_str());
let imported_immutable_config_implemented =
code.contains(&format!("of {config_name}"));
body_code.contains(&format!("of {config_name}"));
if imported_immutable_config_implemented {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
source: src/tests/test_with_components.rs
expression: result
snapshot_kind: text
---
TokenStream:

None

Diagnostics:

====
Error: ERC721Enumerable and ERC721Consecutive interfere with each other in token ownership tracking and should not be used together.
====

AuxData:

None
16 changes: 16 additions & 0 deletions packages/macros/src/tests/test_with_components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,22 @@ fn test_with_erc721_enumerable_no_initializer() {
assert_snapshot!(result);
}

#[test]
fn test_cannot_use_erc721_enumerable_with_erc721_consecutive() {
let attribute = quote! { (ERC721Enumerable) };
let item = quote! {
#[starknet::contract]
pub mod MyContract {
use openzeppelin_token::erc721::extensions::ERC721ConsecutiveComponent;

#[storage]
pub struct Storage {}
}
};
let result = get_string_result(attribute, item);
assert_snapshot!(result);
}

#[test]
fn test_with_erc721_receiver() {
let attribute = quote! { (ERC721Receiver) };
Expand Down
4 changes: 4 additions & 0 deletions packages/token/src/erc721/extensions/erc721_consecutive.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
/// Implementation of the ERC-2309 "Consecutive Transfer Extension".
/// This allows batch minting of consecutive token IDs during construction.
///
/// CAUTION: This extension does not call the `ERC721Component::update` function for tokens minted
/// in batch. Any logic added to this function through hooks will not be triggered when tokens
/// are minted in batch.
///
/// IMPORTANT: To properly track sequential burns and enforce consecutive minting rules, this
/// extension requires that `ERC721ConsecutiveComponent::before_update` and
/// `ERC721ConsecutiveComponent::after_update` are called after every transfer, mint, or burn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
///
/// NOTE: Implementing ERC721Component is a requirement for this component to be implemented.
///
/// CAUTION: `ERC721` extensions that implement custom `balanceOf` logic, such as
/// `ERC721Consecutive`, interfere with enumerability and should not be used together with
/// `ERC721Enumerable`.
///
/// IMPORTANT: To properly track token ids, this extension requires that
/// the ERC721EnumerableComponent::before_update function is called after
/// every transfer, mint, or burn operation.
Expand Down
10 changes: 9 additions & 1 deletion packages/utils/src/tests/test_fuzz_checkpoint.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,17 @@ fn test_lower_lookup(checkpoints: Span<Checkpoint>) {
for i in 1..checkpoints.len() {
let index: usize = i.try_into().unwrap();
let checkpoint = *checkpoints.at(index);
let prev_checkpoint = *checkpoints.at(index - 1);
let search_key = checkpoint.key - 1;
let found_value = mock_trace.lower_lookup(search_key);
assert_eq!(found_value, checkpoint.value);
// If search_key equals the previous checkpoint's key, lower_lookup returns that value.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test was incorrectly assuming that checkpoint.key - 1 would always fall between checkpoints, but when the fuzzer generates consecutive keys (key_step = 1), search_key actually matches the previous checkpoint's key exactly, so lower_lookup correctly returns the previous checkpoint's value.

// Otherwise, it returns the current checkpoint's value (first with key >= search_key).
let expected = if search_key == prev_checkpoint.key {
prev_checkpoint.value
} else {
checkpoint.value
};
assert_eq!(found_value, expected);
}
}

Expand Down