Skip to content

Commit 91ff7b7

Browse files
authored
Handle torchex.copy_ eqally to prims.copy_ (#931)
1 parent e8d2cbf commit 91ff7b7

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

thunder/executors/data_dependent_partition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from thunder.core.trace import TraceCtx
1010
from thunder.core.symbol import BoundSymbol
1111
from thunder.core.proxies import variableify, Proxy
12+
import thunder.core.prims as prims
1213
from thunder.core.prims import PrimIDs
14+
from thunder.executors import torchex
1315

1416

1517
# Represents a region and its parents (regions it consumes the output of) and
@@ -103,7 +105,7 @@ def __init__(self, trace: TraceCtx):
103105
for copy_node in copy_nodes:
104106
node.parents.add(copy_node)
105107
copy_node.children.add(node)
106-
elif bsym.sym.id is PrimIDs.COPY_:
108+
elif bsym.sym in (prims.copy_, torchex.copy_):
107109
copy_nodes.append(node)
108110

109111
for bsym_id, node in enumerate(bsym_id_to_node_map):

0 commit comments

Comments
 (0)