Skip to content

Commit

Permalink
Normalize standard input shapes for rocBLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca committed Feb 18, 2025
1 parent 3fd27d3 commit 42755e9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ struct find_rocblas_gemm_pointwise : gemm_pointwise
auto c_ins_name = c_ins->get_operator().name();
// const-fold input if not standard shape since rocblas can't handle it
// Updated for a case where "standard" shape has out-of-sequence strides
if((not s.standard() or s.normalize_standard() != s) and c_ins_name != "multibroadcast")
if(not s.standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
Expand Down
11 changes: 8 additions & 3 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ template <typename T>
struct gemm_impl
{
gemm_impl(const shape& output_shape,
const std::vector<shape>& input_shapes,
std::vector<shape> input_shapes,
T alpha_param,
T beta_param,
bool compute_fp32_flag)
Expand All @@ -182,6 +182,11 @@ struct gemm_impl
is_3inputs(input_shapes.size() == 4),
compute_fp32(compute_fp32_flag)
{
std::transform(input_shapes.begin(),
input_shapes.end(),
input_shapes.begin(),
[&](const shape& s) { return s.normalize_standard(); });

if(not is_3inputs)
{
beta = 0;
Expand Down Expand Up @@ -592,7 +597,7 @@ void gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape().normalize_standard(); });
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
Expand All @@ -609,7 +614,7 @@ void gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape().normalize_standard(); });
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
Expand Down

0 comments on commit 42755e9

Please sign in to comment.