Skip to content

Commit

Permalink
[PyRTG] Wrapper around SSA values and label support
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Feb 19, 2025
1 parent 435dcee commit ff9c4c2
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 9 deletions.
4 changes: 4 additions & 0 deletions frontends/PyRTG/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ declare_mlir_python_sources(PyRTGSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
SOURCES
pyrtg/__init__.py
pyrtg/core.py
pyrtg/labels.py
pyrtg/rtg.py
pyrtg/support.py
pyrtg/tests.py
rtgtool/rtgtool.py
)
Expand Down
4 changes: 3 additions & 1 deletion frontends/PyRTG/src/pyrtg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import circt
from . import tests
from . import core
from .tests import test
from .labels import Label
from .rtg import rtg
32 changes: 32 additions & 0 deletions frontends/PyRTG/src/pyrtg/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .circt import ir


class CodeGenRoot:
"""
This is the base class for classes that have to be visited by the RTG tool
during codegen.
"""

def codegen(self):
assert False, "must be implemented by the subclass"


class Value:
"""
This class wraps around MLIR SSA values to provide a more Python native
experience. Instead of having a value class that stores the type, classes
deriving from this class represent specific types of values. Operations on
those values can then be exposed as methods that can support more convenient
bridging between Python values and MLIR values (e.g., accepting a Python
integer and automatically building a ConstantOp in MLIR).
"""

def get_type(self) -> ir.Type:
assert False, "must be implemented by subclass"

def _get_ssa_value(self) -> ir.Value:
assert False, "must be implemented by subclass"
56 changes: 56 additions & 0 deletions frontends/PyRTG/src/pyrtg/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations

from .circt import ir
from .core import Value
from .rtg import rtg


class Label(Value):
"""
Represents an ISA Assembly label. It can be declared and then passed around
like every value. To place a label at a specific location in a sequence call
'place'. It is the user's responsibility to place a label such that if the
label is used by an instruction in the fully randomized test, there exists
exactly one placement of the label to not end up with ambiguity or usage of
an undeclared label.
"""

def __init__(self, value: ir.Value):
self._value = value

def declare(string: str) -> Label:
"""
Declares a label with a fixed name. Labels returned by different calls to
this function but with the same arguments refer to the same label.
"""

return rtg.LabelDeclOp(string, [])

def declare_unique(string: str) -> Label:
"""
Declares a unique label. This means, all usages of the value returned by this
function will refer to the same label, but no other label declarations can
conflict with this label, including labels returned by other calls to this
function or fixed labels declared with 'declare_label'.
"""

return rtg.LabelUniqueDeclOp(string, [])

def place(
self,
visibility: rtg.LabelVisibility = rtg.LabelVisibility.LOCAL) -> None:
"""
Places a declared label in a sequence or test.
"""

return rtg.LabelOp(rtg.LabelVisibilityAttr.get(visibility), self._value)

def get_type(self) -> ir.Type:
return rtg.LabelType.get()

def _get_ssa_value(self) -> ir.Value:
return self._value
8 changes: 8 additions & 0 deletions frontends/PyRTG/src/pyrtg/rtg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .support import wrap_opviews_with_values
from .circt.dialects import rtg

wrap_opviews_with_values(rtg, rtg.__name__)
63 changes: 63 additions & 0 deletions frontends/PyRTG/src/pyrtg/support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .circt import support, ir
from .core import Value


def _FromCirctValue(value: ir.Value) -> Value:
type = support.type_to_pytype(value.type)
from .rtg import rtg
if isinstance(type, rtg.LabelType):
from .labels import Label
return Label(value)
assert False, "Unsupported value"


def wrap_opviews_with_values(dialect, module_name, excluded=[]):
"""
Wraps all of a dialect's OpView classes to have their create method return a
Value instead of an OpView.
"""

import sys
module = sys.modules[module_name]

for attr in dir(dialect):
cls = getattr(dialect, attr)

if attr not in excluded and isinstance(cls, type) and issubclass(
cls, ir.OpView):

def specialize_create(cls):

