Skip to content

Commit

Permalink
Don't const-fold for multibroadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca committed Feb 14, 2025
1 parent bbc964e commit 3fd27d3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,11 +673,12 @@ struct find_rocblas_gemm_pointwise : gemm_pointwise

if(ins->inputs().size() == 3)
{
auto c_ins = r.instructions["c"];
shape s = c_ins->get_shape();
auto c_ins = r.instructions["c"];
shape s = c_ins->get_shape();
auto c_ins_name = c_ins->get_operator().name();
// const-fold input if not standard shape since rocblas can't handle it
// Updated for a case where "standard" shape has out-of-sequence strides
if(not s.standard() or s.normalize_standard() != s)
if((not s.standard() or s.normalize_standard() != s) and c_ins_name != "multibroadcast")
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
Expand Down

0 comments on commit 3fd27d3

Please sign in to comment.