Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,69 @@ struct find_unpack_int4_mlir_op
}
};

struct find_mlir_reshape_ops
{
auto matcher() const
{
auto reshapes = reshaper_names();
// slice is not supported
reshapes.erase("slice");
return match::name(reshapes)(match::arg(0)(match::name("gpu::mlir_op")(match::used_once())), match::used_once());
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto mlir_ins = ins->inputs().front();

auto* mm = mlir_ins->module_inputs().front();
module_ref nm = mpm.create_module(mm->name() + ":" + ins->name());
nm->set_bypass();

auto y = nm->fuse(*mm, mlir_ins->inputs());
auto ret = nm->add_instruction(ins->get_operator(), y);
nm->add_return({ret});
mpm.get_module().replace_instruction(ins, mlir_ins->get_operator(), mlir_ins->inputs(), {nm});
}
};

struct find_convolution_reshape
{
auto matcher() const
{
return match::name("reshape")(match::arg(0)(match::name("convolution").bind("convolution")));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv = r.instructions["convolution"];
auto out_dims = ins->get_shape().lens();
auto conv_dims = conv->get_shape().lens();
if(out_dims.size() != 5)
return;
if(conv_dims.size() != 4)
return;
auto perm = find_permutation(conv->get_shape());
if(perm.back() != 1)
return;
if(out_dims[0] != conv_dims[0])
return;
if(not std::equal(conv_dims.begin() + 2, conv_dims.end(), out_dims.begin() + 3, out_dims.end()))
return;
if(out_dims[2] > 32)
return;
if(out_dims[1] < 4)
return;
auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs());
// auto t2 = mpm.get_module().insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 3, 4, 2}}}), reshape);
auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape);
auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1);
auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c);
mpm.get_module().replace_instruction(ins, t2);
}
};

} // namespace

#endif // MIGRAPHX_MLIR
Expand All @@ -1061,6 +1124,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
return std::max(m1, m2);
};

match::find_matches(mpm, find_convolution_reshape{});
// Attention offloads; default disabled
if(mlir_attention_enabled(ctx) or enable_extra)
{
Expand Down Expand Up @@ -1092,6 +1156,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const

match::find_matches(mpm, find_pointwise_mlir{});
match::find_matches(mpm, find_unpack_int4_mlir_op{});

for(int i=0;i<4;i++)
match::find_matches(mpm, find_mlir_reshape_ops{});

#else
(void)mpm;
Expand Down
Loading