def create(*args, **kwargs):
# If any of the arguments are 'pyrtg.Value', we need to convert them.
def to_circt(arg):
if isinstance(arg, (list, tuple)):
return [to_circt(a) for a in arg]
return arg

args = [to_circt(arg) for arg in args]
kwargs = {k: to_circt(v) for k, v in kwargs.items()}
# Create the OpView.
if hasattr(cls, "create"):
created = cls.create(*args, **kwargs)
else:
created = cls(*args, **kwargs)
if isinstance(created, support.NamedValueOpView):
created = created.opview

# Return the wrapped values, if any.
converted_results = tuple(
_FromCirctValue(res) for res in created.results)
return converted_results[0] if len(
converted_results) == 1 else created

return create

wrapped_class = specialize_create(cls)
setattr(module, attr, wrapped_class)
else:
setattr(module, attr, cls)
5 changes: 3 additions & 2 deletions frontends/PyRTG/src/pyrtg/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import inspect

from .circt import ir
from .circt.dialects import rtg
from .core import CodeGenRoot
from .rtg import rtg


class Test:
class Test(CodeGenRoot):
"""
Represents an RTG Test. Stores the test function and location.
"""
Expand Down
2 changes: 1 addition & 1 deletion frontends/PyRTG/src/rtgtool/rtgtool.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def frontend_codegen(args: argparse.Namespace) -> ir.Module:
module = ir.Module.create()
with ir.InsertionPoint(module.body):
for _, obj in inspect.getmembers(file):
if isinstance(obj, pyrtg.tests.Test):
if isinstance(obj, pyrtg.core.CodeGenRoot):
obj.codegen()
return module

Expand Down
48 changes: 43 additions & 5 deletions frontends/PyRTG/test/basic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,57 @@
# RUN: %rtgtool% %s --seed=0 --output-format=mlir | FileCheck %s --check-prefix=MLIR
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm | FileCheck %s --input-file=%t --check-prefix=ASM
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM

from pyrtg import test
from pyrtg import test, Label, rtg

# MLIR: rtg.test @test0
# MLIR-LABEL: rtg.test @test0
# MLIR-NEXT: }

# ELABORATED: rtg.test @test0
# ELABORATED-LABEL: rtg.test @test0
# ELABORATED-NEXT: }

# ASM: Begin of test0
# ASM-LABEL: Begin of test0
# ASM: End of test0


@test
def test0():
pass


# MLIR-LABEL: rtg.test @test_labels
# MLIR-NEXT: [[L0:%.+]] = rtg.label_decl "l0"
# MLIR-NEXT: [[L1:%.+]] = rtg.label_unique_decl "l1"
# MLIR-NEXT: [[L2:%.+]] = rtg.label_unique_decl "l1"
# MLIR-NEXT: rtg.label global [[L0]]
# MLIR-NEXT: rtg.label external [[L1]]
# MLIR-NEXT: rtg.label local [[L2]]
# MLIR-NEXT: }

# ELABORATED-LABEL: rtg.test @test_labels
# ELABORATED-NEXT: [[L0:%.+]] = rtg.label_decl "l0"
# ELABORATED-NEXT: rtg.label global [[L0]]
# ELABORATED-NEXT: [[L1:%.+]] = rtg.label_decl "l1_0"
# ELABORATED-NEXT: rtg.label external [[L1]]
# ELABORATED-NEXT: [[L2:%.+]] = rtg.label_decl "l1_1"
# ELABORATED-NEXT: rtg.label local [[L2]]
# ELABORATED-NEXT: }

# ASM-LABEL: Begin of test_labels
# ASM-EMPTY:
# ASM-NEXT: .global l0
# ASM-NEXT: l0:
# ASM-NEXT: .extern l1_0
# ASM-NEXT: l1_1:
# ASM-EMPTY:
# ASM: End of test_labels


@test
def test_labels():
l0 = Label.declare("l0")
l1 = Label.declare_unique("l1")
l2 = Label.declare_unique("l1")
l0.place(rtg.LabelVisibility.GLOBAL)
l1.place(rtg.LabelVisibility.EXTERNAL)
l2.place()

0 comments on commit ff9c4c2

Please sign in to comment.