Skip to content

Commit

Permalink
Do not reorder between copy_ and return statement (#920)
Browse files Browse the repository at this point in the history
  • Loading branch information
shino16 authored Aug 5, 2024
1 parent e05785e commit 681aef7
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
11 changes: 7 additions & 4 deletions thunder/executors/data_dependent_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, trace: TraceCtx):
# as it appears to be far off from being universal.
# We use indices as hash values instead.
bsym_id_to_node_map: list[int] = []
copy_nodes: list[Node] = []
for bsym_id, bsym in enumerate(trace.bound_symbols):
node = Node(bsym_id, [bsym], [bsym_id], bsym_id, bsym_id)
bsym_id_to_node_map.append(node)
Expand All @@ -99,10 +100,13 @@ def __init__(self, trace: TraceCtx):
lambda: f"Found multiple RETURN nodes while converting a list of bound symbols to a dag",
)
self.return_node = node
for copy_node in copy_nodes:
node.parents.add(copy_node)
copy_node.children.add(node)
elif bsym.sym.id is PrimIDs.COPY_:
copy_nodes.append(node)

for bsym_id, node in enumerate(bsym_id_to_node_map):
has_parents: bool = False

bsym = node.group_bsyms[0]
for inp in bsym.flat_args:
if not isinstance(inp, Proxy):
Expand All @@ -111,9 +115,8 @@ def __init__(self, trace: TraceCtx):
producer_id = producers[inp]
parent = bsym_id_to_node_map[producer_id]
node.parents.add(parent)
has_parents = True

if not has_parents:
if not node.parents:
self.roots.append(node)

for out in bsym.flat_outs:
Expand Down
13 changes: 0 additions & 13 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,19 +818,6 @@ def _can_fuse_node(n: Node):
fused_bsyms.extend(fusion.bound_symbols)
fused_bsyms.extend(epilogue)

# Force return operator to be the last one in the fused_bsyms
if fused_bsyms[-1].sym.id != PrimIDs.RETURN:
return_idx: int = -1
for i, fused_bsym in enumerate(fused_bsyms):
if fused_bsym.sym.id == PrimIDs.RETURN:
return_idx = i
break
utils.check(
return_idx != -1,
lambda: f"Return operator does not exist in bound symbols",
)
fused_bsyms.append(fused_bsyms.pop(return_idx))

fusedtrace.bound_symbols = fused_bsyms

# Some of the operations might be better placed with its consumers (for
Expand Down
11 changes: 10 additions & 1 deletion thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def available_devicetypes():


class TestExecutor:
def is_available(self) -> bool:
return True

def supports_dtype(self, dtype: datatypes.dtype) -> bool:
return dtype in datatypes.resolve_dtypes(self.supported_dtypes)

Expand Down Expand Up @@ -209,6 +212,9 @@ class TorchCompileCatTestExecutor(TestExecutor):
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
supported_dtypes = (datatypes.dtype,)

def is_available(self) -> bool:
return not IS_WINDOWS

def executors_list(self) -> list[extend.Executor]:
from thunder.executors.torch_compile import torch_compile_cat_ex

Expand All @@ -223,6 +229,9 @@ class TorchCompileTestExecutor(TestExecutor):
supported_devicetypes = (devices.DeviceType.CPU, devices.DeviceType.CUDA)
supported_dtypes = (datatypes.dtype,)

def is_available(self) -> bool:
return not IS_WINDOWS

def executors_list(self) -> list[extend.Executor]:
from thunder.executors.torch_compile import torch_compile_ex

Expand Down Expand Up @@ -465,7 +474,7 @@ def __call__(self, test_template):
for executor, devicetype in product(
sorted(self.executors, key=lambda x: repr(x)), sorted(self.devicetypes, key=lambda x: repr(x))
):
if executor is None:
if executor is None or not executor.is_available():
continue

if not executor.supports_devicetype(devicetype):
Expand Down
11 changes: 10 additions & 1 deletion thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import thunder.core.devices as devices
from thunder.core import dtypes
from thunder.core.prims import PrimIDs
from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING, TorchExecutor, nvFuserExecutor
from thunder.tests.framework import (
instantiate,
ops,
requiresCUDA,
NOTHING,
TorchExecutor,
TorchCompileExecutor,
nvFuserExecutor,
)
from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput
from thunder.tests.make_tensor import make_tensor
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
Expand Down Expand Up @@ -451,6 +459,7 @@ def f(xs, ys, z):

@instantiate(
dtypes=NOTHING,
executors=(TorchExecutor, TorchCompileExecutor, nvFuserExecutor),
)
def test_single_tensor_adam_like(executor, device, _):

Expand Down

0 comments on commit 681aef7

Please sign in to comment.