Skip to content

Commit

Permalink
Handle torchex.copy_ eqally to prims.copy_ (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
shino16 authored Aug 26, 2024
1 parent e8d2cbf commit 91ff7b7
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion thunder/executors/data_dependent_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from thunder.core.trace import TraceCtx
from thunder.core.symbol import BoundSymbol
from thunder.core.proxies import variableify, Proxy
import thunder.core.prims as prims
from thunder.core.prims import PrimIDs
from thunder.executors import torchex


# Represents a region and its parents (regions it consumes the output of) and
Expand Down Expand Up @@ -103,7 +105,7 @@ def __init__(self, trace: TraceCtx):
for copy_node in copy_nodes:
node.parents.add(copy_node)
copy_node.children.add(node)
elif bsym.sym.id is PrimIDs.COPY_:
elif bsym.sym in (prims.copy_, torchex.copy_):
copy_nodes.append(node)

for bsym_id, node in enumerate(bsym_id_to_node_map):
Expand Down

0 comments on commit 91ff7b7

Please sign in to comment.