diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 2eb0113487..a2f4b5a0d1 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -32,6 +32,7 @@ jobs: # docs: documentation # test: test suite # lang: language changes + # stdlib: changes to the stdlib # ux: language changes (UX) # tool: integration # ir: (old) IR/codegen changes @@ -43,6 +44,7 @@ jobs: docs test lang + stdlib ux tool ir diff --git a/docs/control-structures.rst b/docs/control-structures.rst index d46e7a4a28..6304637728 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -48,7 +48,16 @@ External functions (marked with the ``@external`` decorator) are a part of the c A Vyper contract cannot call directly between two external functions. If you must do this, you can use an :ref:`interface `. .. note:: - For external functions with default arguments like ``def my_function(x: uint256, b: uint256 = 1)`` the Vyper compiler will generate ``N+1`` overloaded function selectors based on ``N`` default arguments. + For external functions with default arguments like ``def my_function(x: uint256, b: uint256 = 1)`` the Vyper compiler will generate ``N+1`` overloaded function selectors based on ``N`` default arguments. Consequently, the ABI signature for a function (this includes interface functions) excludes optional arguments when their default values are used in the function call. + + .. code-block:: vyper + + from ethereum.ercs import IERC4626 + + @external + def foo(x: IERC4626): + extcall x.withdraw(0, self, self) # keccak256("withdraw(uint256,address,address)")[:4] = 0xb460af94 + extcall x.withdraw(0) # keccak256("withdraw(uint256)")[:4] = 0x2e1a7d4d .. _structure-functions-internal: @@ -75,6 +84,14 @@ Or for internal functions which are defined in :ref:`imported modules ` def calculate(amount: uint256) -> uint256: return calculator_library._times_two(amount) +Marking an internal function as ``payable`` specifies that the function can interact with ``msg.value``. A ``nonpayable`` internal function can be called from an external ``payable`` function, but it cannot access ``msg.value``. + +.. code-block:: vyper + + @payable + def _foo() -> uint256: + return msg.value % 2 + .. note:: As of v0.4.0, the ``@internal`` decorator is optional. That is, functions with no visibility decorator default to being ``internal``. @@ -110,7 +127,7 @@ You can optionally declare a function's mutability by using a :ref:`decorator uint256: input_bundle = make_input_bundle({"lib1.vy": lib1}) c = get_contract(main, input_bundle=input_bundle) assert c.bar() == 1 + + +def test_interface_with_flags(): + code = """ +struct MyStruct: + a: address + +flag Foo: + BOO + MOO + POO + +event Transfer: + sender: indexed(address) + +@external +def bar(): + pass +flag BAR: + BIZ + BAZ + BOO + +@external +@view +def foo(s: MyStruct) -> MyStruct: + return s + """ + + out = compile_code(code, contract_path="code.vy", output_formats=["interface"])["interface"] + + assert "# Flags" in out + assert "flag Foo:" in out + assert "flag BAR" in out + assert "BOO" in out + assert "MOO" in out + + compile_code(out, contract_path="code.vyi", output_formats=["interface"]) + + +vyi_filenames = [ + "test__test.vyi", + "test__t.vyi", + "t__test.vyi", + "t__t.vyi", + "t_t.vyi", + "test_test.vyi", + "t_test.vyi", + "test_t.vyi", + "_test_t__t_tt_.vyi", + "foo_bar_baz.vyi", +] + + +@pytest.mark.parametrize("vyi_filename", vyi_filenames) +def test_external_interface_names(vyi_filename): + code = """ +@external +def foo(): + ... + """ + + compile_code(code, contract_path=vyi_filename, output_formats=["external_interface"]) + + +def test_external_interface_with_flag(): + code = """ +flag Foo: + Blah + +@external +def foo() -> Foo: + ... + """ + + out = compile_code(code, contract_path="test__test.vyi", output_formats=["external_interface"])[ + "external_interface" + ] + assert "-> Foo:" in out + + +def test_external_interface_compiles_again(): + code = """ +@external +def foo() -> uint256: + ... +@external +def bar(a:int32) -> uint256: + ... + """ + + out = compile_code(code, contract_path="test.vyi", output_formats=["external_interface"])[ + "external_interface" + ] + compile_code(out, contract_path="test.vyi", output_formats=["external_interface"]) + + +@pytest.mark.xfail +def test_weird_interface_name(): + # based on comment https://github.com/vyperlang/vyper/pull/4290#discussion_r1884137428 + # we replace "_" for "" which results in an interface without name + out = compile_code("", contract_path="_.vyi", output_formats=["external_interface"])[ + "external_interface" + ] + assert "interface _:" in out diff --git a/tests/functional/venom/parser/test_parsing.py b/tests/functional/venom/parser/test_parsing.py index f1fc59cf40..bd536a8cfa 100644 --- a/tests/functional/venom/parser/test_parsing.py +++ b/tests/functional/venom/parser/test_parsing.py @@ -1,6 +1,6 @@ -from tests.venom_utils import assert_ctx_eq -from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable -from vyper.venom.context import IRContext +from tests.venom_utils import assert_bb_eq, assert_ctx_eq +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral, IRVariable +from vyper.venom.context import DataItem, DataSection, IRContext from vyper.venom.function import IRFunction from vyper.venom.parser import parse_venom @@ -11,8 +11,6 @@ def test_single_bb(): main: stop } - - [data] """ parsed_ctx = parse_venom(source) @@ -38,8 +36,6 @@ def test_multi_bb_single_fn(): has_callvalue: revert 0, 0 } - - [data] """ parsed_ctx = parse_venom(source) @@ -61,8 +57,6 @@ def test_multi_bb_single_fn(): has_callvalue_bb.append_instruction("revert", IRLiteral(0), IRLiteral(0)) has_callvalue_bb.append_instruction("stop") - start_fn.last_variable = 4 - assert_ctx_eq(parsed_ctx, expected_ctx) @@ -74,15 +68,16 @@ def test_data_section(): stop } - [data] - dbname @selector_buckets - db @selector_bucket_0 - db @fallback - db @selector_bucket_2 - db @selector_bucket_3 - db @fallback - db @selector_bucket_5 - db @selector_bucket_6 + data readonly { + dbsection selector_buckets: + db @selector_bucket_0 + db @fallback + db @selector_bucket_2 + db @selector_bucket_3 + db @fallback + db @selector_bucket_5 + db @selector_bucket_6 + } """ ) @@ -91,14 +86,18 @@ def test_data_section(): entry_fn.get_basic_block("entry").append_instruction("stop") expected_ctx.data_segment = [ - IRInstruction("dbname", [IRLabel("selector_buckets")]), - IRInstruction("db", [IRLabel("selector_bucket_0")]), - IRInstruction("db", [IRLabel("fallback")]), - IRInstruction("db", [IRLabel("selector_bucket_2")]), - IRInstruction("db", [IRLabel("selector_bucket_3")]), - IRInstruction("db", [IRLabel("fallback")]), - IRInstruction("db", [IRLabel("selector_bucket_5")]), - IRInstruction("db", [IRLabel("selector_bucket_6")]), + DataSection( + IRLabel("selector_buckets"), + [ + DataItem(IRLabel("selector_bucket_0")), + DataItem(IRLabel("fallback")), + DataItem(IRLabel("selector_bucket_2")), + DataItem(IRLabel("selector_bucket_3")), + DataItem(IRLabel("fallback")), + DataItem(IRLabel("selector_bucket_5")), + DataItem(IRLabel("selector_bucket_6")), + ], + ) ] assert_ctx_eq(parsed_ctx, expected_ctx) @@ -126,8 +125,6 @@ def test_multi_function(): has_value: revert 0, 0 } - - [data] """ ) @@ -157,8 +154,6 @@ def test_multi_function(): value_bb.append_instruction("revert", IRLiteral(0), IRLiteral(0)) value_bb.append_instruction("stop") - check_fn.last_variable = 2 - assert_ctx_eq(parsed_ctx, expected_ctx) @@ -185,13 +180,14 @@ def test_multi_function_and_data(): revert 0, 0 } - [data] - dbname @selector_buckets - db @selector_bucket_0 - db @fallback - db @selector_bucket_2 - db @selector_bucket_3 - db @selector_bucket_6 + data readonly { + dbsection selector_buckets: + db @selector_bucket_0 + db @fallback + db @selector_bucket_2 + db @selector_bucket_3 + db @selector_bucket_6 + } """ ) @@ -221,15 +217,136 @@ def test_multi_function_and_data(): value_bb.append_instruction("revert", IRLiteral(0), IRLiteral(0)) value_bb.append_instruction("stop") - check_fn.last_variable = 2 - expected_ctx.data_segment = [ - IRInstruction("dbname", [IRLabel("selector_buckets")]), - IRInstruction("db", [IRLabel("selector_bucket_0")]), - IRInstruction("db", [IRLabel("fallback")]), - IRInstruction("db", [IRLabel("selector_bucket_2")]), - IRInstruction("db", [IRLabel("selector_bucket_3")]), - IRInstruction("db", [IRLabel("selector_bucket_6")]), + DataSection( + IRLabel("selector_buckets"), + [ + DataItem(IRLabel("selector_bucket_0")), + DataItem(IRLabel("fallback")), + DataItem(IRLabel("selector_bucket_2")), + DataItem(IRLabel("selector_bucket_3")), + DataItem(IRLabel("selector_bucket_6")), + ], + ) ] assert_ctx_eq(parsed_ctx, expected_ctx) + + +def test_phis(): + # @external + # def _loop() -> uint256: + # res: uint256 = 9 + # for i: uint256 in range(res, bound=10): + # res = res + i + # return res + source = """ + function __main_entry { + __main_entry: ; IN=[] OUT=[fallback, 1_then] => {} + %27 = 0 + %1 = calldataload %27 + %28 = %1 + %29 = 224 + %2 = shr %29, %28 + %31 = %2 + %30 = 1729138561 + %4 = xor %30, %31 + %32 = %4 + jnz %32, @fallback, @1_then + ; (__main_entry) + + + 1_then: ; IN=[__main_entry] OUT=[4_condition] => {%11, %var8_0} + %6 = callvalue + %33 = %6 + %7 = iszero %33 + %34 = %7 + assert %34 + %var8_0 = 9 + %11 = 0 + nop + jmp @4_condition + ; (__main_entry) + + + 4_condition: ; IN=[1_then, 5_body] OUT=[5_body, 7_exit] => {%11:3, %var8_0:2} + %var8_0:2 = phi @1_then, %var8_0, @5_body, %var8_0:3 + %11:3 = phi @1_then, %11, @5_body, %11:4 + %35 = %11:3 + %36 = 9 + %15 = xor %36, %35 + %37 = %15 + jnz %37, @5_body, @7_exit + ; (__main_entry) + + + 5_body: ; IN=[4_condition] OUT=[4_condition] => {%11:4, %var8_0:3} + %38 = %11:3 + %39 = %var8_0:2 + %22 = add %39, %38 + %41 = %22 + %40 = %var8_0:2 + %24 = gt %40, %41 + %42 = %24 + %25 = iszero %42 + %43 = %25 + assert %43 + %var8_0:3 = %22 + %44 = %11:3 + %45 = 1 + %11:4 = add %45, %44 + jmp @4_condition + ; (__main_entry) + + + 7_exit: ; IN=[4_condition] OUT=[] => {} + %46 = %var8_0:2 + %47 = 64 + mstore %47, %46 + %48 = 32 + %49 = 64 + return %49, %48 + ; (__main_entry) + + + fallback: ; IN=[__main_entry] OUT=[] => {} + %50 = 0 + %51 = 0 + revert %51, %50 + stop + ; (__main_entry) + } ; close function __main_entry + """ + ctx = parse_venom(source) + + expected_ctx = IRContext() + expected_ctx.add_function(entry_fn := IRFunction(IRLabel("__main_entry"))) + + expect_bb = IRBasicBlock(IRLabel("4_condition"), entry_fn) + entry_fn.append_basic_block(expect_bb) + + expect_bb.append_instruction( + "phi", + IRLabel("1_then"), + IRVariable("%var8_0"), + IRLabel("5_body"), + IRVariable("%var8_0:3"), + ret=IRVariable("var8_0:2"), + ) + expect_bb.append_instruction( + "phi", + IRLabel("1_then"), + IRVariable("%11"), + IRLabel("5_body"), + IRVariable("%11:4"), + ret=IRVariable("11:3"), + ) + expect_bb.append_instruction("store", IRVariable("11:3"), ret=IRVariable("%35")) + expect_bb.append_instruction("store", IRLiteral(9), ret=IRVariable("%36")) + expect_bb.append_instruction("xor", IRVariable("%35"), IRVariable("%36"), ret=IRVariable("%15")) + expect_bb.append_instruction("store", IRVariable("%15"), ret=IRVariable("%37")) + expect_bb.append_instruction("jnz", IRVariable("%37"), IRLabel("5_body"), IRLabel("7_exit")) + # other basic blocks omitted for brevity + + parsed_fn = next(iter(ctx.functions.values())) + assert_bb_eq(parsed_fn.get_basic_block(expect_bb.label.name), expect_bb) diff --git a/tests/functional/venom/test_venom_repr.py b/tests/functional/venom/test_venom_repr.py new file mode 100644 index 0000000000..1fb5d0486a --- /dev/null +++ b/tests/functional/venom/test_venom_repr.py @@ -0,0 +1,126 @@ +import copy +import glob +import textwrap + +import pytest + +from tests.venom_utils import assert_ctx_eq, parse_venom +from vyper.compiler import compile_code +from vyper.compiler.phases import generate_bytecode +from vyper.compiler.settings import OptimizationLevel +from vyper.venom import generate_assembly_experimental, run_passes_on +from vyper.venom.context import IRContext + +""" +Check that venom text format round-trips through parser +""" + + +def get_example_vy_filenames(): + return glob.glob("**/*.vy", root_dir="examples/", recursive=True) + + +@pytest.mark.parametrize("vy_filename", get_example_vy_filenames()) +def test_round_trip_examples(vy_filename, debug, optimize, compiler_settings, request): + """ + Check all examples round trip + """ + path = f"examples/{vy_filename}" + with open(path) as f: + vyper_source = f.read() + + if debug and optimize == OptimizationLevel.CODESIZE: + # FIXME: some round-trips fail when debug is enabled due to labels + # not getting pinned + request.node.add_marker(pytest.mark.xfail(strict=False)) + + _round_trip_helper(vyper_source, optimize, compiler_settings) + + +# pure vyper sources +vyper_sources = [ + """ + @external + def _loop() -> uint256: + res: uint256 = 9 + for i: uint256 in range(res, bound=10): + res = res + i + return res + """ +] + + +@pytest.mark.parametrize("vyper_source", vyper_sources) +def test_round_trip_sources(vyper_source, debug, optimize, compiler_settings, request): + """ + Test vyper_sources round trip + """ + vyper_source = textwrap.dedent(vyper_source) + + if debug and optimize == OptimizationLevel.CODESIZE: + # FIXME: some round-trips fail when debug is enabled due to labels + # not getting pinned + request.node.add_marker(pytest.mark.xfail(strict=False)) + + _round_trip_helper(vyper_source, optimize, compiler_settings) + + +def _round_trip_helper(vyper_source, optimize, compiler_settings): + # helper function to test venom round-tripping thru the parser + # use two helpers because run_passes_on and + # generate_assembly_experimental are both destructive (mutating) on + # the IRContext + _helper1(vyper_source, optimize) + _helper2(vyper_source, optimize, compiler_settings) + + +def _helper1(vyper_source, optimize): + """ + Check that we are able to run passes on the round-tripped venom code + and that it is valid (generates bytecode) + """ + # note: compiling any later stage than bb_runtime like `asm` or + # `bytecode` modifies the bb_runtime data structure in place and results + # in normalization of the venom cfg (which breaks again make_ssa) + out = compile_code(vyper_source, output_formats=["bb_runtime"]) + + bb_runtime = out["bb_runtime"] + venom_code = IRContext.__repr__(bb_runtime) + + ctx = parse_venom(venom_code) + + assert_ctx_eq(bb_runtime, ctx) + + # check it's valid to run venom passes+analyses + # (note this breaks bytecode equality, in the future we should + # test that separately) + run_passes_on(ctx, optimize) + + # test we can generate assembly+bytecode + asm = generate_assembly_experimental(ctx) + generate_bytecode(asm, compiler_metadata=None) + + +def _helper2(vyper_source, optimize, compiler_settings): + """ + Check that we can compile to bytecode, and without running venom passes, + that the output bytecode is equal to going through the normal vyper pipeline + """ + settings = copy.copy(compiler_settings) + # bytecode equivalence only makes sense if we use venom pipeline + settings.experimental_codegen = True + + out = compile_code(vyper_source, settings=settings, output_formats=["bb_runtime"]) + bb_runtime = out["bb_runtime"] + venom_code = IRContext.__repr__(bb_runtime) + + ctx = parse_venom(venom_code) + + assert_ctx_eq(bb_runtime, ctx) + + # test we can generate assembly+bytecode + asm = generate_assembly_experimental(ctx, optimize=optimize) + bytecode = generate_bytecode(asm, compiler_metadata=None) + + out = compile_code(vyper_source, settings=settings, output_formats=["bytecode_runtime"]) + assert "0x" + bytecode.hex() == out["bytecode_runtime"] diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index c9d7248809..196b1e24e6 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -399,6 +399,7 @@ def foo(): "node_id": 0, "path": "main.vy", "source_id": 1, + "is_interface": False, "type": { "name": "main.vy", "type_decl_node": {"node_id": 0, "source_id": 1}, @@ -1175,6 +1176,7 @@ def foo(): "node_id": 0, "path": "lib1.vy", "source_id": 0, + "is_interface": False, "type": { "name": "lib1.vy", "type_decl_node": {"node_id": 0, "source_id": 0}, diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 7660930c26..007abb3512 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -10,6 +10,7 @@ from vyper.cli.compile_archive import compiler_data_from_zip from vyper.cli.vyper_compile import compile_files from vyper.cli.vyper_json import compile_from_input_dict, compile_json +from vyper.compiler import INTERFACE_OUTPUT_FORMATS, OUTPUT_FORMATS from vyper.compiler.input_bundle import FilesystemInputBundle from vyper.compiler.output_bundle import OutputBundle from vyper.compiler.phases import CompilerData @@ -411,6 +412,105 @@ def test_archive_b64_output(input_files): assert out[contract_file] == out2[archive_path] +def test_archive_compile_options(input_files): + tmpdir, _, _, contract_file = input_files + search_paths = [".", tmpdir] + + options = ["abi_python", "json", "ast", "annotated_ast", "ir_json"] + + for option in options: + out = compile_files([contract_file], ["archive_b64", option], paths=search_paths) + + archive_b64 = out[contract_file].pop("archive_b64") + + archive_path = Path("foo.zip.b64") + with archive_path.open("w") as f: + f.write(archive_b64) + + # compare compiling the two input bundles + out2 = compile_files([archive_path], [option]) + + if option in ["ast", "annotated_ast"]: + # would have to normalize paths and imports, so just verify it compiles + continue + + assert out[contract_file] == out2[archive_path] + + +format_options = [ + "bytecode", + "bytecode_runtime", + "blueprint_bytecode", + "abi", + "abi_python", + "source_map", + "source_map_runtime", + "method_identifiers", + "userdoc", + "devdoc", + "metadata", + "combined_json", + "layout", + "ast", + "annotated_ast", + "interface", + "external_interface", + "opcodes", + "opcodes_runtime", + "ir", + "ir_json", + "ir_runtime", + "asm", + "integrity", + "archive", + "solc_json", +] + + +def test_compile_vyz_with_options(input_files): + tmpdir, _, _, contract_file = input_files + search_paths = [".", tmpdir] + + for option in format_options: + out_archive = compile_files([contract_file], ["archive"], paths=search_paths) + + archive = out_archive[contract_file].pop("archive") + + archive_path = Path("foo.zip.out.vyz") + with archive_path.open("wb") as f: + f.write(archive) + + # compare compiling the two input bundles + out = compile_files([contract_file], [option], paths=search_paths) + out2 = compile_files([archive_path], [option]) + + if option in ["ast", "annotated_ast", "metadata"]: + # would have to normalize paths and imports, so just verify it compiles + continue + + if option in ["ir_runtime", "ir", "archive"]: + # ir+ir_runtime is different due to being different compiler runs + # archive is different due to different metadata (timestamps) + continue + + assert out[contract_file] == out2[archive_path] + + +def test_archive_compile_simultaneous_options(input_files): + tmpdir, _, _, contract_file = input_files + search_paths = [".", tmpdir] + + for option in format_options: + with pytest.raises(ValueError) as e: + _ = compile_files([contract_file], ["archive", option], paths=search_paths) + + err_opt = "archive" + if option in ("combined_json", "solc_json"): + err_opt = option + + assert f"If using {err_opt} it must be the only output format requested" in str(e.value) + + def test_solc_json_output(input_files): tmpdir, _, _, storage_layout_path, contract_file, integrity = input_files search_paths = [".", tmpdir] @@ -491,3 +591,31 @@ def test_archive_search_path(tmp_path_factory, make_file, chdir_tmp_path): used_dir = search_paths[-1].stem # either dir1 or dir2 assert output_bundle.used_search_paths == [".", "0/" + used_dir] + + +def test_compile_interface_file(make_file): + interface = """ +@view +@external +def foo() -> String[1]: + ... + +@view +@external +def bar() -> String[1]: + ... + +@external +def baz() -> uint8: + ... + + """ + file = make_file("interface.vyi", interface) + compile_files([file], INTERFACE_OUTPUT_FORMATS) + + # check unallowed output formats + for f in OUTPUT_FORMATS: + if f in INTERFACE_OUTPUT_FORMATS: + continue + with pytest.raises(ValueError): + compile_files([file], [f]) diff --git a/tests/unit/cli/vyper_json/test_compile_json.py b/tests/unit/cli/vyper_json/test_compile_json.py index 765709b526..f921d250a4 100644 --- a/tests/unit/cli/vyper_json/test_compile_json.py +++ b/tests/unit/cli/vyper_json/test_compile_json.py @@ -89,7 +89,7 @@ def oopsie(a: uint256) -> bool: @pytest.fixture(scope="function") -def input_json(optimize, evm_version, experimental_codegen): +def input_json(optimize, evm_version, experimental_codegen, debug): return { "language": "Vyper", "sources": { @@ -103,6 +103,7 @@ def input_json(optimize, evm_version, experimental_codegen): "optimize": optimize.name.lower(), "evmVersion": evm_version, "experimentalCodegen": experimental_codegen, + "debug": debug, }, "storage_layout_overrides": { "contracts/foo.vy": FOO_STORAGE_LAYOUT_OVERRIDES, diff --git a/tests/unit/compiler/test_bytecode_runtime.py b/tests/unit/compiler/test_bytecode_runtime.py index 1d38130c49..9fdc4c493f 100644 --- a/tests/unit/compiler/test_bytecode_runtime.py +++ b/tests/unit/compiler/test_bytecode_runtime.py @@ -54,7 +54,7 @@ def test_bytecode_runtime(): assert out["bytecode_runtime"].removeprefix("0x") in out["bytecode"].removeprefix("0x") -def test_bytecode_signature(): +def test_bytecode_signature(optimize, debug): out = vyper.compile_code( simple_contract_code, output_formats=["bytecode_runtime", "bytecode", "integrity"] ) @@ -65,10 +65,16 @@ def test_bytecode_signature(): metadata = _parse_cbor_metadata(initcode) integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + if debug and optimize == OptimizationLevel.CODESIZE: + # debug forces dense jumptable no matter the size of selector table + expected_data_section_lengths = [5, 7] + else: + expected_data_section_lengths = [] + assert integrity_hash.hex() == out["integrity"] assert runtime_len == len(runtime_code) - assert data_section_lengths == [] + assert data_section_lengths == expected_data_section_lengths assert immutables_len == 0 assert compiler == {"vyper": list(vyper.version.version_tuple)} @@ -119,7 +125,7 @@ def test_bytecode_signature_sparse_jumptable(): assert compiler == {"vyper": list(vyper.version.version_tuple)} -def test_bytecode_signature_immutables(): +def test_bytecode_signature_immutables(debug, optimize): out = vyper.compile_code( has_immutables, output_formats=["bytecode_runtime", "bytecode", "integrity"] ) @@ -130,10 +136,16 @@ def test_bytecode_signature_immutables(): metadata = _parse_cbor_metadata(initcode) integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata + if debug and optimize == OptimizationLevel.CODESIZE: + # debug forces dense jumptable no matter the size of selector table + expected_data_section_lengths = [5, 7] + else: + expected_data_section_lengths = [] + assert integrity_hash.hex() == out["integrity"] assert runtime_len == len(runtime_code) - assert data_section_lengths == [] + assert data_section_lengths == expected_data_section_lengths assert immutables_len == 32 assert compiler == {"vyper": list(vyper.version.version_tuple)} diff --git a/tests/unit/compiler/venom/test_literals_codesize.py b/tests/unit/compiler/venom/test_literals_codesize.py new file mode 100644 index 0000000000..4de4d9de64 --- /dev/null +++ b/tests/unit/compiler/venom/test_literals_codesize.py @@ -0,0 +1,117 @@ +import pytest + +from vyper.utils import evm_not +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRLiteral +from vyper.venom.context import IRContext +from vyper.venom.passes import ReduceLiteralsCodesize + + +def _calc_push_size(val: int): + s = hex(val).removeprefix("0x") + if len(s) % 2 != 0: # justify to multiple of 2 + s = "0" + s + return 1 + len(s) + + +should_invert = [2**256 - 1] + [((2**i) - 1) << (256 - i) for i in range(121, 256 + 1)] + + +@pytest.mark.parametrize("orig_value", should_invert) +def test_literal_codesize_ff_inversion(orig_value): + """ + Test that literals like 0xfffffffffffabcd get inverted to `not 0x5432` + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + inst0 = bb.instructions[0] + assert inst0.opcode == "not" + op0 = inst0.operands[0] + assert evm_not(op0.value) == orig_value + # check the optimization actually improved codesize, after accounting + # for the addl NOT instruction + assert _calc_push_size(op0.value) + 1 < _calc_push_size(orig_value) + + +should_not_invert = [1, 0xFE << 248 | (2**248 - 1)] + [ + ((2**255 - 1) >> i) << i for i in range(0, 3 * 8) +] + + +@pytest.mark.parametrize("orig_value", should_not_invert) +def test_literal_codesize_no_inversion(orig_value): + """ + Check funky cases where inversion would result in bytecode increase + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + assert bb.instructions[0].opcode == "store" + assert bb.instructions[0].operands[0].value == orig_value + + +should_shl = ( + [2**i for i in range(3 * 8, 255)] + + [((2**i) - 1) << (256 - i) for i in range(1, 121)] + + [((2**255 - 1) >> i) << i for i in range(3 * 8, 254)] +) + + +@pytest.mark.parametrize("orig_value", should_shl) +def test_literal_codesize_shl(orig_value): + """ + Test that literals like 0xabcd00000000 get transformed to `shl 32 0xabcd` + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + assert bb.instructions[0].opcode == "shl" + op0, op1 = bb.instructions[0].operands + assert op0.value << op1.value == orig_value + + # check the optimization actually improved codesize, after accounting + # for the addl PUSH and SHL instructions + assert _calc_push_size(op0.value) + _calc_push_size(op1.value) + 1 < _calc_push_size(orig_value) + + +should_not_shl = [1 << i for i in range(0, 3 * 8)] + [ + 0x0, + (((2 ** (256 - 2)) - 1) << (2 * 8)) ^ (2**255), +] + + +@pytest.mark.parametrize("orig_value", should_not_shl) +def test_literal_codesize_no_shl(orig_value): + """ + Check funky cases where shl transformation would result in bytecode increase + """ + ctx = IRContext() + fn = ctx.create_function("_global") + bb = fn.get_basic_block() + + bb.append_instruction("store", IRLiteral(orig_value)) + bb.append_instruction("stop") + ac = IRAnalysesCache(fn) + ReduceLiteralsCodesize(ac, fn).run_pass() + + assert bb.instructions[0].opcode == "store" + assert bb.instructions[0].operands[0].value == orig_value diff --git a/tests/unit/compiler/venom/test_load_elimination.py b/tests/unit/compiler/venom/test_load_elimination.py new file mode 100644 index 0000000000..52c7baf3c9 --- /dev/null +++ b/tests/unit/compiler/venom/test_load_elimination.py @@ -0,0 +1,129 @@ +from tests.venom_utils import assert_ctx_eq, parse_from_basic_block +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.passes.load_elimination import LoadElimination + + +def _check_pre_post(pre, post): + ctx = parse_from_basic_block(pre) + + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + LoadElimination(ac, fn).run_pass() + + assert_ctx_eq(ctx, parse_from_basic_block(post)) + + +def _check_no_change(pre): + _check_pre_post(pre, pre) + + +def test_simple_load_elimination(): + pre = """ + main: + %ptr = 11 + %1 = mload %ptr + + %2 = mload %ptr + + stop + """ + post = """ + main: + %ptr = 11 + %1 = mload %ptr + + %2 = %1 + + stop + """ + _check_pre_post(pre, post) + + +def test_equivalent_var_elimination(): + """ + Test that the lattice can "peer through" equivalent vars + """ + pre = """ + main: + %1 = 11 + %2 = %1 + %3 = mload %1 + + %4 = mload %2 + + stop + """ + post = """ + main: + %1 = 11 + %2 = %1 + %3 = mload %1 + + %4 = %3 # %2 == %1 + + stop + """ + _check_pre_post(pre, post) + + +def test_elimination_barrier(): + """ + Check for barrier between load/load + """ + pre = """ + main: + %1 = 11 + %2 = mload %1 + %3 = %100 + # fence - writes to memory + staticcall %3, %3, %3, %3 + %4 = mload %1 + """ + _check_no_change(pre) + + +def test_store_load_elimination(): + """ + Check that lattice stores the result of mstores (even through + equivalent variables) + """ + pre = """ + main: + %val = 55 + %ptr1 = 11 + %ptr2 = %ptr1 + mstore %ptr1, %val + + %3 = mload %ptr2 + + stop + """ + post = """ + main: + %val = 55 + %ptr1 = 11 + %ptr2 = %ptr1 + mstore %ptr1, %val + + %3 = %val + + stop + """ + _check_pre_post(pre, post) + + +def test_store_load_barrier(): + """ + Check for barrier between store/load + """ + pre = """ + main: + %ptr = 11 + %val = 55 + mstore %ptr, %val + %3 = %100 ; arbitrary + # fence + staticcall %3, %3, %3, %3 + %4 = mload %ptr + """ + _check_no_change(pre) diff --git a/tests/unit/compiler/venom/test_make_ssa.py b/tests/unit/compiler/venom/test_make_ssa.py index aa3fead6bf..7f6b2c0cba 100644 --- a/tests/unit/compiler/venom/test_make_ssa.py +++ b/tests/unit/compiler/venom/test_make_ssa.py @@ -1,48 +1,52 @@ +from tests.venom_utils import assert_ctx_eq, parse_venom from vyper.venom.analysis import IRAnalysesCache -from vyper.venom.basicblock import IRBasicBlock, IRLabel -from vyper.venom.context import IRContext from vyper.venom.passes import MakeSSA -def test_phi_case(): - ctx = IRContext() - fn = ctx.create_function("_global") - - bb = fn.get_basic_block() - - bb_cont = IRBasicBlock(IRLabel("condition"), fn) - bb_then = IRBasicBlock(IRLabel("then"), fn) - bb_else = IRBasicBlock(IRLabel("else"), fn) - bb_if_exit = IRBasicBlock(IRLabel("if_exit"), fn) - fn.append_basic_block(bb_cont) - fn.append_basic_block(bb_then) - fn.append_basic_block(bb_else) - fn.append_basic_block(bb_if_exit) - - v = bb.append_instruction("mload", 64) - bb_cont.append_instruction("jnz", v, bb_then.label, bb_else.label) - - bb_if_exit.append_instruction("add", v, 1, ret=v) - bb_if_exit.append_instruction("jmp", bb_cont.label) +def _check_pre_post(pre, post): + ctx = parse_venom(pre) + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + assert_ctx_eq(ctx, parse_venom(post)) - bb_then.append_instruction("assert", bb_then.append_instruction("mload", 96)) - bb_then.append_instruction("jmp", bb_if_exit.label) - bb_else.append_instruction("jmp", bb_if_exit.label) - bb.append_instruction("jmp", bb_cont.label) - - ac = IRAnalysesCache(fn) - MakeSSA(ac, fn).run_pass() - - condition_block = fn.get_basic_block("condition") - assert len(condition_block.instructions) == 2 - - phi_inst = condition_block.instructions[0] - assert phi_inst.opcode == "phi" - assert phi_inst.operands[0].name == "_global" - assert phi_inst.operands[1].name == "%1" - assert phi_inst.operands[2].name == "if_exit" - assert phi_inst.operands[3].name == "%1" - assert phi_inst.output.name == "%1" - assert phi_inst.output.value != phi_inst.operands[1].value - assert phi_inst.output.value != phi_inst.operands[3].value +def test_phi_case(): + pre = """ + function loop { + main: + %v = mload 64 + jmp @test + test: + jnz %v, @then, @else + then: + %t = mload 96 + assert %t + jmp @if_exit + else: + jmp @if_exit + if_exit: + %v = add %v, 1 + jmp @test + } + """ + post = """ + function loop { + main: + %v = mload 64 + jmp @test + test: + %v:1 = phi @main, %v, @if_exit, %v:2 + jnz %v:1, @then, @else + then: + %t = mload 96 + assert %t + jmp @if_exit + else: + jmp @if_exit + if_exit: + %v:2 = add %v:1, 1 + jmp @test + } + """ + _check_pre_post(pre, post) diff --git a/tests/unit/compiler/venom/test_memmerging.py b/tests/unit/compiler/venom/test_memmerging.py new file mode 100644 index 0000000000..d309752621 --- /dev/null +++ b/tests/unit/compiler/venom/test_memmerging.py @@ -0,0 +1,1065 @@ +import pytest + +from tests.venom_utils import assert_ctx_eq, parse_from_basic_block, parse_venom +from vyper.evm.opcodes import version_check +from vyper.venom.analysis import IRAnalysesCache +from vyper.venom.passes import SCCP, MemMergePass + + +def _check_pre_post(pre, post): + ctx = parse_from_basic_block(pre) + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + MemMergePass(ac, fn).run_pass() + assert_ctx_eq(ctx, parse_from_basic_block(post)) + + +def _check_no_change(pre): + _check_pre_post(pre, pre) + + +# for parametrizing tests +LOAD_COPY = [("dload", "dloadbytes"), ("calldataload", "calldatacopy")] + + +def test_memmerging(): + """ + Basic memory merge test + All mloads and mstores can be + transformed into mcopy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 1000, %1 + mstore 1032, %2 + mstore 1064, %3 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_out_of_order(): + """ + interleaved mloads/mstores which can be transformed into mcopy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 32 + %2 = mload 0 + mstore 132, %1 + %3 = mload 64 + mstore 164, %3 + mstore 100, %2 + stop + """ + + post = """ + _global: + mcopy 100, 0, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_imposs(): + """ + Test case of impossible merge + Impossible because of the overlap + [0 96] + [32 128] + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 32, %1 + + ; BARRIER - overlap between src and dst + ; (writes to source of potential mcopy) + mstore 64, %2 + + mstore 96, %3 + stop + """ + _check_no_change(pre) + + +def test_memmerging_imposs_mstore(): + """ + Test case of impossible merge + Impossible because of the mstore barrier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 16 + mstore 1000, %1 + %3 = mload 1000 ; BARRIER - load from dst of potential mcopy + mstore 1016, %2 + mstore 2000, %3 + stop + """ + _check_no_change(pre) + + +@pytest.mark.xfail +def test_memmerging_bypass_fence(): + """ + We should be able to optimize this to an mcopy(0, 1000, 64), but + currently do not + """ + if not version_check(begin="cancun"): + raise AssertionError() # xfail + + pre = """ + function _global { + _global: + %1 = mload 0 + %2 = mload 32 + mstore %1, 1000 + %3 = mload 1000 + mstore 1032, %2 + mstore 2000, %3 + stop + } + """ + + ctx = parse_venom(pre) + + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + MemMergePass(ac, fn).run_pass() + + fn = next(iter(ctx.functions.values())) + bb = fn.entry + assert any(inst.opcode == "mcopy" for inst in bb.instructions) + + +def test_memmerging_imposs_unkown_place(): + """ + Test case of impossible merge + Impossible because of the + non constant address mload and mstore barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = param + %2 = mload 0 + %3 = mload %1 ; BARRIER + %4 = mload 32 + %5 = mload 64 + mstore 1000, %2 + mstore 1032, %4 + mstore 10, %1 ; BARRIER + mstore 1064, %5 + stop + """ + _check_no_change(pre) + + +def test_memmerging_imposs_msize(): + """ + Test case of impossible merge + Impossible because of the msize barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = msize ; BARRIER + %3 = mload 32 + %4 = mload 64 + mstore 1000, %1 + mstore 1032, %3 + %5 = msize ; BARRIER + mstore 1064, %4 + return %2, %5 + """ + _check_no_change(pre) + + +def test_memmerging_partial_msize(): + """ + Only partial merge possible + because of the msize barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 1000, %1 + mstore 1032, %2 + %4 = msize ; BARRIER + mstore 1064, %3 + return %4 + """ + + post = """ + _global: + %3 = mload 64 + mcopy 1000, 0, 64 + %4 = msize + mstore 1064, %3 + return %4 + """ + _check_pre_post(pre, post) + + +def test_memmerging_partial_overlap(): + """ + Two different copies from overlapping + source range + + [0 128] + [24 88] + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + %4 = mload 96 + %5 = mload 24 + %6 = mload 56 + mstore 1064, %3 + mstore 1096, %4 + mstore 1000, %1 + mstore 1032, %2 + mstore 2024, %5 + mstore 2056, %6 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 128 + mcopy 2024, 24, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_partial_different_effect(): + """ + Only partial merge possible + because of the generic memory + effect barier + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + %3 = mload 64 + mstore 1000, %1 + mstore 1032, %2 + dloadbytes 2000, 1000, 1000 ; BARRIER + mstore 1064, %3 + stop + """ + + post = """ + _global: + %3 = mload 64 + mcopy 1000, 0, 64 + dloadbytes 2000, 1000, 1000 + mstore 1064, %3 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerge_ok_interval_subset(): + """ + Test subintervals get subsumed by larger intervals + mstore(, mload()) + mcopy(, , 64) + => + mcopy(, , 64) + Because the first mload/mstore is contained in the mcopy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + mstore 100, %1 + mcopy 100, 0, 33 + stop + """ + + post = """ + _global: + mcopy 100, 0, 33 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_ok_overlap(): + """ + Test for with source overlap + which is ok to do + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 24 + %3 = mload 48 + mstore 1000, %1 + mstore 1024, %2 + mstore 1048, %3 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 80 + stop + """ + + _check_pre_post(pre, post) + + +def test_memmerging_mcopy(): + """ + Test that sequences of mcopy get merged (not just loads/stores) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 32 + mcopy 1032, 32, 32 + mcopy 1064, 64, 64 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 128 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_small(): + """ + Test that sequences of mcopies get merged, and that mcopy of 32 bytes + gets transformed to mload/mstore (saves 1 byte) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 16 + mcopy 1016, 16, 16 + stop + """ + + post = """ + _global: + %1 = mload 0 + mstore 1000, %1 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_weird_bisect(): + """ + Check that bisect_left finds the correct merge + copy(80, 100, 2) + copy(150, 60, 1) + copy(82, 102, 3) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 80, 100, 2 + mcopy 150, 60, 1 + mcopy 82, 102, 3 + stop + """ + + post = """ + _global: + mcopy 150, 60, 1 + mcopy 80, 100, 5 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_weird_bisect2(): + """ + Check that bisect_left finds the correct merge + copy(80, 50, 2) + copy(20, 100, 1) + copy(82, 52, 3) + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 80, 50, 2 + mcopy 20, 100, 1 + mcopy 82, 52, 3 + stop + """ + + post = """ + _global: + mcopy 20, 100, 1 + mcopy 80, 50, 5 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_allowed_overlapping(): + """ + Test merge of interleaved mload/mstore/mcopy works + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 32 + mcopy 1000, 32, 128 + %2 = mload 0 + mstore 2032, %1 + mstore 2000, %2 + stop + """ + + post = """ + _global: + mcopy 1000, 32, 128 + mcopy 2000, 0, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_allowed_overlapping2(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 64 + %1 = mload 1032 + mstore 2000, %1 + %2 = mload 1064 + mstore 2032, %2 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 64 + mcopy 2000, 1032, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_unused_mload(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 100 + %2 = mload 132 + mstore 64, %2 + + # does not interfere with the mload/mstore merging even though + # it cannot be removed + %3 = mload 32 + + mstore 32, %1 + return %3, %3 + """ + + post = """ + _global: + %3 = mload 32 + mcopy 32, 100, 64 + return %3, %3 + """ + + _check_pre_post(pre, post) + + +def test_memmerging_unused_mload1(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 100 + %2 = mload 132 + mstore 0, %1 + + # does not interfere with the mload/mstore merging even though + # it cannot be removed + %3 = mload 32 + + mstore 32, %2 + return %3, %3 + """ + + post = """ + _global: + %3 = mload 32 + mcopy 0, 100, 64 + return %3, %3 + """ + _check_pre_post(pre, post) + + +def test_memmerging_mload_read_after_write_hazard(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 100 + %2 = mload 132 + mstore 0, %1 + %3 = mload 32 + mstore 32, %2 + %4 = mload 64 + + ; BARRIER - the load is overriden by existing copy + mstore 1000, %3 + mstore 1032, %4 + stop + """ + + post = """ + _global: + %3 = mload 32 + mcopy 0, 100, 64 + %4 = mload 64 + mstore 1000, %3 + mstore 1032, %4 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_mcopy_read_after_write_hazard(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 32, 64 + mcopy 2000, 1000, 64 ; BARRIER + mcopy 1064, 96, 64 + stop + """ + _check_no_change(pre) + + +def test_memmerging_write_after_write(): + """ + Check that conflicting writes (from different source locations) + produce a barrier - mstore+mstore version + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 100 + %3 = mload 32 + %4 = mload 132 + mstore 1000, %1 + mstore 1000, %2 ; BARRIER + mstore 1032, %4 + mstore 1032, %3 ; BARRIER + """ + _check_no_change(pre) + + +def test_memmerging_write_after_write_mstore_and_mcopy(): + """ + Check that conflicting writes (from different source locations) + produce a barrier - mstore+mcopy version + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 132 + mstore 1000, %1 + mcopy 1000, 100, 16 ; write barrier + mstore 1032, %2 + mcopy 1016, 116, 64 + stop + """ + _check_no_change(pre) + + +def test_memmerging_write_after_write_only_mcopy(): + """ + Check that conflicting writes (from different source locations) + produce a barrier - mcopy+mcopy version + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 16 + mcopy 1000, 100, 16 ; write barrier + mcopy 1016, 116, 64 + mcopy 1016, 16, 64 + stop + """ + + post = """ + _global: + mcopy 1000, 0, 16 + mcopy 1000, 100, 80 + mcopy 1016, 16, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memmerging_not_allowed_overlapping(): + if not version_check(begin="cancun"): + return + + # NOTE: maybe optimization is possible here, to: + # mcopy 2000, 1000, 64 + # mcopy 1000, 0, 128 + pre = """ + _global: + %1 = mload 1000 + %2 = mload 1032 + mcopy 1000, 0, 128 + mstore 2000, %1 ; BARRIER - the mload and mcopy cannot be combined + mstore 2032, %2 + stop + """ + _check_no_change(pre) + + +def test_memmerging_not_allowed_overlapping2(): + if not version_check(begin="cancun"): + return + + # NOTE: maybe optimization is possible here, to: + # mcopy 2000, 1000, 64 + # mcopy 1000, 0, 128 + pre = """ + _global: + %1 = mload 1032 + mcopy 1000, 0, 64 + mstore 2000, %1 + %2 = mload 1064 + mstore 2032, %2 + stop + """ + + _check_no_change(pre) + + +def test_memmerging_existing_copy_overwrite(): + """ + Check that memmerge does not write over source of another copy + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 1000, 0, 64 + %1 = mload 2000 + + # barrier, write over source of existing copy + mstore 0, %1 + + mcopy 1064, 64, 64 + stop + """ + + _check_no_change(pre) + + +def test_memmerging_double_use(): + if not version_check(begin="cancun"): + return + + pre = """ + _global: + %1 = mload 0 + %2 = mload 32 + mstore 1000, %1 + mstore 1032, %2 + return %1 + """ + + post = """ + _global: + %1 = mload 0 + mcopy 1000, 0, 64 + return %1 + """ + + _check_pre_post(pre, post) + + +@pytest.mark.parametrize("load_opcode,copy_opcode", LOAD_COPY) +def test_memmerging_load(load_opcode, copy_opcode): + pre = f""" + _global: + %1 = {load_opcode} 0 + mstore 32, %1 + %2 = {load_opcode} 32 + mstore 64, %2 + stop + """ + + post = f""" + _global: + {copy_opcode} 32, 0, 64 + stop + """ + _check_pre_post(pre, post) + + +@pytest.mark.parametrize("load_opcode,copy_opcode", LOAD_COPY) +def test_memmerging_two_intervals_diff_offset(load_opcode, copy_opcode): + """ + Test different dloadbytes/calldatacopy sequences are separately merged + """ + pre = f""" + _global: + %1 = {load_opcode} 0 + mstore 0, %1 + {copy_opcode} 32, 32, 64 + %2 = {load_opcode} 0 + mstore 8, %2 + {copy_opcode} 40, 32, 64 + stop + """ + + post = f""" + _global: + {copy_opcode} 0, 0, 96 + {copy_opcode} 8, 0, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_1(): + """ + Test of basic memzeroing done with mstore only + """ + + pre = """ + _global: + mstore 32, 0 + mstore 64, 0 + mstore 96, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + calldatacopy 32, %1, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_2(): + """ + Test of basic memzeroing done with calldatacopy only + + sequence of these instruction will + zero out the memory at destination + %1 = calldatasize + calldatacopy %1 + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 64, %1, 128 + %2 = calldatasize + calldatacopy 192, %2, 128 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + %3 = calldatasize + calldatacopy 64, %3, 256 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_3(): + """ + Test of basic memzeroing done with combination of + mstores and calldatacopies + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 0, %1, 100 + mstore 100, 0 + %2 = calldatasize + calldatacopy 132, %2, 100 + mstore 232, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + %3 = calldatasize + calldatacopy 0, %3, 264 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_small_calldatacopy(): + """ + Test of converting calldatacopy of + size 32 into mstore + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 0, %1, 32 + stop + """ + + post = """ + _global: + %1 = calldatasize + mstore 0, 0 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_smaller_calldatacopy(): + """ + Test merging smaller (<32) calldatacopies + into either calldatacopy or mstore + """ + + pre = """ + _global: + %1 = calldatasize + calldatacopy 0, %1, 8 + %2 = calldatasize + calldatacopy 8, %2, 16 + %3 = calldatasize + calldatacopy 100, %3, 8 + %4 = calldatasize + calldatacopy 108, %4, 16 + %5 = calldatasize + calldatacopy 124, %5, 8 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + %6 = calldatasize + calldatacopy 0, %6, 24 + %3 = calldatasize + %4 = calldatasize + %5 = calldatasize + mstore 100, 0 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_overlap(): + """ + Test of merging overlaping zeroing intervals + + [128 160] + [136 192] + """ + + pre = """ + _global: + mstore 100, 0 + %1 = calldatasize + calldatacopy 108, %1, 56 + stop + """ + + post = """ + _global: + %1 = calldatasize + %2 = calldatasize + calldatacopy 100, %2, 64 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_imposs(): + """ + Test of memzeroing barriers caused + by non constant arguments + """ + + pre = """ + _global: + %1 = param ; abstract location, causes barrier + mstore 32, 0 + mstore %1, 0 + mstore 64, 0 + %2 = calldatasize + calldatacopy %1, %2, 10 + mstore 96, 0 + %3 = calldatasize + calldatacopy 10, %3, %1 + mstore 128, 0 + calldatacopy 10, %1, 10 + mstore 160, 0 + stop + """ + _check_no_change(pre) + + +def test_memzeroing_imposs_effect(): + """ + Test of memzeroing bariers caused + by different effect + """ + + pre = """ + _global: + mstore 32, 0 + dloadbytes 10, 20, 30 ; BARRIER + mstore 64, 0 + stop + """ + _check_no_change(pre) + + +def test_memzeroing_overlaping(): + """ + Test merging overlapping memzeroes (they can be merged + since both result in zeroes being written to destination) + """ + + pre = """ + _global: + mstore 32, 0 + mstore 96, 0 + mstore 32, 0 + mstore 64, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + calldatacopy 32, %1, 96 + stop + """ + _check_pre_post(pre, post) + + +def test_memzeroing_interleaved(): + """ + Test merging overlapping memzeroes (they can be merged + since both result in zeroes being written to destination) + """ + + pre = """ + _global: + mstore 32, 0 + mstore 1000, 0 + mstore 64, 0 + mstore 1032, 0 + stop + """ + + post = """ + _global: + %1 = calldatasize + calldatacopy 32, %1, 64 + %2 = calldatasize + calldatacopy 1000, %2, 64 + stop + """ + _check_pre_post(pre, post) diff --git a/tests/venom_utils.py b/tests/venom_utils.py index d4536e8bf7..85298ccb87 100644 --- a/tests/venom_utils.py +++ b/tests/venom_utils.py @@ -36,15 +36,11 @@ def assert_fn_eq(fn1: IRFunction, fn2: IRFunction): def assert_ctx_eq(ctx1: IRContext, ctx2: IRContext): - assert ctx1.last_label == ctx2.last_label - assert len(ctx1.functions) == len(ctx2.functions) for label1, fn1 in ctx1.functions.items(): assert label1 in ctx2.functions assert_fn_eq(fn1, ctx2.functions[label1]) + assert len(ctx1.functions) == len(ctx2.functions) # check entry function is the same assert next(iter(ctx1.functions.keys())) == next(iter(ctx2.functions.keys())) - - assert len(ctx1.data_segment) == len(ctx2.data_segment) - for d1, d2 in zip(ctx1.data_segment, ctx2.data_segment): - assert instructions_eq(d1, d2), f"data: [{d1}] != [{d2}]" + assert ctx1.data_segment == ctx2.data_segment, ctx2.data_segment diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 974685f403..ccc80947e4 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -638,7 +638,7 @@ class TopLevel(VyperNode): class Module(TopLevel): # metadata - __slots__ = ("path", "resolved_path", "source_id") + __slots__ = ("path", "resolved_path", "source_id", "is_interface") def to_dict(self): return dict(source_sha256sum=self.source_sha256sum, **super().to_dict()) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 783764271d..b00354c03a 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -71,6 +71,7 @@ class Module(TopLevel): path: str = ... resolved_path: str = ... source_id: int = ... + is_interface: bool = ... def namespace(self) -> Any: ... # context manager class FunctionDef(TopLevel): diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index d975aafac4..423b37721a 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -23,10 +23,11 @@ def parse_to_ast_with_settings( module_path: Optional[str] = None, resolved_path: Optional[str] = None, add_fn_node: Optional[str] = None, + is_interface: bool = False, ) -> tuple[Settings, vy_ast.Module]: try: return _parse_to_ast_with_settings( - vyper_source, source_id, module_path, resolved_path, add_fn_node + vyper_source, source_id, module_path, resolved_path, add_fn_node, is_interface ) except SyntaxException as e: e.resolved_path = resolved_path @@ -39,6 +40,7 @@ def _parse_to_ast_with_settings( module_path: Optional[str] = None, resolved_path: Optional[str] = None, add_fn_node: Optional[str] = None, + is_interface: bool = False, ) -> tuple[Settings, vy_ast.Module]: """ Parses a Vyper source string and generates basic Vyper AST nodes. @@ -62,6 +64,9 @@ def _parse_to_ast_with_settings( resolved_path: str, optional The resolved path of the source code Corresponds to FileInput.resolved_path + is_interface: bool + Indicates whether the source code should + be parsed as an interface file. Returns ------- @@ -106,6 +111,7 @@ def _parse_to_ast_with_settings( # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint + module.is_interface = is_interface return pre_parser.settings, module diff --git a/vyper/builtins/interfaces/IERC4626.vyi b/vyper/builtins/interfaces/IERC4626.vyi index 6d9e4c6ef7..0dd398d1f3 100644 --- a/vyper/builtins/interfaces/IERC4626.vyi +++ b/vyper/builtins/interfaces/IERC4626.vyi @@ -44,7 +44,7 @@ def previewDeposit(assets: uint256) -> uint256: ... @external -def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: +def deposit(assets: uint256, receiver: address) -> uint256: ... @view @@ -58,7 +58,7 @@ def previewMint(shares: uint256) -> uint256: ... @external -def mint(shares: uint256, receiver: address=msg.sender) -> uint256: +def mint(shares: uint256, receiver: address) -> uint256: ... @view @@ -72,7 +72,7 @@ def previewWithdraw(assets: uint256) -> uint256: ... @external -def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: +def withdraw(assets: uint256, receiver: address, owner: address) -> uint256: ... @view @@ -86,5 +86,5 @@ def previewRedeem(shares: uint256) -> uint256: ... @external -def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sender) -> uint256: +def redeem(shares: uint256, receiver: address, owner: address) -> uint256: ... diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 046cac2c0b..390416799a 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -359,7 +359,7 @@ def compile_files( # we allow this instead of requiring a different mode (like # `--zip`) so that verifier pipelines do not need a different # workflow for archive files and single-file contracts. - output = compile_from_zip(file_name, output_formats, settings, no_bytecode_metadata) + output = compile_from_zip(file_name, final_formats, settings, no_bytecode_metadata) ret[file_path] = output continue except NotZipInput: diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py index f7bcb622c7..5f632f4167 100755 --- a/vyper/cli/vyper_json.py +++ b/vyper/cli/vyper_json.py @@ -286,8 +286,17 @@ def get_settings(input_dict: dict) -> Settings: else: assert optimize is None + debug = input_dict["settings"].get("debug", None) + + # TODO: maybe change these to camelCase for consistency + enable_decimals = input_dict["settings"].get("enable_decimals", None) + return Settings( - evm_version=evm_version, optimize=optimize, experimental_codegen=experimental_codegen + evm_version=evm_version, + optimize=optimize, + experimental_codegen=experimental_codegen, + debug=debug, + enable_decimals=enable_decimals, ) diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index d885599cec..57bd2f4096 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -46,6 +46,13 @@ "opcodes_runtime": output.build_opcodes_runtime_output, } +INTERFACE_OUTPUT_FORMATS = [ + "ast_dict", + "annotated_ast_dict", + "interface", + "external_interface", + "abi", +] UNKNOWN_CONTRACT_NAME = "" @@ -121,10 +128,18 @@ def outputs_from_compiler_data( output_formats = ("bytecode",) ret = {} + with anchor_settings(compiler_data.settings): for output_format in output_formats: if output_format not in OUTPUT_FORMATS: raise ValueError(f"Unsupported format type {repr(output_format)}") + + is_vyi = compiler_data.file_input.resolved_path.suffix == ".vyi" + if is_vyi and output_format not in INTERFACE_OUTPUT_FORMATS: + raise ValueError( + f"Unsupported format for compiling interface: {repr(output_format)}" + ) + try: formatter = OUTPUT_FORMATS[output_format] ret[output_format] = formatter(compiler_data) diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 82a9602540..b6a0e8ac8c 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -108,9 +108,8 @@ def build_integrity(compiler_data: CompilerData) -> str: def build_external_interface_output(compiler_data: CompilerData) -> str: interface = compiler_data.annotated_vyper_module._metadata["type"].interface stem = PurePath(compiler_data.contract_path).stem - # capitalize words separated by '_' - # ex: test_interface.vy -> TestInterface - name = "".join([x.capitalize() for x in stem.split("_")]) + + name = stem.title().replace("_", "") out = f"\n# External Interfaces\ninterface {name}:\n" for func in interface.functions.values(): @@ -136,6 +135,14 @@ def build_interface_output(compiler_data: CompilerData) -> str: out += f" {member_name}: {member_type}\n" out += "\n\n" + if len(interface.flags) > 0: + out += "# Flags\n\n" + for flag in interface.flags.values(): + out += f"flag {flag.name}:\n" + for flag_value in flag._flag_members: + out += f" {flag_value}\n" + out += "\n\n" + if len(interface.events) > 0: out += "# Events\n\n" for event in interface.events.values(): @@ -282,7 +289,8 @@ def build_method_identifiers_output(compiler_data: CompilerData) -> dict: def build_abi_output(compiler_data: CompilerData) -> list: module_t = compiler_data.annotated_vyper_module._metadata["type"] - _ = compiler_data.ir_runtime # ensure _ir_info is generated + if not compiler_data.annotated_vyper_module.is_interface: + _ = compiler_data.ir_runtime # ensure _ir_info is generated abi = module_t.interface.to_toplevel_abi_dict() if module_t.init_function: diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 3d5791a644..17812ee535 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -13,6 +13,7 @@ from vyper.compiler.input_bundle import FileInput, FilesystemInputBundle, InputBundle from vyper.compiler.settings import OptimizationLevel, Settings, anchor_settings, merge_settings from vyper.ir import compile_ir, optimizer +from vyper.ir.compile_ir import reset_symbols from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.analysis.data_positions import generate_layout_export from vyper.semantics.analysis.imports import resolve_imports @@ -114,11 +115,14 @@ def contract_path(self): @cached_property def _generate_ast(self): + is_vyi = self.contract_path.suffix == ".vyi" + settings, ast = vy_ast.parse_to_ast_with_settings( self.source_code, self.source_id, module_path=self.contract_path.as_posix(), resolved_path=self.file_input.resolved_path.as_posix(), + is_interface=is_vyi, ) if self.original_settings: @@ -320,6 +324,7 @@ def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, """ # make IR output the same between runs codegen.reset_names() + reset_symbols() with anchor_settings(settings): ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py index a8e28c1ed1..e9840e8334 100644 --- a/vyper/compiler/settings.py +++ b/vyper/compiler/settings.py @@ -120,12 +120,12 @@ def _merge_one(lhs, rhs, helpstr): return lhs if rhs is None else rhs ret = Settings() - ret.evm_version = _merge_one(one.evm_version, two.evm_version, "evm version") - ret.optimize = _merge_one(one.optimize, two.optimize, "optimize") - ret.experimental_codegen = _merge_one( - one.experimental_codegen, two.experimental_codegen, "experimental codegen" - ) - ret.enable_decimals = _merge_one(one.enable_decimals, two.enable_decimals, "enable-decimals") + for field in dataclasses.fields(ret): + if field.name == "compiler_version": + continue + pretty_name = field.name.replace("_", "-") # e.g. evm_version -> evm-version + val = _merge_one(getattr(one, field.name), getattr(two, field.name), pretty_name) + setattr(ret, field.name, val) return ret diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index e87cf1b310..936e6d5d72 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -54,6 +54,11 @@ def mksymbol(name=""): return f"_sym_{name}{_next_symbol}" +def reset_symbols(): + global _next_symbol + _next_symbol = 0 + + def mkdebug(pc_debugger, ast_source): i = Instruction("DEBUG", ast_source) i.pc_debugger = pc_debugger diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 737f675b7c..534af4d633 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -53,7 +53,7 @@ def analyze_module(module_ast: vy_ast.Module) -> ModuleT: add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ - return _analyze_module_r(module_ast) + return _analyze_module_r(module_ast, module_ast.is_interface) def _analyze_module_r(module_ast: vy_ast.Module, is_interface: bool = False): diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 73fa4878c7..d01ab23299 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -77,6 +77,9 @@ def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": self._helper.get_member(key, node) return self + def __str__(self): + return f"{self.name}" + def __repr__(self): arg_types = ",".join(repr(a) for a in self._flag_members) return f"flag {self.name}({arg_types})" diff --git a/vyper/utils.py b/vyper/utils.py index d635c78383..db50626713 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -391,6 +391,11 @@ def evm_twos_complement(x: int) -> int: return ((2**256 - 1) ^ x) + 1 +def evm_not(val: int) -> int: + assert 0 <= val <= SizeLimits.MAX_UINT256, "Value out of bounds" + return SizeLimits.MAX_UINT256 ^ val + + # EVM div semantics as a python function def evm_div(x, y): if y == 0: @@ -519,20 +524,79 @@ def indent(text: str, indent_chars: Union[str, List[str]] = " ", level: int = 1) @contextlib.contextmanager -def timeit(msg): +def timeit(msg): # pragma: nocover start_time = time.perf_counter() yield end_time = time.perf_counter() total_time = end_time - start_time - print(f"{msg}: Took {total_time:.4f} seconds") + print(f"{msg}: Took {total_time:.4f} seconds", file=sys.stderr) + + +_CUMTIMES = None + + +def _dump_cumtime(): # pragma: nocover + global _CUMTIMES + for msg, total_time in _CUMTIMES.items(): + print(f"{msg}: Cumulative time {total_time:.4f} seconds", file=sys.stderr) @contextlib.contextmanager -def timer(msg): - t0 = time.time() +def cumtimeit(msg): # pragma: nocover + import atexit + from collections import defaultdict + + global _CUMTIMES + + if _CUMTIMES is None: + warnings.warn("timing code, disable me before pushing!", stacklevel=2) + _CUMTIMES = defaultdict(int) + atexit.register(_dump_cumtime) + + start_time = time.perf_counter() yield - t1 = time.time() - print(f"{msg} took {t1 - t0}s") + end_time = time.perf_counter() + total_time = end_time - start_time + _CUMTIMES[msg] += total_time + + +_PROF = None + + +def _dump_profile(): # pragma: nocover + global _PROF + + _PROF.disable() # don't profile dumping stats + _PROF.dump_stats("stats") + + from pstats import Stats + + stats = Stats("stats", stream=sys.stderr) + stats.sort_stats("time") + stats.print_stats() + + +@contextlib.contextmanager +def profileit(): # pragma: nocover + """ + Helper function for local dev use, is not intended to ever be run in + production build + """ + import atexit + from cProfile import Profile + + global _PROF + if _PROF is None: + warnings.warn("profiling code, disable me before pushing!", stacklevel=2) + _PROF = Profile() + _PROF.disable() + atexit.register(_dump_profile) + + try: + _PROF.enable() + yield + finally: + _PROF.disable() def annotate_source_code( diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 7d9404b9ef..ddd9065194 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -15,8 +15,12 @@ BranchOptimizationPass, DFTPass, FloatAllocas, + LoadElimination, + LowerDloadPass, MakeSSA, Mem2Var, + MemMergePass, + ReduceLiteralsCodesize, RemoveUnusedVariablesPass, SimplifyCFGPass, StoreElimination, @@ -52,11 +56,17 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: SimplifyCFGPass(ac, fn).run_pass() MakeSSA(ac, fn).run_pass() + StoreElimination(ac, fn).run_pass() Mem2Var(ac, fn).run_pass() MakeSSA(ac, fn).run_pass() SCCP(ac, fn).run_pass() + + LoadElimination(ac, fn).run_pass() StoreElimination(ac, fn).run_pass() + MemMergePass(ac, fn).run_pass() SimplifyCFGPass(ac, fn).run_pass() + + LowerDloadPass(ac, fn).run_pass() AlgebraicOptimizationPass(ac, fn).run_pass() # NOTE: MakeSSA is after algebraic optimization it currently produces # smaller code by adding some redundant phi nodes. This is not a @@ -69,6 +79,10 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: RemoveUnusedVariablesPass(ac, fn).run_pass() StoreExpansionPass(ac, fn).run_pass() + + if optimize == OptimizationLevel.CODESIZE: + ReduceLiteralsCodesize(ac, fn).run_pass() + DFTPass(ac, fn).run_pass() diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index cb2904f97f..4c75c67700 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,3 +1,5 @@ +import json +import re from typing import TYPE_CHECKING, Any, Iterator, Optional, Union import vyper.venom.effects as effects @@ -105,7 +107,7 @@ def __init__(self, line_no: int, src: str) -> None: def __repr__(self) -> str: src = self.src if self.src else "" - return f"\t# line {self.line_no}: {src}".expandtabs(20) + return f"\t; line {self.line_no}: {src}".expandtabs(20) class IROperand: @@ -115,13 +117,20 @@ class IROperand: """ value: Any + _hash: Optional[int] = None + + def __init__(self, value: Any) -> None: + self.value = value + self._hash = None @property def name(self) -> str: return self.value def __hash__(self) -> int: - return hash(self.value) + if self._hash is None: + self._hash = hash(self.value) + return self._hash def __eq__(self, other) -> bool: if not isinstance(other, type(self)): @@ -141,7 +150,7 @@ class IRLiteral(IROperand): def __init__(self, value: int) -> None: assert isinstance(value, int), "value must be an int" - self.value = value + super().__init__(value) class IRVariable(IROperand): @@ -149,27 +158,25 @@ class IRVariable(IROperand): IRVariable represents a variable in IR. A variable is a string that starts with a %. """ - value: str - - def __init__(self, value: str, version: Optional[str | int] = None) -> None: - assert isinstance(value, str) - assert ":" not in value, "Variable name cannot contain ':'" - if version: - assert isinstance(value, str) or isinstance(value, int), "value must be an str or int" - value = f"{value}:{version}" - if value[0] != "%": - value = f"%{value}" - self.value = value + _name: str + version: Optional[int] + + def __init__(self, name: str, version: int = 0) -> None: + assert isinstance(name, str) + # TODO: allow version to be None + assert isinstance(version, int) + if not name.startswith("%"): + name = f"%{name}" + self._name = name + self.version = version + value = name + if version > 0: + value = f"{name}:{version}" + super().__init__(value) @property def name(self) -> str: - return self.value.split(":")[0] - - @property - def version(self) -> int: - if ":" not in self.value: - return 0 - return int(self.value.split(":")[1]) + return self._name class IRLabel(IROperand): @@ -183,9 +190,18 @@ class IRLabel(IROperand): value: str def __init__(self, value: str, is_symbol: bool = False) -> None: - assert isinstance(value, str), "value must be an str" - self.value = value + assert isinstance(value, str), f"not a str: {value} ({type(value)})" + assert len(value) > 0 self.is_symbol = is_symbol + super().__init__(value) + + _IS_IDENTIFIER = re.compile("[0-9a-zA-Z_]*") + + def __repr__(self): + if self.__class__._IS_IDENTIFIER.fullmatch(self.value): + return self.value + + return json.dumps(self.value) # escape it class IRInstruction: @@ -200,7 +216,7 @@ class IRInstruction: opcode: str operands: list[IROperand] - output: Optional[IROperand] + output: Optional[IRVariable] # set of live variables at this instruction liveness: OrderedSet[IRVariable] parent: "IRBasicBlock" @@ -212,7 +228,7 @@ def __init__( self, opcode: str, operands: list[IROperand] | Iterator[IROperand], - output: Optional[IROperand] = None, + output: Optional[IRVariable] = None, ): assert isinstance(opcode, str), "opcode must be an str" assert isinstance(operands, list | Iterator), "operands must be a list" @@ -360,20 +376,6 @@ def get_ast_source(self) -> Optional[IRnode]: return inst.ast_source return self.parent.parent.ast_source - def str_short(self) -> str: - s = "" - if self.output: - s += f"{self.output} = " - opcode = f"{self.opcode} " if self.opcode != "store" else "" - s += opcode - operands = self.operands - if opcode not in ["jmp", "jnz", "invoke"]: - operands = list(reversed(operands)) - s += ", ".join( - [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands] - ) - return s - def __repr__(self) -> str: s = "" if self.output: @@ -381,14 +383,15 @@ def __repr__(self) -> str: opcode = f"{self.opcode} " if self.opcode != "store" else "" s += opcode operands = self.operands - if opcode not in ("jmp", "jnz", "invoke"): + if self.opcode == "invoke": + operands = [operands[0]] + list(reversed(operands[1:])) + elif self.opcode not in ("jmp", "jnz", "phi"): operands = reversed(operands) # type: ignore - s += ", ".join( - [(f"label %{op}" if isinstance(op, IRLabel) else str(op)) for op in operands] - ) + + s += ", ".join([(f"@{op}" if isinstance(op, IRLabel) else str(op)) for op in operands]) if self.annotation: - s += f" <{self.annotation}>" + s += f" ; {self.annotation}" return f"{s: <30}" @@ -446,6 +449,8 @@ def __init__(self, label: IRLabel, parent: "IRFunction") -> None: self.out_vars = OrderedSet() self.is_reachable = False + self._garbage_instructions: set[IRInstruction] = set() + def add_cfg_in(self, bb: "IRBasicBlock") -> None: self.cfg_in.add(bb) @@ -464,7 +469,7 @@ def remove_cfg_out(self, bb: "IRBasicBlock") -> None: self.cfg_out.remove(bb) def append_instruction( - self, opcode: str, *args: Union[IROperand, int], ret: IRVariable = None + self, opcode: str, *args: Union[IROperand, int], ret: Optional[IRVariable] = None ) -> Optional[IRVariable]: """ Append an instruction to the basic block @@ -520,13 +525,20 @@ def insert_instruction(self, instruction: IRInstruction, index: Optional[int] = instruction.error_msg = self.parent.error_msg self.instructions.insert(index, instruction) + def mark_for_removal(self, instruction: IRInstruction) -> None: + self._garbage_instructions.add(instruction) + + def clear_dead_instructions(self) -> None: + if len(self._garbage_instructions) > 0: + self.instructions = [ + inst for inst in self.instructions if inst not in self._garbage_instructions + ] + self._garbage_instructions.clear() + def remove_instruction(self, instruction: IRInstruction) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" self.instructions.remove(instruction) - def clear_instructions(self) -> None: - self.instructions = [] - @property def phi_instructions(self) -> Iterator[IRInstruction]: for inst in self.instructions: @@ -644,10 +656,12 @@ def copy(self): return bb def __repr__(self) -> str: - s = ( - f"{repr(self.label)}: IN={[bb.label for bb in self.cfg_in]}" - f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars}\n" - ) + s = f"{self.label}: ; IN={[bb.label for bb in self.cfg_in]}" + s += f" OUT={[bb.label for bb in self.cfg_out]} => {self.out_vars}\n" for instruction in self.instructions: - s += f" {str(instruction).strip()}\n" + s += f" {str(instruction).strip()}\n" + if len(self.instructions) > 30: + s += f" ; {self.label}\n" + if len(self.instructions) > 30 or self.parent.num_basic_blocks > 5: + s += f" ; ({self.parent.name})\n\n" return s diff --git a/vyper/venom/context.py b/vyper/venom/context.py index 0b0252d976..0c5cbc379c 100644 --- a/vyper/venom/context.py +++ b/vyper/venom/context.py @@ -1,14 +1,40 @@ +import textwrap +from dataclasses import dataclass, field from typing import Optional -from vyper.venom.basicblock import IRInstruction, IRLabel, IROperand +from vyper.venom.basicblock import IRLabel from vyper.venom.function import IRFunction +@dataclass +class DataItem: + data: IRLabel | bytes # can be raw data or bytes + + def __str__(self): + if isinstance(self.data, IRLabel): + return f"@{self.data}" + else: + assert isinstance(self.data, bytes) + return f'x"{self.data.hex()}"' + + +@dataclass +class DataSection: + label: IRLabel + data_items: list[DataItem] = field(default_factory=list) + + def __str__(self): + ret = [f"dbsection {self.label.value}:"] + for item in self.data_items: + ret.append(f" db {item}") + return "\n".join(ret) + + class IRContext: functions: dict[IRLabel, IRFunction] ctor_mem_size: Optional[int] immutables_len: Optional[int] - data_segment: list[IRInstruction] + data_segment: list[DataSection] last_label: int def __init__(self) -> None: @@ -47,11 +73,16 @@ def chain_basic_blocks(self) -> None: for fn in self.functions.values(): fn.chain_basic_blocks() - def append_data(self, opcode: str, args: list[IROperand]) -> None: + def append_data_section(self, name: IRLabel) -> None: + self.data_segment.append(DataSection(name)) + + def append_data_item(self, data: IRLabel | bytes) -> None: """ - Append data + Append data to current data section """ - self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + assert len(self.data_segment) > 0 + data_section = self.data_segment[-1] + data_section.data_items.append(DataItem(data)) def as_graph(self) -> str: s = ["digraph G {"] @@ -62,14 +93,15 @@ def as_graph(self) -> str: return "\n".join(s) def __repr__(self) -> str: - s = ["IRContext:"] + s = [] for fn in self.functions.values(): - s.append(fn.__repr__()) + s.append(IRFunction.__repr__(fn)) s.append("\n") if len(self.data_segment) > 0: - s.append("\nData segment:") - for inst in self.data_segment: - s.append(f"{inst}") + s.append("data readonly {") + for data_section in self.data_segment: + s.append(textwrap.indent(DataSection.__str__(data_section), " ")) + s.append("}") return "\n".join(s) diff --git a/vyper/venom/effects.py b/vyper/venom/effects.py index 97cffe2cb2..bbda481e14 100644 --- a/vyper/venom/effects.py +++ b/vyper/venom/effects.py @@ -44,6 +44,7 @@ def __iter__(self): "invoke": ALL, # could be smarter, look up the effects of the invoked function "log": LOG, "dloadbytes": MEMORY, + "dload": MEMORY, "returndatacopy": MEMORY, "calldatacopy": MEMORY, "codecopy": MEMORY, diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 2372f8ba52..f02da77fe3 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -1,3 +1,4 @@ +import textwrap from typing import Iterator, Optional from vyper.codegen.ir_node import IRnode @@ -41,7 +42,7 @@ def append_basic_block(self, bb: IRBasicBlock): Append basic block to function. """ assert isinstance(bb, IRBasicBlock), bb - assert bb.label.name not in self._basic_block_dict + assert bb.label.name not in self._basic_block_dict, bb.label self._basic_block_dict[bb.label.name] = bb def remove_basic_block(self, bb: IRBasicBlock): @@ -222,7 +223,10 @@ def _make_label(bb): return "\n".join(ret) def __repr__(self) -> str: - str = f"IRFunction: {self.name}\n" + ret = f"function {self.name} {{\n" for bb in self.get_basic_blocks(): - str += f"{bb}\n" - return str.strip() + bb_str = textwrap.indent(str(bb), " ") + ret += f"{bb_str}\n" + ret = ret.strip() + "\n}" + ret += f" ; close function {self.name}" + return ret diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 782309d841..f46457b77f 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -4,7 +4,6 @@ from vyper.codegen.ir_node import IRnode from vyper.evm.opcodes import get_opcodes -from vyper.utils import MemoryPositions from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -67,6 +66,8 @@ "mload", "iload", "istore", + "dload", + "dloadbytes", "sload", "sstore", "tload", @@ -366,17 +367,15 @@ def _convert_ir_bb(fn, ir, symbols): elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": - label = IRLabel(ir.args[0].value) - ctx.append_data("dbname", [label]) + label = IRLabel(ir.args[0].value, True) + ctx.append_data_section(label) for c in ir.args[1:]: - if isinstance(c, int): - assert 0 <= c <= 255, "data with invalid size" - ctx.append_data("db", [c]) # type: ignore - elif isinstance(c.value, bytes): - ctx.append_data("db", [c.value]) # type: ignore + if isinstance(c.value, bytes): + ctx.append_data_item(c.value) elif isinstance(c, IRnode): data = _convert_ir_bb(fn, c, symbols) - ctx.append_data("db", [data]) # type: ignore + assert isinstance(data, IRLabel) # help mypy + ctx.append_data_item(data) elif ir.value == "label": label = IRLabel(ir.args[0].value, True) bb = fn.get_basic_block() @@ -403,22 +402,6 @@ def _convert_ir_bb(fn, ir, symbols): else: bb.append_instruction("jmp", label) - elif ir.value == "dload": - arg_0 = _convert_ir_bb(fn, ir.args[0], symbols) - bb = fn.get_basic_block() - src = bb.append_instruction("add", arg_0, IRLabel("code_end")) - - bb.append_instruction("dloadbytes", 32, src, MemoryPositions.FREE_VAR_SPACE) - return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) - - elif ir.value == "dloadbytes": - dst, src_offset, len_ = _convert_ir_bb_list(fn, ir.args, symbols) - - bb = fn.get_basic_block() - src = bb.append_instruction("add", src_offset, IRLabel("code_end")) - bb.append_instruction("dloadbytes", len_, src, dst) - return None - elif ir.value == "mstore": # some upstream code depends on reversed order of evaluation -- # to fix upstream. diff --git a/vyper/venom/parser.py b/vyper/venom/parser.py index d219f271b3..5ccc29b7a4 100644 --- a/vyper/venom/parser.py +++ b/vyper/venom/parser.py @@ -1,3 +1,5 @@ +import json + from lark import Lark, Transformer from vyper.venom.basicblock import ( @@ -8,31 +10,35 @@ IROperand, IRVariable, ) -from vyper.venom.context import IRContext +from vyper.venom.context import DataItem, DataSection, IRContext from vyper.venom.function import IRFunction -VENOM_PARSER = Lark( - """ +VENOM_GRAMMAR = """ %import common.CNAME %import common.DIGIT + %import common.HEXDIGIT %import common.LETTER %import common.WS %import common.INT + %import common.SIGNED_INT + %import common.ESCAPED_STRING # Allow multiple comment styles COMMENT: ";" /[^\\n]*/ | "//" /[^\\n]*/ | "#" /[^\\n]*/ - start: function* data_section? + start: function* data_segment? # TODO: consider making entry block implicit, e.g. # `"{" instruction+ block* "}"` - function: "function" NAME "{" block* "}" + function: "function" LABEL_IDENT "{" block* "}" - data_section: "[data]" instruction* + data_segment: "data" "readonly" "{" data_section* "}" + data_section: "dbsection" LABEL_IDENT ":" data_item+ + data_item: "db" (HEXSTR | LABEL) - block: NAME ":" statement* + block: LABEL_IDENT ":" "\\n" statement* - statement: instruction | assignment + statement: (instruction | assignment) "\\n" assignment: VAR_IDENT "=" expr expr: instruction | operand instruction: OPCODE operands_list? @@ -41,16 +47,24 @@ operand: VAR_IDENT | CONST | LABEL - CONST: INT + CONST: SIGNED_INT OPCODE: CNAME - VAR_IDENT: "%" NAME - LABEL: "@" NAME + VAR_IDENT: "%" (DIGIT|LETTER|"_"|":")+ + + # handy for identifier to be an escaped string sometimes + # (especially for machine-generated labels) + LABEL_IDENT: (NAME | ESCAPED_STRING) + LABEL: "@" LABEL_IDENT + + DOUBLE_QUOTE: "\\"" NAME: (DIGIT|LETTER|"_")+ + HEXSTR: "x" DOUBLE_QUOTE (HEXDIGIT|"_")+ DOUBLE_QUOTE %ignore WS %ignore COMMENT """ -) + +VENOM_PARSER = Lark(VENOM_GRAMMAR) def _set_last_var(fn: IRFunction): @@ -83,24 +97,37 @@ def _ensure_terminated(bb): # TODO: raise error if still not terminated. -class _DataSegment: - def __init__(self, instructions): - self.instructions = instructions +def _unescape(s: str): + """ + Unescape the escaped string. This is the inverse of `IRLabel.__repr__()`. + """ + if s.startswith('"'): + return json.loads(s) + return s + + +class _TypedItem: + def __init__(self, children): + self.children = children + + +class _DataSegment(_TypedItem): + pass class VenomTransformer(Transformer): def start(self, children) -> IRContext: ctx = IRContext() - data_section = [] - if isinstance(children[-1], _DataSegment): - data_section = children.pop().instructions + if len(children) > 0 and isinstance(children[-1], _DataSegment): + ctx.data_segment = children.pop().children + funcs = children for fn_name, blocks in funcs: fn = ctx.create_function(fn_name) fn._basic_block_dict.clear() for block_name, instructions in blocks: - bb = IRBasicBlock(IRLabel(block_name), fn) + bb = IRBasicBlock(IRLabel(block_name, True), fn) fn.append_basic_block(bb) for instruction in instructions: @@ -112,8 +139,6 @@ def start(self, children) -> IRContext: _set_last_var(fn) _set_last_label(ctx) - ctx.data_segment = data_section - return ctx def function(self, children) -> tuple[str, list[tuple[str, list[IRInstruction]]]]: @@ -123,9 +148,25 @@ def function(self, children) -> tuple[str, list[tuple[str, list[IRInstruction]]] def statement(self, children): return children[0] - def data_section(self, children): + def data_segment(self, children): return _DataSegment(children) + def data_section(self, children): + label = IRLabel(children[0], True) + data_items = children[1:] + assert all(isinstance(item, DataItem) for item in data_items) + return DataSection(label, data_items) + + def data_item(self, children): + item = children[0] + if isinstance(item, IRLabel): + return DataItem(item) + assert item.startswith('x"') + assert item.endswith('"') + item = item.removeprefix('x"').removesuffix('"') + item = item.replace("_", "") + return DataItem(bytes.fromhex(item)) + def block(self, children) -> tuple[str, list[IRInstruction]]: label, *instructions = children return label, instructions @@ -152,8 +193,12 @@ def instruction(self, children) -> IRInstruction: # reverse operands, venom internally represents top of stack # as rightmost operand - if opcode not in ("jmp", "jnz", "invoke", "phi"): - # special cases: operands with labels look better un-reversed + if opcode == "invoke": + # reverse stack arguments but not label arg + # invoke + operands = [operands[0]] + list(reversed(operands[1:])) + # special cases: operands with labels look better un-reversed + elif opcode not in ("jmp", "jnz", "phi"): operands.reverse() return IRInstruction(opcode, operands) @@ -166,17 +211,15 @@ def operand(self, children) -> IROperand: def OPCODE(self, token): return token.value + def LABEL_IDENT(self, label) -> str: + return _unescape(label) + def LABEL(self, label) -> IRLabel: - return IRLabel(label[1:]) + label = _unescape(label[1:]) + return IRLabel(label, True) def VAR_IDENT(self, var_ident) -> IRVariable: - parts = var_ident[1:].split(":", maxsplit=1) - assert 1 <= len(parts) <= 2 - varname = parts[0] - version = None - if len(parts) > 1: - version = int(parts[1]) - return IRVariable(varname, version=version) + return IRVariable(var_ident[1:]) def CONST(self, val) -> IRLiteral: return IRLiteral(int(val)) diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py index fcd2aa1f22..a3227dcf4b 100644 --- a/vyper/venom/passes/__init__.py +++ b/vyper/venom/passes/__init__.py @@ -2,8 +2,12 @@ from .branch_optimization import BranchOptimizationPass from .dft import DFTPass from .float_allocas import FloatAllocas +from .literals_codesize import ReduceLiteralsCodesize +from .load_elimination import LoadElimination +from .lower_dload import LowerDloadPass from .make_ssa import MakeSSA from .mem2var import Mem2Var +from .memmerging import MemMergePass from .normalization import NormalizationPass from .remove_unused_variables import RemoveUnusedVariablesPass from .sccp import SCCP diff --git a/vyper/venom/passes/literals_codesize.py b/vyper/venom/passes/literals_codesize.py new file mode 100644 index 0000000000..daf195dfd4 --- /dev/null +++ b/vyper/venom/passes/literals_codesize.py @@ -0,0 +1,58 @@ +from vyper.utils import evm_not +from vyper.venom.basicblock import IRLiteral +from vyper.venom.passes.base_pass import IRPass + +# not takes 1 byte1, so it makes sense to use it when we can save at least +# 1 byte +NOT_THRESHOLD = 1 + +# shl takes 3 bytes, so it makes sense to use it when we can save at least +# 3 bytes +SHL_THRESHOLD = 3 + + +class ReduceLiteralsCodesize(IRPass): + def run_pass(self): + for bb in self.function.get_basic_blocks(): + self._process_bb(bb) + + def _process_bb(self, bb): + for inst in bb.instructions: + if inst.opcode != "store": + continue + + (op,) = inst.operands + if not isinstance(op, IRLiteral): + continue + + val = op.value % (2**256) + + # calculate amount of bits saved by not optimization + not_benefit = ((len(hex(val)) // 2 - len(hex(evm_not(val))) // 2) - NOT_THRESHOLD) * 8 + + # calculate amount of bits saved by shl optimization + binz = bin(val)[2:] + ix = len(binz) - binz.rfind("1") + shl_benefit = ix - SHL_THRESHOLD * 8 + + if not_benefit <= 0 and shl_benefit <= 0: + # no optimization can be done here + continue + + if not_benefit >= shl_benefit: + assert not_benefit > 0 # implied by previous conditions + # transform things like 0xffff...01 to (not 0xfe) + inst.opcode = "not" + op.value = evm_not(val) + continue + else: + assert shl_benefit > 0 # implied by previous conditions + # transform things like 0x123400....000 to 0x1234 << ... + ix -= 1 + # sanity check + assert (val >> ix) << ix == val, val + assert (val >> ix) & 1 == 1, val + + inst.opcode = "shl" + inst.operands = [IRLiteral(val >> ix), IRLiteral(ix)] + continue diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py new file mode 100644 index 0000000000..6701b588fe --- /dev/null +++ b/vyper/venom/passes/load_elimination.py @@ -0,0 +1,50 @@ +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, VarEquivalenceAnalysis +from vyper.venom.effects import Effects +from vyper.venom.passes.base_pass import IRPass + + +class LoadElimination(IRPass): + """ + Eliminate sloads, mloads and tloads + """ + + # should this be renamed to EffectsElimination? + + def run_pass(self): + self.equivalence = self.analyses_cache.request_analysis(VarEquivalenceAnalysis) + + for bb in self.function.get_basic_blocks(): + self._process_bb(bb, Effects.MEMORY, "mload", "mstore") + self._process_bb(bb, Effects.TRANSIENT, "tload", "tstore") + self._process_bb(bb, Effects.STORAGE, "sload", "sstore") + + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(DFGAnalysis) + + def equivalent(self, op1, op2): + return op1 == op2 or self.equivalence.equivalent(op1, op2) + + def _process_bb(self, bb, eff, load_opcode, store_opcode): + # not really a lattice even though it is not really inter-basic block; + # we may generalize in the future + lattice = () + + for inst in bb.instructions: + if eff in inst.get_write_effects(): + lattice = () + + if inst.opcode == store_opcode: + # mstore [val, ptr] + val, ptr = inst.operands + lattice = (ptr, val) + + if inst.opcode == load_opcode: + prev_lattice = lattice + (ptr,) = inst.operands + lattice = (ptr, inst.output) + if not prev_lattice: + continue + if not self.equivalent(ptr, prev_lattice[0]): + continue + inst.opcode = "store" + inst.operands = [prev_lattice[1]] diff --git a/vyper/venom/passes/lower_dload.py b/vyper/venom/passes/lower_dload.py new file mode 100644 index 0000000000..c863a1b7c7 --- /dev/null +++ b/vyper/venom/passes/lower_dload.py @@ -0,0 +1,42 @@ +from vyper.utils import MemoryPositions +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral +from vyper.venom.passes.base_pass import IRPass + + +class LowerDloadPass(IRPass): + """ + Lower dload and dloadbytes instructions + """ + + def run_pass(self): + for bb in self.function.get_basic_blocks(): + self._handle_bb(bb) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(DFGAnalysis) + + def _handle_bb(self, bb: IRBasicBlock): + fn = bb.parent + for idx, inst in enumerate(bb.instructions): + if inst.opcode == "dload": + (ptr,) = inst.operands + var = fn.get_next_variable() + bb.insert_instruction( + IRInstruction("add", [ptr, IRLabel("code_end")], output=var), index=idx + ) + idx += 1 + dst = IRLiteral(MemoryPositions.FREE_VAR_SPACE) + bb.insert_instruction( + IRInstruction("codecopy", [IRLiteral(32), var, dst]), index=idx + ) + + inst.opcode = "mload" + inst.operands = [dst] + elif inst.opcode == "dloadbytes": + _, src, _ = inst.operands + code_ptr = fn.get_next_variable() + bb.insert_instruction( + IRInstruction("add", [src, IRLabel("code_end")], output=code_ptr), index=idx + ) + inst.opcode = "codecopy" + inst.operands[1] = code_ptr diff --git a/vyper/venom/passes/make_ssa.py b/vyper/venom/passes/make_ssa.py index 56d3e1b7d3..ee013e0f1d 100644 --- a/vyper/venom/passes/make_ssa.py +++ b/vyper/venom/passes/make_ssa.py @@ -35,8 +35,8 @@ def _add_phi_nodes(self): Add phi nodes to the function. """ self._compute_defs() - work = {var: 0 for var in self.dom.dfs_walk} - has_already = {var: 0 for var in self.dom.dfs_walk} + work = {bb: 0 for bb in self.dom.dfs_walk} + has_already = {bb: 0 for bb in self.dom.dfs_walk} i = 0 # Iterate over all variables @@ -96,7 +96,6 @@ def _rename_vars(self, basic_block: IRBasicBlock): self.var_name_counters[v_name] = i + 1 inst.output = IRVariable(v_name, version=i) - # note - after previous line, inst.output.name != v_name outs.append(inst.output.name) for bb in basic_block.cfg_out: @@ -106,8 +105,9 @@ def _rename_vars(self, basic_block: IRBasicBlock): assert inst.output is not None, "Phi instruction without output" for i, op in enumerate(inst.operands): if op == basic_block.label: + var = inst.operands[i + 1] inst.operands[i + 1] = IRVariable( - inst.output.name, version=self.var_name_stacks[inst.output.name][-1] + var.name, version=self.var_name_stacks[var.name][-1] ) for bb in self.dom.dominated[basic_block]: diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py index f93924d449..9f985e2b0b 100644 --- a/vyper/venom/passes/mem2var.py +++ b/vyper/venom/passes/mem2var.py @@ -34,31 +34,30 @@ def _mk_varname(self, varname: str): def _process_alloca_var(self, dfg: DFGAnalysis, var: IRVariable): """ - Process alloca allocated variable. If it is only used by mstore/mload/return - instructions, it is promoted to a stack variable. Otherwise, it is left as is. + Process alloca allocated variable. If it is only used by + mstore/mload/return instructions, it is promoted to a stack variable. + Otherwise, it is left as is. """ uses = dfg.get_uses(var) - if all([inst.opcode == "mload" for inst in uses]): - return - elif all([inst.opcode == "mstore" for inst in uses]): + if not all([inst.opcode in ["mstore", "mload", "return"] for inst in uses]): return - elif all([inst.opcode in ["mstore", "mload", "return"] for inst in uses]): - var_name = self._mk_varname(var.name) - for inst in uses: - if inst.opcode == "mstore": - inst.opcode = "store" - inst.output = IRVariable(var_name) - inst.operands = [inst.operands[0]] - elif inst.opcode == "mload": - inst.opcode = "store" - inst.operands = [IRVariable(var_name)] - elif inst.opcode == "return": - bb = inst.parent - idx = len(bb.instructions) - 1 - assert inst == bb.instructions[idx] # sanity - bb.insert_instruction( - IRInstruction("mstore", [IRVariable(var_name), inst.operands[1]]), idx - ) + + var_name = self._mk_varname(var.name) + var = IRVariable(var_name) + for inst in uses: + if inst.opcode == "mstore": + inst.opcode = "store" + inst.output = var + inst.operands = [inst.operands[0]] + elif inst.opcode == "mload": + inst.opcode = "store" + inst.operands = [var] + elif inst.opcode == "return": + bb = inst.parent + idx = len(bb.instructions) - 1 + assert inst == bb.instructions[idx] # sanity + new_inst = IRInstruction("mstore", [var, inst.operands[1]]) + bb.insert_instruction(new_inst, idx) def _process_palloca_var(self, dfg: DFGAnalysis, palloca_inst: IRInstruction, var: IRVariable): """ @@ -70,16 +69,18 @@ def _process_palloca_var(self, dfg: DFGAnalysis, palloca_inst: IRInstruction, va return var_name = self._mk_varname(var.name) + var = IRVariable(var_name) + # some value given to us by the calling convention palloca_inst.opcode = "mload" palloca_inst.operands = [palloca_inst.operands[0]] - palloca_inst.output = IRVariable(var_name) + palloca_inst.output = var for inst in uses: if inst.opcode == "mstore": inst.opcode = "store" - inst.output = IRVariable(var_name) + inst.output = var inst.operands = [inst.operands[0]] elif inst.opcode == "mload": inst.opcode = "store" - inst.operands = [IRVariable(var_name)] + inst.operands = [var] diff --git a/vyper/venom/passes/memmerging.py b/vyper/venom/passes/memmerging.py new file mode 100644 index 0000000000..2e5ee46b84 --- /dev/null +++ b/vyper/venom/passes/memmerging.py @@ -0,0 +1,358 @@ +from bisect import bisect_left +from dataclasses import dataclass + +from vyper.evm.opcodes import version_check +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLiteral, IRVariable +from vyper.venom.effects import Effects +from vyper.venom.passes.base_pass import IRPass + + +@dataclass +class _Interval: + start: int + length: int + + @property + def end(self): + return self.start + self.length + + +@dataclass +class _Copy: + # abstract "copy" operation which contains a list of copy instructions + # and can fuse them into a single copy operation. + dst: int + src: int + length: int + insts: list[IRInstruction] + + @classmethod + def memzero(cls, dst, length, insts): + # factory method to simplify creation of memory zeroing operations + # (which are similar to Copy operations but src is always + # `calldatasize`). choose src=dst, so that can_merge returns True + # for overlapping memzeros. + return cls(dst, dst, length, insts) + + @property + def src_end(self) -> int: + return self.src + self.length + + @property + def dst_end(self) -> int: + return self.dst + self.length + + def src_interval(self) -> _Interval: + return _Interval(self.src, self.length) + + def dst_interval(self) -> _Interval: + return _Interval(self.dst, self.length) + + def overwrites_self_src(self) -> bool: + # return true if dst overlaps src. this is important for blocking + # mcopy batching in certain cases. + return self.overwrites(self.src_interval()) + + def overwrites(self, interval: _Interval) -> bool: + # return true if dst of self overwrites the interval + a = max(self.dst, interval.start) + b = min(self.dst_end, interval.end) + return a < b + + def can_merge(self, other: "_Copy"): + # both source and destination have to be offset by same amount, + # otherwise they do not represent the same copy. e.g. + # Copy(0, 64, 16) + # Copy(11, 74, 16) + if self.src - other.src != self.dst - other.dst: + return False + + # the copies must at least touch each other + if other.dst > self.dst_end: + return False + + return True + + def merge(self, other: "_Copy"): + # merge other into self. e.g. + # Copy(0, 64, 16); Copy(16, 80, 8) => Copy(0, 64, 24) + + assert self.dst <= other.dst, "bad bisect_left" + assert self.can_merge(other) + + new_length = max(self.dst_end, other.dst_end) - self.dst + self.length = new_length + self.insts.extend(other.insts) + + def __repr__(self) -> str: + return f"({self.src}, {self.src_end}, {self.length}, {self.dst}, {self.dst_end})" + + +class MemMergePass(IRPass): + dfg: DFGAnalysis + _copies: list[_Copy] + _loads: dict[IRVariable, int] + + def run_pass(self): + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) # type: ignore + + for bb in self.function.get_basic_blocks(): + self._handle_bb_memzero(bb) + self._handle_bb(bb, "calldataload", "calldatacopy", allow_dst_overlaps_src=True) + self._handle_bb(bb, "dload", "dloadbytes", allow_dst_overlaps_src=True) + + if version_check(begin="cancun"): + # mcopy is available + self._handle_bb(bb, "mload", "mcopy") + + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + + def _optimize_copy(self, bb: IRBasicBlock, copy_opcode: str, load_opcode: str): + for copy in self._copies: + copy.insts.sort(key=bb.instructions.index) + + if copy_opcode == "mcopy": + assert not copy.overwrites_self_src() + + pin_inst = None + inst = copy.insts[-1] + if copy.length != 32 or load_opcode == "dload": + inst.output = None + inst.opcode = copy_opcode + inst.operands = [IRLiteral(copy.length), IRLiteral(copy.src), IRLiteral(copy.dst)] + elif inst.opcode == "mstore": + # we already have a load which is the val for this mstore; + # leave it in place. + var, _ = inst.operands + assert isinstance(var, IRVariable) # help mypy + pin_inst = self.dfg.get_producing_instruction(var) + assert pin_inst is not None # help mypy + + else: + # we are converting an mcopy into an mload+mstore (mload+mstore + # is 1 byte smaller than mcopy). + index = inst.parent.instructions.index(inst) + var = bb.parent.get_next_variable() + load = IRInstruction(load_opcode, [IRLiteral(copy.src)], output=var) + inst.parent.insert_instruction(load, index) + + inst.output = None + inst.opcode = "mstore" + inst.operands = [var, IRLiteral(copy.dst)] + + for inst in copy.insts[:-1]: + if inst.opcode == load_opcode: + if inst is pin_inst: + continue + + # if the load is used by any instructions besides the ones + # we are removing, we can't delete it. (in the future this + # may be handled by "remove unused effects" pass). + assert isinstance(inst.output, IRVariable) # help mypy + uses = self.dfg.get_uses(inst.output) + if not all(use in copy.insts for use in uses): + continue + + bb.mark_for_removal(inst) + + self._copies.clear() + self._loads.clear() + + def _write_after_write_hazard(self, new_copy: _Copy) -> bool: + for copy in self._copies: + # note, these are the same: + # - new_copy.overwrites(copy.dst_interval()) + # - copy.overwrites(new_copy.dst_interval()) + if new_copy.overwrites(copy.dst_interval()) and not ( + copy.can_merge(new_copy) or new_copy.can_merge(copy) + ): + return True + return False + + def _read_after_write_hazard(self, new_copy: _Copy) -> bool: + new_copies = self._copies + [new_copy] + + # new copy would overwrite memory that + # needs to be read to optimize copy + if any(new_copy.overwrites(copy.src_interval()) for copy in new_copies): + return True + + # existing copies would overwrite memory that the + # new copy would need + if self._overwrites(new_copy.src_interval()): + return True + + return False + + def _find_insertion_point(self, new_copy: _Copy): + return bisect_left(self._copies, new_copy.dst, key=lambda c: c.dst) + + def _add_copy(self, new_copy: _Copy): + index = self._find_insertion_point(new_copy) + self._copies.insert(index, new_copy) + + i = max(index - 1, 0) + while i < min(index + 1, len(self._copies) - 1): + if self._copies[i].can_merge(self._copies[i + 1]): + self._copies[i].merge(self._copies[i + 1]) + del self._copies[i + 1] + else: + i += 1 + + def _overwrites(self, read_interval: _Interval) -> bool: + # check if any of self._copies tramples the interval + + # could use bisect_left to optimize, but it's harder to reason about + return any(c.overwrites(read_interval) for c in self._copies) + + def _handle_bb( + self, + bb: IRBasicBlock, + load_opcode: str, + copy_opcode: str, + allow_dst_overlaps_src: bool = False, + ): + self._loads = {} + self._copies = [] + + def _barrier(): + self._optimize_copy(bb, copy_opcode, load_opcode) + + # copy in necessary because there is a possibility + # of insertion in optimizations + for inst in bb.instructions.copy(): + if inst.opcode == load_opcode: + src_op = inst.operands[0] + if not isinstance(src_op, IRLiteral): + _barrier() + continue + + read_interval = _Interval(src_op.value, 32) + + # we will read from this memory so we need to put barier + if not allow_dst_overlaps_src and self._overwrites(read_interval): + _barrier() + + assert inst.output is not None + self._loads[inst.output] = src_op.value + + elif inst.opcode == "mstore": + var, dst = inst.operands + + if not isinstance(var, IRVariable) or not isinstance(dst, IRLiteral): + _barrier() + continue + + if var not in self._loads: + _barrier() + continue + + src_ptr = self._loads[var] + load_inst = self.dfg.get_producing_instruction(var) + assert load_inst is not None # help mypy + n_copy = _Copy(dst.value, src_ptr, 32, [inst, load_inst]) + + if self._write_after_write_hazard(n_copy): + _barrier() + # no continue needed, we have not invalidated the loads dict + + # check if the new copy does not overwrites existing data + if not allow_dst_overlaps_src and self._read_after_write_hazard(n_copy): + _barrier() + # this continue is necessary because we have invalidated + # the _loads dict, so src_ptr is no longer valid. + continue + self._add_copy(n_copy) + + elif inst.opcode == copy_opcode: + if not all(isinstance(op, IRLiteral) for op in inst.operands): + _barrier() + continue + + length, src, dst = inst.operands + n_copy = _Copy(dst.value, src.value, length.value, [inst]) + + if self._write_after_write_hazard(n_copy): + _barrier() + # check if the new copy does not overwrites existing data + if not allow_dst_overlaps_src and self._read_after_write_hazard(n_copy): + _barrier() + self._add_copy(n_copy) + + elif _volatile_memory(inst): + _barrier() + + _barrier() + bb.clear_dead_instructions() + + # optimize memzeroing operations + def _optimize_memzero(self, bb: IRBasicBlock): + for copy in self._copies: + inst = copy.insts[-1] + if copy.length == 32: + inst.opcode = "mstore" + inst.operands = [IRLiteral(0), IRLiteral(copy.dst)] + else: + index = bb.instructions.index(inst) + calldatasize = bb.parent.get_next_variable() + bb.insert_instruction(IRInstruction("calldatasize", [], output=calldatasize), index) + + inst.output = None + inst.opcode = "calldatacopy" + inst.operands = [IRLiteral(copy.length), calldatasize, IRLiteral(copy.dst)] + + for inst in copy.insts[:-1]: + bb.mark_for_removal(inst) + + self._copies.clear() + self._loads.clear() + + def _handle_bb_memzero(self, bb: IRBasicBlock): + self._loads = {} + self._copies = [] + + def _barrier(): + self._optimize_memzero(bb) + + # copy in necessary because there is a possibility + # of insertion in optimizations + for inst in bb.instructions.copy(): + if inst.opcode == "mstore": + val = inst.operands[0] + dst = inst.operands[1] + is_zero_literal = isinstance(val, IRLiteral) and val.value == 0 + if not (isinstance(dst, IRLiteral) and is_zero_literal): + _barrier() + continue + n_copy = _Copy.memzero(dst.value, 32, [inst]) + assert not self._write_after_write_hazard(n_copy) + self._add_copy(n_copy) + elif inst.opcode == "calldatacopy": + length, var, dst = inst.operands + if not isinstance(var, IRVariable): + _barrier() + continue + if not isinstance(dst, IRLiteral) or not isinstance(length, IRLiteral): + _barrier() + continue + src_inst = self.dfg.get_producing_instruction(var) + assert src_inst is not None, f"bad variable {var}" + if src_inst.opcode != "calldatasize": + _barrier() + continue + n_copy = _Copy.memzero(dst.value, length.value, [inst]) + assert not self._write_after_write_hazard(n_copy) + self._add_copy(n_copy) + elif _volatile_memory(inst): + _barrier() + continue + + _barrier() + bb.clear_dead_instructions() + + +def _volatile_memory(inst): + inst_effects = inst.get_read_effects() | inst.get_write_effects() + return Effects.MEMORY in inst_effects or Effects.MSIZE in inst_effects diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 7ca242c74e..37ba1023c9 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -45,9 +45,10 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB inst.operands[i] = split_bb.label # Update the labels in the data segment - for inst in fn.ctx.data_segment: - if inst.opcode == "db" and inst.operands[0] == bb.label: - inst.operands[0] = split_bb.label + for data_section in fn.ctx.data_segment: + for item in data_section.data_items: + if item.data == bb.label: + item.data = split_bb.label return split_bb diff --git a/vyper/venom/passes/sccp/eval.py b/vyper/venom/passes/sccp/eval.py index b5786bb304..99f0ba70d9 100644 --- a/vyper/venom/passes/sccp/eval.py +++ b/vyper/venom/passes/sccp/eval.py @@ -5,6 +5,7 @@ SizeLimits, evm_div, evm_mod, + evm_not, evm_pow, signed_to_unsigned, unsigned_to_signed, @@ -95,11 +96,6 @@ def _evm_sar(shift_len: int, value: int) -> int: return value >> shift_len -def _evm_not(value: int) -> int: - assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" - return SizeLimits.MAX_UINT256 ^ value - - ARITHMETIC_OPS: dict[str, Callable[[list[IROperand]], int]] = { "add": _wrap_binop(operator.add), "sub": _wrap_binop(operator.sub), @@ -122,7 +118,7 @@ def _evm_not(value: int) -> int: "or": _wrap_binop(operator.or_), "and": _wrap_binop(operator.and_), "xor": _wrap_binop(operator.xor), - "not": _wrap_unop(_evm_not), + "not": _wrap_unop(evm_not), "signextend": _wrap_binop(_evm_signextend), "iszero": _wrap_unop(_evm_iszero), "shr": _wrap_binop(_evm_shr), diff --git a/vyper/venom/passes/store_elimination.py b/vyper/venom/passes/store_elimination.py index 559205adc8..a4f217505b 100644 --- a/vyper/venom/passes/store_elimination.py +++ b/vyper/venom/passes/store_elimination.py @@ -1,4 +1,4 @@ -from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis from vyper.venom.basicblock import IRVariable from vyper.venom.passes.base_pass import IRPass @@ -9,38 +9,37 @@ class StoreElimination(IRPass): and removes the `store` instruction. """ + # TODO: consider renaming `store` instruction, since it is confusing + # with LoadElimination + def run_pass(self): - self.analyses_cache.request_analysis(CFGAnalysis) - dfg = self.analyses_cache.request_analysis(DFGAnalysis) + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) - for var, inst in dfg.outputs.items(): + for var, inst in self.dfg.outputs.items(): if inst.opcode != "store": continue - if not isinstance(inst.operands[0], IRVariable): - continue - if inst.operands[0].name in ["%ret_ofst", "%ret_size"]: - continue - if inst.output.name in ["%ret_ofst", "%ret_size"]: - continue - self._process_store(dfg, inst, var, inst.operands[0]) + self._process_store(inst, var, inst.operands[0]) self.analyses_cache.invalidate_analysis(LivenessAnalysis) self.analyses_cache.invalidate_analysis(DFGAnalysis) - def _process_store(self, dfg, inst, var: IRVariable, new_var: IRVariable): + def _process_store(self, inst, var: IRVariable, new_var: IRVariable): """ Process store instruction. If the variable is only used by a load instruction, forward the variable to the load instruction. """ - if any([inst.opcode == "phi" for inst in dfg.get_uses(new_var)]): + if any([inst.opcode == "phi" for inst in self.dfg.get_uses(new_var)]): return - uses = dfg.get_uses(var) + uses = self.dfg.get_uses(var) if any([inst.opcode == "phi" for inst in uses]): return - for use_inst in uses: + for use_inst in uses.copy(): for i, operand in enumerate(use_inst.operands): if operand == var: use_inst.operands[i] = new_var + self.dfg.add_use(new_var, use_inst) + self.dfg.remove_use(var, use_inst) + inst.parent.remove_instruction(inst) diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 9b52b842ba..048555a221 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -41,6 +41,7 @@ "calldatacopy", "mcopy", "calldataload", + "codecopy", "gas", "gasprice", "gaslimit", @@ -121,6 +122,11 @@ def apply_line_numbers(inst: IRInstruction, asm) -> list[str]: return ret # type: ignore +def _as_asm_symbol(label: IRLabel) -> str: + # Lower an IRLabel to an assembly symbol + return f"_sym_{label.value}" + + # TODO: "assembly" gets into the recursion due to how the original # IR was structured recursively in regards with the deploy instruction. # There, recursing into the deploy instruction was by design, and @@ -182,19 +188,19 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: asm.extend(_REVERT_POSTAMBLE) # Append data segment - data_segments: dict = dict() - for inst in ctx.data_segment: - if inst.opcode == "dbname": - label = inst.operands[0].value - data_segments[label] = [DataHeader(f"_sym_{label}")] - elif inst.opcode == "db": - data = inst.operands[0] + for data_section in ctx.data_segment: + label = data_section.label + asm_data_section: list[Any] = [] + asm_data_section.append(DataHeader(_as_asm_symbol(label))) + for item in data_section.data_items: + data = item.data if isinstance(data, IRLabel): - data_segments[label].append(f"_sym_{data.value}") + asm_data_section.append(_as_asm_symbol(data)) else: - data_segments[label].append(data) + assert isinstance(data, bytes) + asm_data_section.append(data) - asm.extend(list(data_segments.values())) + asm.append(asm_data_section) if no_optimize is False: optimize_assembly(top_asm) @@ -259,7 +265,7 @@ def _emit_input_operands( # invoke emits the actual instruction itself so we don't need # to emit it here but we need to add it to the stack map if inst.opcode != "invoke": - assembly.append(f"_sym_{op.value}") + assembly.append(_as_asm_symbol(op)) stack.push(op) continue @@ -293,7 +299,7 @@ def _generate_evm_for_basicblock_r( asm = [] # assembly entry point into the block - asm.append(f"_sym_{basicblock.label}") + asm.append(_as_asm_symbol(basicblock.label)) asm.append("JUMPDEST") if len(basicblock.cfg_in) == 1: @@ -408,7 +414,9 @@ def _generate_evm_for_instruction( return apply_line_numbers(inst, assembly) if opcode == "offset": - assembly.extend(["_OFST", f"_sym_{inst.operands[1].value}", inst.operands[0].value]) + ofst, label = inst.operands + assert isinstance(label, IRLabel) # help mypy + assembly.extend(["_OFST", _as_asm_symbol(label), ofst.value]) assert isinstance(inst.output, IROperand), "Offset must have output" stack.push(inst.output) return apply_line_numbers(inst, assembly) @@ -470,26 +478,26 @@ def _generate_evm_for_instruction( pass elif opcode == "store": pass - elif opcode == "dbname": - pass elif opcode in ["codecopy", "dloadbytes"]: assembly.append("CODECOPY") + elif opcode == "dbname": + pass elif opcode == "jnz": # jump if not zero - if_nonzero_label = inst.operands[1] - if_zero_label = inst.operands[2] - assembly.append(f"_sym_{if_nonzero_label.value}") + if_nonzero_label, if_zero_label = inst.get_label_operands() + assembly.append(_as_asm_symbol(if_nonzero_label)) assembly.append("JUMPI") # make sure the if_zero_label will be optimized out # assert if_zero_label == next(iter(inst.parent.cfg_out)).label - assembly.append(f"_sym_{if_zero_label.value}") + assembly.append(_as_asm_symbol(if_zero_label)) assembly.append("JUMP") elif opcode == "jmp": - assert isinstance(inst.operands[0], IRLabel) - assembly.append(f"_sym_{inst.operands[0].value}") + (target,) = inst.operands + assert isinstance(target, IRLabel) + assembly.append(_as_asm_symbol(target)) assembly.append("JUMP") elif opcode == "djmp": assert isinstance( @@ -504,7 +512,7 @@ def _generate_evm_for_instruction( assembly.extend( [ f"_sym_label_ret_{self.label_counter}", - f"_sym_{target.value}", + _as_asm_symbol(target), "JUMP", f"_sym_label_ret_{self.label_counter}", "JUMPDEST",