Skip to content

Commit cb699cd

Browse files
ShawnXuancrazy-JiangDongHuaoneflow-ci-bot
authored
support fuse layer norm grad for npu (#10614)
Co-authored-by: JiangDongHua <759421566@qq.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
1 parent 5c62322 commit cb699cd

File tree

8 files changed

+245
-16
lines changed

8 files changed

+245
-16
lines changed

oneflow/core/autograd/gradient_funcs/layer_norm.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ limitations under the License.
1818
#include "oneflow/core/functional/functional.h"
1919

2020
namespace oneflow {
21+
22+
DEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false);
23+
2124
namespace one {
2225

2326
struct LayerNormCaptureState : public AutoGradCaptureState {
@@ -107,22 +110,36 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
107110
std::shared_ptr<Tensor> mean = saved_tensors.at(ctx->mean_index);
108111
std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);
109112

110-
if (ctx->has_affine) {
111-
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
112-
// Int64 begin_params_axis)
113-
const auto& results =
114-
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
115-
in_grads->at(1) = results->at(0); // For gamma.
116-
in_grads->at(2) = results->at(1); // For beta.
117-
}
118-
if (ctx->x_requires_grad) {
119-
if (ctx->scale) {
120-
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
121-
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
122-
begin_norm_axis, ctx->epsilon));
123-
} else {
124-
in_grads->at(0) =
125-
JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
113+
if (EnvBool<ONEFLOW_USE_FUSE_LAYER_NORM_GRAD>()) {
114+
// just for npu
115+
CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test";
116+
if (ctx->x_requires_grad) {
117+
if (ctx->scale) {
118+
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
119+
*in_grads = *JUST(functional::FuseLayerNormGrad(
120+
dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));
121+
} else {
122+
UNIMPLEMENTED();
123+
}
124+
}
125+
} else {
126+
if (ctx->has_affine) {
127+
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
128+
// Int64 begin_params_axis)
129+
const auto& results =
130+
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
131+
in_grads->at(1) = results->at(0); // For gamma.
132+
in_grads->at(2) = results->at(1); // For beta.
133+
}
134+
if (ctx->x_requires_grad) {
135+
if (ctx->scale) {
136+
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
137+
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
138+
begin_norm_axis, ctx->epsilon));
139+
} else {
140+
in_grads->at(0) = JUST(
141+
functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
142+
}
126143
}
127144
}
128145
return Maybe<void>::Ok();

oneflow/core/functional/functional_api.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,10 @@
15581558
signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad"
15591559
bind_python: False
15601560

1561+
- name: "fuse_layer_norm_grad"
1562+
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad"
1563+
bind_python: False
1564+
15611565
- name: "layer_norm_param_grad"
15621566
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad"
15631567
bind_python: False

oneflow/core/functional/impl/nn_grad_functor.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,36 @@ class LayerNormAffineGradFunctor {
983983
std::shared_ptr<OpExpr> op_;
984984
};
985985

986+
class FuseLayerNormGradFunctor {
987+
public:
988+
FuseLayerNormGradFunctor() {
989+
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
990+
.Input("dy")
991+
.Input("x")
992+
.Input("mean")
993+
.Input("inv_variance")
994+
.Input("gamma")
995+
.Output("dx")
996+
.Output("gamma_diff")
997+
.Output("beta_diff")
998+
.Build());
999+
}
1000+
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
1001+
const std::shared_ptr<one::Tensor>& x,
1002+
const std::shared_ptr<one::Tensor>& mean,
1003+
const std::shared_ptr<one::Tensor>& inv_variance,
1004+
const std::shared_ptr<one::Tensor>& gamma,
1005+
const int64_t& begin_norm_axis, const int64_t& begin_params_axis,
1006+
const double& epsilon) const {
1007+
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
1008+
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
1009+
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);
1010+
}
1011+
1012+
private:
1013+
std::shared_ptr<OpExpr> op_;
1014+
};
1015+
9861016
class LayerNormParamGradFunctor {
9871017
public:
9881018
LayerNormParamGradFunctor() {
@@ -1707,6 +1737,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
17071737
m.add_functor<impl::LayerNormGradFunctor>("LayerNormGrad");
17081738
m.add_functor<impl::LayerNormAffineGradFunctor>("LayerNormAffineGrad");
17091739
m.add_functor<impl::LayerNormParamGradFunctor>("LayerNormParamGrad");
1740+
m.add_functor<impl::FuseLayerNormGradFunctor>("FuseLayerNormGrad");
17101741
m.add_functor<impl::GroupNormGradFunctor>("GroupNormGrad");
17111742
m.add_functor<impl::GroupNormParamGradFunctor>("GroupNormParamGrad");
17121743
m.add_functor<impl::BroadcastMatmulGradBFunctor>("BroadcastMatmulGradB");

oneflow/core/job_rewriter/auto_mixed_precision.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0)
359359
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0)
360360
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0)
361361
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0)
362+
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "mean", 0)
363+
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "inv_variance", 0)
362364

