-
Notifications
You must be signed in to change notification settings - Fork 318
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyRTG] Wrapper around SSA values and label support
- Loading branch information
Showing
9 changed files
with
213 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .circt import ir | ||
|
||
|
||
class CodeGenRoot: | ||
""" | ||
This is the base class for classes that have to be visited by the RTG tool | ||
during codegen. | ||
""" | ||
|
||
def codegen(self): | ||
assert False, "must be implemented by the subclass" | ||
|
||
|
||
class Value: | ||
""" | ||
This class wraps around MLIR SSA values to provide a more Python native | ||
experience. Instead of having a value class that stores the type, classes | ||
deriving from this class represent specific types of values. Operations on | ||
those values can then be exposed as methods that can support more convenient | ||
bridging between Python values and MLIR values (e.g., accepting a Python | ||
integer and automatically building a ConstantOp in MLIR). | ||
""" | ||
|
||
def get_type(self) -> ir.Type: | ||
assert False, "must be implemented by subclass" | ||
|
||
def _get_ssa_value(self) -> ir.Value: | ||
assert False, "must be implemented by subclass" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from __future__ import annotations | ||
|
||
from .circt import ir | ||
from .core import Value | ||
from .rtg import rtg | ||
|
||
|
||
class Label(Value): | ||
""" | ||
Represents an ISA Assembly label. It can be declared and then passed around | ||
like every value. To place a label at a specific location in a sequence call | ||
'place'. It is the user's responsibility to place a label such that if the | ||
label is used by an instruction in the fully randomized test, there exists | ||
exactly one placement of the label to not end up with ambiguity or usage of | ||
an undeclared label. | ||
""" | ||
|
||
def __init__(self, value: ir.Value): | ||
self._value = value | ||
|
||
def declare(string: str) -> Label: | ||
""" | ||
Declares a label with a fixed name. Labels returned by different calls to | ||
this function but with the same arguments refer to the same label. | ||
""" | ||
|
||
return rtg.LabelDeclOp(string, []) | ||
|
||
def declare_unique(string: str) -> Label: | ||
""" | ||
Declares a unique label. This means, all usages of the value returned by this | ||
function will refer to the same label, but no other label declarations can | ||
conflict with this label, including labels returned by other calls to this | ||
function or fixed labels declared with 'declare_label'. | ||
""" | ||
|
||
return rtg.LabelUniqueDeclOp(string, []) | ||
|
||
def place( | ||
self, | ||
visibility: rtg.LabelVisibility = rtg.LabelVisibility.LOCAL) -> None: | ||
""" | ||
Places a declared label in a sequence or test. | ||
""" | ||
|
||
return rtg.LabelOp(rtg.LabelVisibilityAttr.get(visibility), self._value) | ||
|
||
def get_type(self) -> ir.Type: | ||
return rtg.LabelType.get() | ||
|
||
def _get_ssa_value(self) -> ir.Value: | ||
return self._value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .support import wrap_opviews_with_values | ||
from .circt.dialects import rtg | ||
|
||
wrap_opviews_with_values(rtg, rtg.__name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .circt import support, ir | ||
from .core import Value | ||
|
||
|
||
def _FromCirctValue(value: ir.Value) -> Value: | ||
type = support.type_to_pytype(value.type) | ||
from .rtg import rtg | ||
if isinstance(type, rtg.LabelType): | ||
from .labels import Label | ||
return Label(value) | ||
assert False, "Unsupported value" | ||
|
||
|
||
def wrap_opviews_with_values(dialect, module_name, excluded=[]): | ||
""" | ||
Wraps all of a dialect's OpView classes to have their create method return a | ||
Value instead of an OpView. | ||
""" | ||
|
||
import sys | ||
module = sys.modules[module_name] | ||
|
||
for attr in dir(dialect): | ||
cls = getattr(dialect, attr) | ||
|
||
if attr not in excluded and isinstance(cls, type) and issubclass( | ||
cls, ir.OpView): | ||
|
||
def specialize_create(cls): | ||
|
||
def create(*args, **kwargs): | ||
# If any of the arguments are 'pyrtg.Value', we need to convert them. | ||
def to_circt(arg): | ||
if isinstance(arg, (list, tuple)): | ||
return [to_circt(a) for a in arg] | ||
return arg | ||
|
||
args = [to_circt(arg) for arg in args] | ||
kwargs = {k: to_circt(v) for k, v in kwargs.items()} | ||
# Create the OpView. | ||
if hasattr(cls, "create"): | ||
created = cls.create(*args, **kwargs) | ||
else: | ||
created = cls(*args, **kwargs) | ||
if isinstance(created, support.NamedValueOpView): | ||
created = created.opview | ||
|
||
# Return the wrapped values, if any. | ||
converted_results = tuple( | ||
_FromCirctValue(res) for res in created.results) | ||
return converted_results[0] if len( | ||
converted_results) == 1 else created | ||
|
||
return create | ||
|
||
wrapped_class = specialize_create(cls) | ||
setattr(module, attr, wrapped_class) | ||
else: | ||
setattr(module, attr, cls) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,57 @@ | ||
# RUN: %rtgtool% %s --seed=0 --output-format=mlir | FileCheck %s --check-prefix=MLIR | ||
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED | ||
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm | FileCheck %s --input-file=%t --check-prefix=ASM | ||
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM | ||
|
||
from pyrtg import test | ||
from pyrtg import test, Label, rtg | ||
|
||
# MLIR: rtg.test @test0 | ||
# MLIR-LABEL: rtg.test @test0 | ||
# MLIR-NEXT: } | ||
|
||
# ELABORATED: rtg.test @test0 | ||
# ELABORATED-LABEL: rtg.test @test0 | ||
# ELABORATED-NEXT: } | ||
|
||
# ASM: Begin of test0 | ||
# ASM-LABEL: Begin of test0 | ||
# ASM: End of test0 | ||
|
||
|
||
@test | ||
def test0(): | ||
pass | ||
|
||
|
||
# MLIR-LABEL: rtg.test @test_labels | ||
# MLIR-NEXT: [[L0:%.+]] = rtg.label_decl "l0" | ||
# MLIR-NEXT: [[L1:%.+]] = rtg.label_unique_decl "l1" | ||
# MLIR-NEXT: [[L2:%.+]] = rtg.label_unique_decl "l1" | ||
# MLIR-NEXT: rtg.label global [[L0]] | ||
# MLIR-NEXT: rtg.label external [[L1]] | ||
# MLIR-NEXT: rtg.label local [[L2]] | ||
# MLIR-NEXT: } | ||
|
||
# ELABORATED-LABEL: rtg.test @test_labels | ||
# ELABORATED-NEXT: [[L0:%.+]] = rtg.label_decl "l0" | ||
# ELABORATED-NEXT: rtg.label global [[L0]] | ||
# ELABORATED-NEXT: [[L1:%.+]] = rtg.label_decl "l1_0" | ||
# ELABORATED-NEXT: rtg.label external [[L1]] | ||
# ELABORATED-NEXT: [[L2:%.+]] = rtg.label_decl "l1_1" | ||
# ELABORATED-NEXT: rtg.label local [[L2]] | ||
# ELABORATED-NEXT: } | ||
|
||
# ASM-LABEL: Begin of test_labels | ||
# ASM-EMPTY: | ||
# ASM-NEXT: .global l0 | ||
# ASM-NEXT: l0: | ||
# ASM-NEXT: .extern l1_0 | ||
# ASM-NEXT: l1_1: | ||
# ASM-EMPTY: | ||
# ASM: End of test_labels | ||
|
||
|
||
@test | ||
def test_labels(): | ||
l0 = Label.declare("l0") | ||
l1 = Label.declare_unique("l1") | ||
l2 = Label.declare_unique("l1") | ||
l0.place(rtg.LabelVisibility.GLOBAL) | ||
l1.place(rtg.LabelVisibility.EXTERNAL) | ||
l2.place() |