From 330057ef1067d96e0bf3fee78016b6824db9be6d Mon Sep 17 00:00:00 2001 From: Eric Lau Date: Fri, 20 Dec 2024 17:29:06 -0500 Subject: [PATCH] Check for missing calls to transitive parent initializers - WIP --- .../contracts/test/ValidationsInitializer.sol | 55 ++++++++ .../core/src/validate-initializers.test.ts | 133 +++++++++++------- packages/core/src/validate/report.ts | 4 +- packages/core/src/validate/run.ts | 128 ++++++++++++++--- 4 files changed, 250 insertions(+), 70 deletions(-) diff --git a/packages/core/contracts/test/ValidationsInitializer.sol b/packages/core/contracts/test/ValidationsInitializer.sol index 63741d604..4377e2019 100644 --- a/packages/core/contracts/test/ValidationsInitializer.sol +++ b/packages/core/contracts/test/ValidationsInitializer.sol @@ -323,6 +323,61 @@ contract Child_Has_PrivateInitializer_Bad is Parent__OnlyInitializingModifier { // ==== Transitive initialization ==== +abstract contract TransitiveGrandparent1 is Initializable { + uint x; + function __TransitiveGrandparent1_init() onlyInitializing internal { + x = 1; + } +} + +abstract contract TransitiveGrandparent2 is Initializable { + uint y; + function __TransitiveGrandparent2_init() onlyInitializing internal { + y = 1; + } +} + +contract TransitiveParent_Ok is TransitiveGrandparent1, TransitiveGrandparent2 { + function initializeParent() initializer public { + __TransitiveGrandparent1_init(); + __TransitiveGrandparent2_init(); + } +} + +contract TransitiveParent_Bad is TransitiveGrandparent1, TransitiveGrandparent2 { + function initializeParent() initializer public { + __TransitiveGrandparent1_init(); + // Does not call __TransitiveGrandparent2_init, and this contract is not abstract, so it is required + } +} + +contract TransitiveChild_Bad_Parent is TransitiveParent_Bad { // this contract is ok but the parent is not + function initialize() initializer public { + initializeParent(); + } +} + +contract TransitiveChild_Bad_Order is TransitiveParent_Bad { // grandparent should be initialized first + function initialize() initializer public { + initializeParent(); + __TransitiveGrandparent2_init(); + } +} + +contract TransitiveChild_Bad_Order2 is TransitiveParent_Bad { // this contract is ok but the parent is not + function initialize() initializer public { + __TransitiveGrandparent2_init(); + initializeParent(); + } +} + +contract TransitiveDuplicate_Bad is TransitiveGrandparent1, TransitiveParent_Ok { + function initialize() initializer public { + __TransitiveGrandparent1_init(); + initializeParent(); + } +} + contract Ownable_Ok is Initializable, ERC20Upgradeable, OwnableUpgradeable { /// @custom:oz-upgrades-unsafe-allow constructor constructor() { diff --git a/packages/core/src/validate-initializers.test.ts b/packages/core/src/validate-initializers.test.ts index 53cd00c08..f0b32fdf5 100644 --- a/packages/core/src/validate-initializers.test.ts +++ b/packages/core/src/validate-initializers.test.ts @@ -37,17 +37,27 @@ function testAccepts(name: string, kind: ValidationOptions['kind']) { testOverride(name, kind, {}, undefined); } -function testRejects(name: string, kind: ValidationOptions['kind'], expectedErrorContains: string) { - testOverride(name, kind, {}, expectedErrorContains); +function testRejects( + name: string, + kind: ValidationOptions['kind'], + expectedError?: { + contains: string; + count: number; + }, +) { + testOverride(name, kind, {}, expectedError); } function testOverride( name: string, kind: ValidationOptions['kind'], opts: ValidationOptions, - expectErrorContains?: string, + expectedError?: { + contains: string; + count: number; + }, ) { - const expectValid = expectErrorContains === undefined; + const expectValid = expectedError === undefined; const optKeys = Object.keys(opts); const describeOpts = optKeys.length > 0 ? '(' + optKeys.join(', ') + ')' : ''; @@ -59,7 +69,11 @@ function testOverride( t.notThrows(assertUpgSafe); } else { const error = t.throws(assertUpgSafe) as ValidationErrors; - t.true(error.message.includes(expectErrorContains), error.message); + t.true( + error.errors.length === expectedError.count, + `Expected ${expectedError.count} errors, got ${error.errors.length}:\n${error.message}`, + ); + t.true(error.message.includes(expectedError.contains), error.message); } }); } @@ -67,75 +81,100 @@ function testOverride( testAccepts('Child_Of_NoInitializer_Ok', 'transparent'); testAccepts('Child_Of_InitializerModifier_Ok', 'transparent'); -testRejects( - 'Child_Of_InitializerModifier_Bad', - 'transparent', - 'Contract is missing initializer calls for one or more parent contracts: `Parent_InitializerModifier`', -); +testRejects('Child_Of_InitializerModifier_Bad', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `Parent_InitializerModifier`', + count: 1, +}); testAccepts('Child_Of_InitializerModifier_UsesSuper_Ok', 'transparent'); testAccepts('Child_Of_ReinitializerModifier_Ok', 'transparent'); -testRejects( - 'Child_Of_ReinitializerModifier_Bad', - 'transparent', - 'Contract is missing initializer calls for one or more parent contracts: `Parent_ReinitializerModifier`', -); +testRejects('Child_Of_ReinitializerModifier_Bad', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `Parent_ReinitializerModifier`', + count: 1, +}); testAccepts('Child_Of_OnlyInitializingModifier_Ok', 'transparent'); -testRejects( - 'Child_Of_OnlyInitializingModifier_Bad', - 'transparent', - 'Contract is missing initializer calls for one or more parent contracts: `Parent__OnlyInitializingModifier`', -); +testRejects('Child_Of_OnlyInitializingModifier_Bad', 'transparent', { + contains: + 'Contract is missing initializer calls for one or more parent contracts: `Parent__OnlyInitializingModifier`', + count: 1, +}); + +testRejects('MissingInitializer_Bad', 'transparent', { + contains: 'Contract is missing an initializer', -testRejects('MissingInitializer_Bad', 'transparent', 'Contract is missing an initializer'); + count: 1, +}); testAccepts('MissingInitializer_UnsafeAllow_Contract', 'transparent'); testOverride('MissingInitializer_Bad', 'transparent', { unsafeAllow: ['missing-initializer'] }); testAccepts('InitializationOrder_Ok', 'transparent'); testAccepts('InitializationOrder_Ok_2', 'transparent'); -testRejects( - 'InitializationOrder_WrongOrder_Bad', - 'transparent', - 'Contract has an incorrect order of parent initializer calls. Expected initializers to be called for parent contracts in the following order: A, B, C', -); +testRejects('InitializationOrder_WrongOrder_Bad', 'transparent', { + contains: 'Expected initializers to be called for parent contracts in the following order: A, B, C', + count: 1, +}); testAccepts('InitializationOrder_WrongOrder_UnsafeAllow_Contract', 'transparent'); testAccepts('InitializationOrder_WrongOrder_UnsafeAllow_Function', 'transparent'); testOverride('InitializationOrder_WrongOrder_Bad', 'transparent', { unsafeAllow: ['incorrect-initializer-order'] }); -testRejects( - 'InitializationOrder_MissingCall_Bad', - 'transparent', - 'Contract is missing initializer calls for one or more parent contracts: `C`', -); +testRejects('InitializationOrder_MissingCall_Bad', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `C`', + count: 1, +}); testAccepts('InitializationOrder_MissingCall_UnsafeAllow_Contract', 'transparent'); testAccepts('InitializationOrder_MissingCall_UnsafeAllow_Function', 'transparent'); testOverride('InitializationOrder_MissingCall_Bad', 'transparent', { unsafeAllow: ['missing-initializer-call'] }); -testRejects( - 'InitializationOrder_Duplicate_Bad', - 'transparent', - 'Contract has duplicate calls to parent initializer `__B_init` for contract `B`', -); +testRejects('InitializationOrder_Duplicate_Bad', 'transparent', { + contains: 'Contract has duplicate calls to parent initializer `__B_init` for contract `B`', + count: 1, +}); testAccepts('InitializationOrder_Duplicate_UnsafeAllow_Contract', 'transparent'); testAccepts('InitializationOrder_Duplicate_UnsafeAllow_Function', 'transparent'); testAccepts('InitializationOrder_Duplicate_UnsafeAllow_Call', 'transparent'); testOverride('InitializationOrder_Duplicate_Bad', 'transparent', { unsafeAllow: ['duplicate-initializer-call'] }); -testRejects( - 'InitializationOrder_UnsafeAllowDuplicate_But_WrongOrder', - 'transparent', - 'Contract has an incorrect order of parent initializer calls. Expected initializers to be called for parent contracts in the following order: A, B, C', -); +testRejects('InitializationOrder_UnsafeAllowDuplicate_But_WrongOrder', 'transparent', { + contains: 'Expected initializers to be called for parent contracts in the following order: A, B, C', + count: 1, +}); testAccepts('Child_Of_ParentPrivateInitializer_Ok', 'transparent'); testAccepts('Child_Of_ParentPublicInitializer_Ok', 'transparent'); -testRejects('Child_Has_PrivateInitializer_Bad', 'transparent', 'Contract is missing an initializer'); +testRejects('Child_Has_PrivateInitializer_Bad', 'transparent', { + contains: 'Contract is missing an initializer', + count: 1, +}); + +testAccepts('TransitiveParent_Ok', 'transparent'); +testRejects('TransitiveParent_Bad', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `TransitiveGrandparent2`', + count: 1, +}); +testRejects('TransitiveChild_Bad_Parent', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `TransitiveGrandparent2`', + count: 3, // should be 2 if we ignore wrong order. the errors are: missing for child, missing for parent +}); +testRejects('TransitiveChild_Bad_Order', 'transparent', { + contains: + 'Expected initializers to be called for parent contracts in the following order: TransitiveGrandparent2, TransitiveParent_Bad', + count: 2, +}); // should have 2 errors: 'Expected initializers to be called for parent contracts in the following order: TransitiveGrandparent2, TransitiveParent_Bad', 'Contract is missing initializer calls for one or more parent contracts: `TransitiveGrandparent2`' +// but 1 if we ignore wrong order +testRejects('TransitiveChild_Bad_Order2', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `TransitiveGrandparent2`', + count: 1, +}); +testRejects('TransitiveDuplicate_Bad', 'transparent', { + contains: 'Contract has duplicate calls to parent', + count: 1, +}); +// should allow this if we ignore duplicate calls transitively testAccepts('Ownable_Ok', 'transparent'); testAccepts('Ownable2Step_Ok', 'transparent'); -testRejects( - 'Ownable2Step_Bad', - 'transparent', - 'Contract is missing initializer calls for one or more parent contracts: `OwnableUpgradeable`', -); +testRejects('Ownable2Step_Bad', 'transparent', { + contains: 'Contract is missing initializer calls for one or more parent contracts: `OwnableUpgradeable`', + count: 1, +}); diff --git a/packages/core/src/validate/report.ts b/packages/core/src/validate/report.ts index ed598c1a7..ab518b214 100644 --- a/packages/core/src/validate/report.ts +++ b/packages/core/src/validate/report.ts @@ -91,7 +91,9 @@ const errorInfo: ErrorDescriptions = { }, 'incorrect-initializer-order': { msg: e => - `Contract has an incorrect order of parent initializer calls. Expected initializers to be called for parent contracts in the following order: ${e.expectedLinearization.join(', ')}`, + `Contract has an incorrect order of parent initializer calls. +- Expected initializers to be called for parent contracts in the following order: ${e.expectedLinearization.join(', ')} +- Found order: ${e.foundOrder.join(', ')}`, hint: () => `Call parent initializers in the order they are inherited`, link: 'https://zpl.in/upgrades/error-001', }, diff --git a/packages/core/src/validate/run.ts b/packages/core/src/validate/run.ts index 247d1a5fa..1caf5f884 100644 --- a/packages/core/src/validate/run.ts +++ b/packages/core/src/validate/run.ts @@ -97,6 +97,7 @@ interface ValidationErrorDuplicateInitializerCall extends ValidationErrorBase { interface ValidationErrorIncorrectInitializerOrder extends ValidationErrorBase { kind: 'incorrect-initializer-order'; expectedLinearization: string[]; + foundOrder: string[]; } type ValidationErrorInitializer = @@ -662,7 +663,7 @@ function* getInternalFunctionStorageErrors( } /** - * Reports an error if a parent contract has an initializer and any of the following are true: + * Reports an error this contract is non-abstract, a linearized parent contract has an initializer, and any of the following are true: * - 1. Missing initializer: This contract does not appear to have an initializer. * - 2. Missing initializer call: This contract's initializer is missing a call to a parent initializer. * - 3. Duplicate initializer call: This contract has duplicate calls to the same parent initializer function. @@ -673,21 +674,97 @@ function* getInitializerErrors( deref: ASTDereferencer, decodeSrc: SrcDecoder, ): Generator { - if (contractDef.baseContracts.length > 0) { - const baseContractDefs = contractDef.baseContracts.map(base => - deref('ContractDefinition', base.baseName.referencedDeclaration), + if (contractDef.abstract) { + return; + } + if (contractDef.linearizedBaseContracts.length > 0) { + console.log('- Checking initializers for contract [' + contractDef.name + ']'); + + const linearizedBaseContractDefs = contractDef.linearizedBaseContracts.map(base => + deref('ContractDefinition', base), ); + + // Remove the least derived contract from the list + linearizedBaseContractDefs.shift(); + // Reverse the order to start from the most derived contract + linearizedBaseContractDefs.reverse(); + const baseContractsInitializersMap = new Map( - baseContractDefs.map(base => [base.name, getPossibleInitializers(base, true)]), + linearizedBaseContractDefs.map(base => [base.name, getPossibleInitializers(base, true)]), ); - const baseContractsWithInitializers = baseContractDefs + + console.log(' -> Before removing: ' + linearizedBaseContractDefs.map(base => base.name).join(', ')); + + // For each base contract, if its initializer calls any of the earlier base contracts' intializers, it can be removed from the list. + // Ignore whether the base contracts are calling their initializers in the correct order, because we only check the order of THIS contract's calls. + for (const base of linearizedBaseContractDefs) { + const baseInitializers = baseContractsInitializersMap.get(base.name)!; + for (const initializer of baseInitializers) { + const expressionStatements = + initializer.body?.statements?.filter(stmt => stmt.nodeType === 'ExpressionStatement') ?? []; + for (const stmt of expressionStatements) { + const fnCall = stmt.expression; + if ( + fnCall.nodeType === 'FunctionCall' && + (fnCall.expression.nodeType === 'Identifier' || fnCall.expression.nodeType === 'MemberAccess') + ) { + const referencedFn = fnCall.expression.referencedDeclaration; + if (referencedFn) { + const earlierBaseContractDefs = linearizedBaseContractDefs.slice( + 0, + linearizedBaseContractDefs.indexOf(base), + ); + + const foundParentInitializer = earlierBaseContractDefs.find(base => + baseContractsInitializersMap.get(base.name)!.some(init => init.id === referencedFn), + ); + if (foundParentInitializer) { + const index = earlierBaseContractDefs.indexOf(foundParentInitializer); + if (index !== -1) { + console.log( + ' - Removing ' + + foundParentInitializer.name + + ' from linearizedBaseContractDefs because it is called by ' + + base.name, + ); + linearizedBaseContractDefs.splice(linearizedBaseContractDefs.indexOf(foundParentInitializer), 1); + } + } + } + } + } + } + } + + console.log(' -> After removing: ' + linearizedBaseContractDefs.map(base => base.name).join(', ')); + + const baseContractsToInitialize = linearizedBaseContractDefs .filter(base => baseContractsInitializersMap.get(base.name)?.length) .map(base => base.name); - if (baseContractsWithInitializers.length > 0) { + console.log(' - baseContractsToInitialize: ' + baseContractsToInitialize); + + if (baseContractsToInitialize.length > 0) { // Check for missing initializers const contractInitializers = getPossibleInitializers(contractDef, false); - if (contractInitializers.length === 0 && !skipCheck('missing-initializer', contractDef)) { + + // If there are multiple parents with initializers, the contract must have its own initializer to call them. + // If there is one parent with possible initializers: + // - they are all internal, the contract must have its own initializer + // - otherwise the contract does not need its own initializer, since one of the parent's initializers can be called during deployment + const requiresParentInitializerCall = + baseContractsToInitialize.length > 1 || + (baseContractsToInitialize.length === 1 && + baseContractsInitializersMap.get(baseContractsToInitialize[0])!.length > 0 && + baseContractsInitializersMap + .get(baseContractsToInitialize[0])! + .every(contractDef => contractDef.visibility === 'internal')); + + if ( + requiresParentInitializerCall && + contractInitializers.length === 0 && + !skipCheck('missing-initializer', contractDef) + ) { yield { kind: 'missing-initializer', src: decodeSrc(contractDef), @@ -695,7 +772,8 @@ function* getInitializerErrors( } for (const contractInitializer of contractInitializers) { - const uninitializedBaseContracts = [...baseContractsWithInitializers]; + const remaining: string[] = [...baseContractsToInitialize]; + const foundOrder: string[] = []; const calledInitializerIds: number[] = []; const expressionStatements = @@ -711,8 +789,9 @@ function* getInitializerErrors( // If this is a call to a parent initializer, then: // - Check if it was already called (duplicate call) // - Otherwise, check if the parent initializer is called in the correct order - for (const [baseName, initializers] of baseContractsInitializersMap) { - const foundParentInitializer = initializers.find(init => init.id === referencedFn); + for (const baseContractToInitialize of baseContractsToInitialize) { + const baseInitializers = baseContractsInitializersMap.get(baseContractToInitialize)!; + const foundParentInitializer = baseInitializers.find(init => init.id === referencedFn); if (referencedFn && foundParentInitializer) { const duplicate = calledInitializerIds.includes(referencedFn); if ( @@ -725,12 +804,14 @@ function* getInitializerErrors( kind: 'duplicate-initializer-call', src: decodeSrc(fnCall), parentInitializer: foundParentInitializer.name, - parentContract: baseName, + parentContract: baseContractToInitialize, }; } calledInitializerIds.push(referencedFn); - const index = uninitializedBaseContracts.indexOf(baseName); + foundOrder.push(baseContractToInitialize); + // TODO handle linearized contracts + const index = remaining.indexOf(baseContractToInitialize); if ( !duplicate && // Omit duplicate calls to avoid treating them as out of order. Duplicates are either reported above or they were skipped. index !== 0 && @@ -740,11 +821,12 @@ function* getInitializerErrors( yield { kind: 'incorrect-initializer-order', src: decodeSrc(fnCall), - expectedLinearization: baseContractsWithInitializers, + expectedLinearization: baseContractsToInitialize, + foundOrder, }; } if (index !== -1) { - uninitializedBaseContracts.splice(index, 1); + remaining.splice(index, 1); } } } @@ -753,14 +835,14 @@ function* getInitializerErrors( // If there are any base contracts that were not initialized, report an error if ( - uninitializedBaseContracts.length > 0 && + remaining.length > 0 && !skipCheck('missing-initializer-call', contractDef) && !skipCheck('missing-initializer-call', contractInitializer) ) { yield { kind: 'missing-initializer-call', src: decodeSrc(contractInitializer), - parentContracts: uninitializedBaseContracts, + parentContracts: remaining, }; } } @@ -768,6 +850,9 @@ function* getInitializerErrors( } } +/** + * Get all functions that could be initializers. Does not include private functions. + */ function getPossibleInitializers(contractDef: ContractDefinition, isParentContract: boolean) { const fns = [...findAll('FunctionDefinition', contractDef)]; return fns.filter( @@ -778,11 +863,10 @@ function getPossibleInitializers(contractDef: ContractDefinition, isParentContra ['initialize', 'initializer', 'reinitialize', 'reinitializer'].includes(fnDef.name)) && // Skip virtual functions without a body, since that indicates an abstract function and is not itself an initializer !(fnDef.virtual && !fnDef.body) && - // For parent contracts, only treat internal functions which contain statements as initializers (since they MUST be called by the child) - // For child contracts, treat all non-private functions as initializers (since they can be called by another contract or externally) - (isParentContract - ? fnDef.visibility === 'internal' && fnDef.body?.statements?.length - : fnDef.visibility !== 'private'), + // Ignore private functions, since they cannot be called outside the contract + fnDef.visibility !== 'private' && + // For parent contracts, only functions which contain statements need to be called + (isParentContract ? fnDef.body?.statements?.length : true), ); }