363365
} // namespace
364366

oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
9595
"layer_norm",
9696
"layer_norm_param_grad",
9797
"layer_norm_grad",
98+
"fuse_layer_norm_grad",
9899
"skip_layer_norm",
99100
"rms_norm",
100101
"rms_norm_grad",

oneflow/ir/include/OneFlow/OneFlowUserOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7071,6 +7071,35 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect,
70717071
let has_data_type_infer_fn = 1;
70727072
}
70737073

7074+
def OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<"fuse_layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
7075+
let input = (ins
7076+
OneFlow_Tensor:$dy,
7077+
OneFlow_Tensor:$x,
7078+
OneFlow_Tensor:$mean,
7079+
OneFlow_Tensor:$inv_variance,
7080+
Optional<OneFlow_Tensor>:$gamma,
7081+
Optional<OneFlow_Tensor>:$_add_to_output
7082+
);
7083+
let output = (outs
7084+
OneFlow_Tensor:$dx,
7085+
OneFlow_Tensor:$gamma_diff,
7086+
OneFlow_Tensor:$beta_diff
7087+
);
7088+
let attrs = (ins
7089+
DefaultValuedAttr<SI64Attr, "0">:$begin_norm_axis,
7090+
DefaultValuedAttr<SI64Attr, "0">:$begin_params_axis,
7091+
DefaultValuedAttr<F64Attr, "0.">:$epsilon
7092+
);
7093+
let trait_attrs = (ins
7094+
DenseI32ArrayAttr:$operand_segment_sizes,
7095+
DenseI32ArrayAttr:$result_segment_sizes
7096+
);
7097+
let has_logical_tensor_desc_infer_fn = 1;
7098+
let has_physical_tensor_desc_infer_fn = 1;
7099+
let has_get_sbp_fn = 1;
7100+
let has_data_type_infer_fn = 1;
7101+
}
7102+
70747103
def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
70757104
let input = (ins
70767105
OneFlow_Tensor:$dy,

oneflow/user/kernels/layer_norm_cpu_kernel.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ class LayerNormGradCpuKernel final : public user_op::OpKernel {
5757
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(float)
5858
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(double)
5959

60+
template<typename T>
61+
class FuseLayerNormGradCpuKernel final : public user_op::OpKernel {
62+
public:
63+
FuseLayerNormGradCpuKernel() = default;
64+
~FuseLayerNormGradCpuKernel() = default;
65+
66+
private:
67+
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
68+
void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };
69+
};
70+
71+
#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \
72+
REGISTER_USER_KERNEL("fuse_layer_norm_grad") \
73+
.SetCreateFn<LayerNormGradCpuKernel<dtype>>() \
74+
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \
75+
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value));
76+
77+
REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float)
78+
REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double)
79+
6080
template<typename T>
6181
class LayerNormParamGradCpuKernel final : public user_op::OpKernel {
6282
public:

oneflow/user/ops/layer_norm_op.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,4 +268,129 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
268268
return Maybe<void>::Ok();
269269
}
270270

