Skip to content

Commit

Permalink
Add a fix to broadcast_in_dim nvFuser's executor definition to get …
Browse files Browse the repository at this point in the history
…an inline shape for constants (#1336)
  • Loading branch information
kevinstephano authored Oct 31, 2024
1 parent 5365c2b commit 01f3a04
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,10 @@ def broadcast_in_dim(
a: TensorProxy, shape: list[int], broadcast_dimensions: list[int], *, fd: FusionDefinition, lc_to_nv_map: dict
) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
nv_shape = getnv(shape, fd, lc_to_nv_map)
if any(map(lambda x: isinstance(x, NumberProxy), shape)):
nv_shape = getnv(shape, fd, lc_to_nv_map)
else:
nv_shape = shape

return fd.ops.broadcast_in_dim(nva, nv_shape, broadcast_dimensions)

Expand Down

0 comments on commit 01f3a04

Please sign in to comment.