diff --git a/frontends/PyRTG/src/CMakeLists.txt b/frontends/PyRTG/src/CMakeLists.txt index 3d082afc8402..5a5732b592b0 100644 --- a/frontends/PyRTG/src/CMakeLists.txt +++ b/frontends/PyRTG/src/CMakeLists.txt @@ -14,6 +14,10 @@ declare_mlir_python_sources(PyRTGSources ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" SOURCES pyrtg/__init__.py + pyrtg/core.py + pyrtg/labels.py + pyrtg/rtg.py + pyrtg/support.py pyrtg/tests.py rtgtool/rtgtool.py ) diff --git a/frontends/PyRTG/src/pyrtg/__init__.py b/frontends/PyRTG/src/pyrtg/__init__.py index 4c205c453dca..fe0485aa8627 100644 --- a/frontends/PyRTG/src/pyrtg/__init__.py +++ b/frontends/PyRTG/src/pyrtg/__init__.py @@ -2,6 +2,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from . import circt from . import tests +from . import core from .tests import test +from .labels import Label +from .rtg import rtg diff --git a/frontends/PyRTG/src/pyrtg/core.py b/frontends/PyRTG/src/pyrtg/core.py new file mode 100644 index 000000000000..bb63a507cef2 --- /dev/null +++ b/frontends/PyRTG/src/pyrtg/core.py @@ -0,0 +1,32 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .circt import ir + + +class CodeGenRoot: + """ + This is the base class for classes that have to be visited by the RTG tool + during codegen. + """ + + def codegen(self): + assert False, "must be implemented by the subclass" + + +class Value: + """ + This class wraps around MLIR SSA values to provide a more Python native + experience. Instead of having a value class that stores the type, classes + deriving from this class represent specific types of values. Operations on + those values can then be exposed as methods that can support more convenient + bridging between Python values and MLIR values (e.g., accepting a Python + integer and automatically building a ConstantOp in MLIR). + """ + + def get_type(self) -> ir.Type: + assert False, "must be implemented by subclass" + + def _get_ssa_value(self) -> ir.Value: + assert False, "must be implemented by subclass" diff --git a/frontends/PyRTG/src/pyrtg/labels.py b/frontends/PyRTG/src/pyrtg/labels.py new file mode 100644 index 000000000000..bb360810773b --- /dev/null +++ b/frontends/PyRTG/src/pyrtg/labels.py @@ -0,0 +1,56 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from __future__ import annotations + +from .circt import ir +from .core import Value +from .rtg import rtg + + +class Label(Value): + """ + Represents an ISA Assembly label. It can be declared and then passed around + like every value. To place a label at a specific location in a sequence call + 'place'. It is the user's responsibility to place a label such that if the + label is used by an instruction in the fully randomized test, there exists + exactly one placement of the label to not end up with ambiguity or usage of + an undeclared label. + """ + + def __init__(self, value: ir.Value): + self._value = value + + def declare(string: str) -> Label: + """ + Declares a label with a fixed name. Labels returned by different calls to + this function but with the same arguments refer to the same label. + """ + + return rtg.LabelDeclOp(string, []) + + def declare_unique(string: str) -> Label: + """ + Declares a unique label. This means, all usages of the value returned by this + function will refer to the same label, but no other label declarations can + conflict with this label, including labels returned by other calls to this + function or fixed labels declared with 'declare_label'. + """ + + return rtg.LabelUniqueDeclOp(string, []) + + def place( + self, + visibility: rtg.LabelVisibility = rtg.LabelVisibility.LOCAL) -> None: + """ + Places a declared label in a sequence or test. + """ + + return rtg.LabelOp(rtg.LabelVisibilityAttr.get(visibility), self._value) + + def get_type(self) -> ir.Type: + return rtg.LabelType.get() + + def _get_ssa_value(self) -> ir.Value: + return self._value diff --git a/frontends/PyRTG/src/pyrtg/rtg.py b/frontends/PyRTG/src/pyrtg/rtg.py new file mode 100644 index 000000000000..d14c16b311d1 --- /dev/null +++ b/frontends/PyRTG/src/pyrtg/rtg.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .support import wrap_opviews_with_values +from .circt.dialects import rtg + +wrap_opviews_with_values(rtg, rtg.__name__) diff --git a/frontends/PyRTG/src/pyrtg/support.py b/frontends/PyRTG/src/pyrtg/support.py new file mode 100644 index 000000000000..f752f981fcc6 --- /dev/null +++ b/frontends/PyRTG/src/pyrtg/support.py @@ -0,0 +1,63 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .circt import support, ir +from .core import Value + + +def _FromCirctValue(value: ir.Value) -> Value: + type = support.type_to_pytype(value.type) + from .rtg import rtg + if isinstance(type, rtg.LabelType): + from .labels import Label + return Label(value) + assert False, "Unsupported value" + + +def wrap_opviews_with_values(dialect, module_name, excluded=[]): + """ + Wraps all of a dialect's OpView classes to have their create method return a + Value instead of an OpView. + """ + + import sys + module = sys.modules[module_name] + + for attr in dir(dialect): + cls = getattr(dialect, attr) + + if attr not in excluded and isinstance(cls, type) and issubclass( + cls, ir.OpView): + + def specialize_create(cls): + + def create(*args, **kwargs): + # If any of the arguments are 'pyrtg.Value', we need to convert them. + def to_circt(arg): + if isinstance(arg, (list, tuple)): + return [to_circt(a) for a in arg] + return arg + + args = [to_circt(arg) for arg in args] + kwargs = {k: to_circt(v) for k, v in kwargs.items()} + # Create the OpView. + if hasattr(cls, "create"): + created = cls.create(*args, **kwargs) + else: + created = cls(*args, **kwargs) + if isinstance(created, support.NamedValueOpView): + created = created.opview + + # Return the wrapped values, if any. + converted_results = tuple( + _FromCirctValue(res) for res in created.results) + return converted_results[0] if len( + converted_results) == 1 else created + + return create + + wrapped_class = specialize_create(cls) + setattr(module, attr, wrapped_class) + else: + setattr(module, attr, cls) diff --git a/frontends/PyRTG/src/pyrtg/tests.py b/frontends/PyRTG/src/pyrtg/tests.py index 4743710c2aef..10315a3674c4 100644 --- a/frontends/PyRTG/src/pyrtg/tests.py +++ b/frontends/PyRTG/src/pyrtg/tests.py @@ -5,10 +5,11 @@ import inspect from .circt import ir -from .circt.dialects import rtg +from .core import CodeGenRoot +from .rtg import rtg -class Test: +class Test(CodeGenRoot): """ Represents an RTG Test. Stores the test function and location. """ diff --git a/frontends/PyRTG/src/rtgtool/rtgtool.py b/frontends/PyRTG/src/rtgtool/rtgtool.py index c7d9c5f5e3a3..4e35d95ae2e6 100644 --- a/frontends/PyRTG/src/rtgtool/rtgtool.py +++ b/frontends/PyRTG/src/rtgtool/rtgtool.py @@ -128,7 +128,7 @@ def frontend_codegen(args: argparse.Namespace) -> ir.Module: module = ir.Module.create() with ir.InsertionPoint(module.body): for _, obj in inspect.getmembers(file): - if isinstance(obj, pyrtg.tests.Test): + if isinstance(obj, pyrtg.core.CodeGenRoot): obj.codegen() return module diff --git a/frontends/PyRTG/test/basic.py b/frontends/PyRTG/test/basic.py index b575ddce23f1..ab145d847e95 100644 --- a/frontends/PyRTG/test/basic.py +++ b/frontends/PyRTG/test/basic.py @@ -1,19 +1,57 @@ # RUN: %rtgtool% %s --seed=0 --output-format=mlir | FileCheck %s --check-prefix=MLIR # RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED -# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm | FileCheck %s --input-file=%t --check-prefix=ASM +# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM -from pyrtg import test +from pyrtg import test, Label, rtg -# MLIR: rtg.test @test0 +# MLIR-LABEL: rtg.test @test0 # MLIR-NEXT: } -# ELABORATED: rtg.test @test0 +# ELABORATED-LABEL: rtg.test @test0 # ELABORATED-NEXT: } -# ASM: Begin of test0 +# ASM-LABEL: Begin of test0 # ASM: End of test0 @test def test0(): pass + + +# MLIR-LABEL: rtg.test @test_labels +# MLIR-NEXT: [[L0:%.+]] = rtg.label_decl "l0" +# MLIR-NEXT: [[L1:%.+]] = rtg.label_unique_decl "l1" +# MLIR-NEXT: [[L2:%.+]] = rtg.label_unique_decl "l1" +# MLIR-NEXT: rtg.label global [[L0]] +# MLIR-NEXT: rtg.label external [[L1]] +# MLIR-NEXT: rtg.label local [[L2]] +# MLIR-NEXT: } + +# ELABORATED-LABEL: rtg.test @test_labels +# ELABORATED-NEXT: [[L0:%.+]] = rtg.label_decl "l0" +# ELABORATED-NEXT: rtg.label global [[L0]] +# ELABORATED-NEXT: [[L1:%.+]] = rtg.label_decl "l1_0" +# ELABORATED-NEXT: rtg.label external [[L1]] +# ELABORATED-NEXT: [[L2:%.+]] = rtg.label_decl "l1_1" +# ELABORATED-NEXT: rtg.label local [[L2]] +# ELABORATED-NEXT: } + +# ASM-LABEL: Begin of test_labels +# ASM-EMPTY: +# ASM-NEXT: .global l0 +# ASM-NEXT: l0: +# ASM-NEXT: .extern l1_0 +# ASM-NEXT: l1_1: +# ASM-EMPTY: +# ASM: End of test_labels + + +@test +def test_labels(): + l0 = Label.declare("l0") + l1 = Label.declare_unique("l1") + l2 = Label.declare_unique("l1") + l0.place(rtg.LabelVisibility.GLOBAL) + l1.place(rtg.LabelVisibility.EXTERNAL) + l2.place()