271+
/* static */ Maybe<void> FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
272+
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
273+
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
274+
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
275+
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
276+
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
277+
CHECK_EQ_OR_RETURN(dy.shape(), x.shape()) << "dy and x shapes should be equal.";
278+
const int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
279+
CHECK_GT_OR_RETURN(begin_norm_axis, 0) << "begin_norm_axis must be greater than 0.";
280+
const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);
281+
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape) << "mean shape must match bn_param_shape.";
282+
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape)
283+
<< "inv_variance shape must match bn_param_shape.";
284+
dx->set_shape(dy.shape());
285+
dx->set_is_dynamic(dy.is_dynamic());
286+
if (ctx->has_input("_add_to_output", 0)) {
287+
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
288+
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape())
289+
<< "add_to_output shape must match dx shape.";
290+
}
291+
292+
auto has_tensor = [ctx](const std::string& bn) -> bool {
293+
bool ret = false;
294+
for (const auto& t : ctx->inputs()) {
295+
if (bn == t.first) { return true; }
296+
}
297+
for (const auto& t : ctx->outputs()) {
298+
if (bn == t.first) { return true; }
299+
}
300+
return ret;
301+
};
302+
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
303+
const bool has_beta_diff = has_tensor("beta_diff");
304+
const bool has_gamma_diff = has_tensor("gamma_diff");
305+
CHECK_GE_OR_RETURN(begin_params_axis, 1)
306+
<< "begin_params_axis must be greater than or equal to 1.";
307+
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes())
308+
<< "begin_params_axis must be less than the number of axes in dy shape.";
309+
DimVector param_shape_dim_vec;
310+
param_shape_dim_vec.insert(param_shape_dim_vec.end(),
311+
dy.shape().dim_vec().cbegin() + begin_params_axis,
312+
dy.shape().dim_vec().cend());
313+
const Shape param_shape(param_shape_dim_vec);
314+
if (has_beta_diff) {
315+
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
316+
beta_diff->set_shape(param_shape);
317+
}
318+
if (has_gamma_diff) {
319+
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
320+
gamma_diff->set_shape(param_shape);
321+
}
322+
return Maybe<void>::Ok();
323+
}
324+
325+
/*static*/ Maybe<void> FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
326+
return InferLogicalTensorDesc(ctx);
327+
}
328+
329+
/* static */ Maybe<void> FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) {
330+
std::vector<user_op::OpArg> broadcast_args;
331+
if (ctx->user_op_conf().has_input("gamma", 0)) { broadcast_args.emplace_back("gamma", 0); }
332+
int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
333+
int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
334+
CHECK_EQ(begin_norm_axis, begin_params_axis)
335+
<< "begin_norm_axis and begin_params_axis must be equal, but got " << begin_norm_axis
336+
<< " and " << begin_params_axis;
337+
for (int i = 0; i < begin_norm_axis; ++i) {
338+
ctx->NewBuilder()
339+
.Split(ctx->inputs(), i)
340+
.Split(user_op::OpArg("dx", 0), i)
341+
.PartialSum(user_op::OpArg("gamma_diff", 0))
342+
.PartialSum(user_op::OpArg("beta_diff", 0))
343+
.Broadcast(broadcast_args)
344+
.Build();
345+
}
346+
return Maybe<void>::Ok();
347+
}
348+
349+
/* static */ Maybe<void> FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) {
350+
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
351+
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
352+
CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())
353+
<< "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got "
354+
<< DataType_Name(dy.data_type());
355+
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
356+
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
357+
DataType bn_param_data_type = InferBnParamDataType(x.data_type());
358+
CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type)
359+
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
360+
<< DataType_Name(mean.data_type());
361+
CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type)
362+
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
363+
<< DataType_Name(inv_variance.data_type());
364+
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
365+
dx->set_data_type(dy.data_type());
366+
if (ctx->has_input("_add_to_output", 0)) {
367+
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
368+
CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type())
369+
<< "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got "
370+
<< DataType_Name(add_to_output.data_type());
371+
}
372+
373+
auto has_tensor = [ctx](const std::string& bn) -> bool {
374+
bool ret = false;
375+
for (auto& t : ctx->inputs()) {
376+
if (bn == t.first) { return true; }
377+
}
378+
for (auto& t : ctx->outputs()) {
379+
if (bn == t.first) { return true; }
380+
}
381+
return ret;
382+
};
383+
const bool has_beta_diff = has_tensor("beta_diff");
384+
const bool has_gamma_diff = has_tensor("gamma_diff");
385+
if (has_beta_diff) {
386+
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
387+
beta_diff->set_data_type(dy.data_type());
388+
}
389+
if (has_gamma_diff) {
390+
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
391+
gamma_diff->set_data_type(dy.data_type());
392+
}
393+
return Maybe<void>::Ok();
394+
}
395+
271396
} // namespace oneflow

0 commit comments

Comments
 (0)