diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index c85ad67cb2..a13691103f 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -1135,7 +1135,10 @@ def broadcast_in_dim( a: TensorProxy, shape: list[int], broadcast_dimensions: list[int], *, fd: FusionDefinition, lc_to_nv_map: dict ) -> Any: nva = getnv(a, fd, lc_to_nv_map) - nv_shape = getnv(shape, fd, lc_to_nv_map) + if any(map(lambda x: isinstance(x, NumberProxy), shape)): + nv_shape = getnv(shape, fd, lc_to_nv_map) + else: + nv_shape = shape return fd.ops.broadcast_in_dim(nva, nv_shape, broadcast_dimensions)