Skip to content

Commit

Permalink
fix transform
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 23, 2023
1 parent a73d241 commit 9219966
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 198 deletions.
74 changes: 6 additions & 68 deletions mlir/extras/runtime/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,79 +210,17 @@ def lower_to_vulkan(self, index_bitwidth=None):
def transform_dialect_erase_schedule(self):
return self.add_pass("test-transform-dialect-erase-schedule")

def transform_dialect_interpreter(
def transform_interpreter(
self,
bind_first_extra_to_ops=None,
bind_first_extra_to_params=None,
bind_first_extra_to_results_of_ops=None,
bind_second_extra_to_ops=None,
bind_second_extra_to_params=None,
bind_second_extra_to_results_of_ops=None,
debug_payload_root_tag=None,
debug_transform_root_tag=None,
enable_expensive_checks=None,
transform_file_name=None,
test_module_generation=None,
disable_expensive_checks=None,
entry_point=None,
):
if bind_first_extra_to_ops is not None and isinstance(
bind_first_extra_to_ops, (list, tuple)
):
bind_first_extra_to_ops = ",".join(map(str, bind_first_extra_to_ops))
if bind_first_extra_to_params is not None and isinstance(
bind_first_extra_to_params, (list, tuple)
):
bind_first_extra_to_params = ",".join(map(str, bind_first_extra_to_params))
if bind_first_extra_to_results_of_ops is not None and isinstance(
bind_first_extra_to_results_of_ops, (list, tuple)
):
bind_first_extra_to_results_of_ops = ",".join(
map(str, bind_first_extra_to_results_of_ops)
)
if bind_second_extra_to_ops is not None and isinstance(
bind_second_extra_to_ops, (list, tuple)
):
bind_second_extra_to_ops = ",".join(map(str, bind_second_extra_to_ops))
if bind_second_extra_to_params is not None and isinstance(
bind_second_extra_to_params, (list, tuple)
):
bind_second_extra_to_params = ",".join(
map(str, bind_second_extra_to_params)
)
if bind_second_extra_to_results_of_ops is not None and isinstance(
bind_second_extra_to_results_of_ops, (list, tuple)
):
bind_second_extra_to_results_of_ops = ",".join(
map(str, bind_second_extra_to_results_of_ops)
)
if debug_payload_root_tag is not None and isinstance(
debug_payload_root_tag, (list, tuple)
):
debug_payload_root_tag = ",".join(map(str, debug_payload_root_tag))
if debug_transform_root_tag is not None and isinstance(
debug_transform_root_tag, (list, tuple)
):
debug_transform_root_tag = ",".join(map(str, debug_transform_root_tag))
if transform_file_name is not None and isinstance(
transform_file_name, (list, tuple)
):
transform_file_name = ",".join(map(str, transform_file_name))
if test_module_generation is not None and isinstance(
test_module_generation, (list, tuple)
):
test_module_generation = ",".join(map(str, test_module_generation))
return self.add_pass(
"test-transform-dialect-interpreter",
bind_first_extra_to_ops=bind_first_extra_to_ops,
bind_first_extra_to_params=bind_first_extra_to_params,
bind_first_extra_to_results_of_ops=bind_first_extra_to_results_of_ops,
bind_second_extra_to_ops=bind_second_extra_to_ops,
bind_second_extra_to_params=bind_second_extra_to_params,
bind_second_extra_to_results_of_ops=bind_second_extra_to_results_of_ops,
"transform-interpreter",
debug_payload_root_tag=debug_payload_root_tag,
debug_transform_root_tag=debug_transform_root_tag,
enable_expensive_checks=enable_expensive_checks,
transform_file_name=transform_file_name,
test_module_generation=test_module_generation,
disable_expensive_checks=disable_expensive_checks,
entry_point=entry_point,
)

############################
Expand Down
Loading

0 comments on commit 9219966

Please sign in to comment.