Skip to content

Commit ea585f6

Browse files
flash_attention_v2_backward (#10495)
flash attention v2 backward算子 --------- Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
1 parent 44ad994 commit ea585f6

File tree

9 files changed

+746
-26
lines changed

9 files changed

+746
-26
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
Copyright 2020 The OneFlow Authors. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
#include "oneflow/core/common/maybe.h"
17+
#include "oneflow/core/framework/op_expr_grad_function.h"
18+
#include "oneflow/core/framework/op_builder.h"
19+
#include "oneflow/core/framework/op_expr.h"
20+
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
21+
#include "oneflow/core/functional/functional.h"
22+
#include "oneflow/core/common/container_util.h"
23+
#if CUDA_VERSION >= 11070
24+
25+
namespace oneflow {
26+
27+
namespace one {
28+
29+
struct ScaledDotProductFlashAttentionCaptureState : public AutoGradCaptureState {
30+
bool query_requires_grad = true;
31+
bool key_requires_grad = true;
32+
bool value_requires_grad = true;
33+
size_t query_idx = 0;
34+
size_t key_idx = 0;
35+
size_t value_idx = 0;
36+
size_t out_idx = 0;
37+
size_t softmax_lse_idx = 0;
38+
size_t rng_state_idx = 0;
39+
float p_dropout = .0f;
40+
float softmax_scale = .0f;
41+
bool is_causal = false;
42+
};
43+
44+
class ScaledDotProductFlashAttention
45+
: public OpExprGradFunction<ScaledDotProductFlashAttentionCaptureState> {
46+
public:
47+
Maybe<void> Init(const OpExpr& op) override {
48+
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
49+
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be None. ";
50+
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
51+
return Maybe<void>::Ok();
52+
}
53+
54+
Maybe<void> Capture(ScaledDotProductFlashAttentionCaptureState* ctx, const TensorTuple& inputs,
55+
const TensorTuple& outputs, const AttrMap& attrs) const override {
56+
CHECK_EQ_OR_RETURN(inputs.size(), 3) << "Input size should be equal to 3. ";
57+
ComposedAttrMap composed_attrs(attrs, base_attrs_);
58+
ctx->p_dropout = JUST(composed_attrs.GetAttr<float>("p_dropout"));
59+
ctx->softmax_scale = JUST(composed_attrs.GetAttr<float>("softmax_scale"));
60+
ctx->is_causal = JUST(composed_attrs.GetAttr<bool>("is_causal"));
61+
ctx->query_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
62+
ctx->key_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad();
63+
ctx->value_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad();
64+
ctx->query_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
65+
ctx->key_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));
66+
ctx->value_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2)));
67+
ctx->out_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));
68+
ctx->softmax_lse_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1)));
69+
ctx->rng_state_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 2)));
70+
return Maybe<void>::Ok();
71+
}
72+
73+
Maybe<void> Apply(const ScaledDotProductFlashAttentionCaptureState* ctx,
74+
const TensorTuple& out_grads, TensorTuple* in_grads) const override {
75+
CHECK_EQ_OR_RETURN(out_grads.size(), 3) << "Out grads size should be equal to 3. ";
76+
std::shared_ptr<oneflow::one::TensorTuple> grads;
77+
in_grads->resize(3);
78+
grads = JUST(functional::ScaledDotProductFlashAttentionGrad(
79+
JUST(oneflow::VectorAt(out_grads, 0)),
80+
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->query_idx)),
81+
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->key_idx)),
82+
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->value_idx)),
83+
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->out_idx)),
84+
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->softmax_lse_idx)),
85+
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->rng_state_idx)), ctx->p_dropout,
86+
ctx->is_causal, ctx->softmax_scale));
87+
88+
if (ctx->query_requires_grad) {
89+
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0));
90+
}
91+
if (ctx->key_requires_grad) {
92+
JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1));
93+
}
94+
if (ctx->value_requires_grad) {
95+
JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2));
96+
}
97+
98+
return Maybe<void>::Ok();
99+
}
100+
101+
private:
102+
AttrMap base_attrs_;
103+
};
104+
105+
REGISTER_OP_EXPR_GRAD_FUNCTION("scaled_dot_product_flash_attention",
106+
ScaledDotProductFlashAttention);
107+
108+
} // namespace one
109+
110+
} // namespace oneflow
111+
112+
#endif // CUDA_VERSION >= 11070

oneflow/core/functional/functional_api.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,6 +2688,10 @@
26882688
signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention"
26892689
bind_python: True
26902690

2691+
- name: "scaled_dot_product_attention_grad"
2692+
signature: "TensorTuple (Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor softmax_lse, Tensor rng_state, Float dropout_p=0.0, Bool is_causal=False, Float scale=0.0) => ScaledDotProductFlashAttentionGrad"
2693+
bind_python: False
2694+
26912695
- name: "fused_multi_head_attention_inference"
26922696
signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference"
26932697
bind_python: True

oneflow/core/functional/impl/nn_functor.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5538,6 +5538,135 @@ class ScaledDotProductFlashAttentionFunctor {
55385538
#endif // CUDA_VERSION >= 11070
55395539
};
55405540

5541+
class ScaledDotProductFlashAttentionGradFunctor {
5542+
public:
5543+
ScaledDotProductFlashAttentionGradFunctor() {
5544+
#if CUDA_VERSION >= 11070
5545+
op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention_grad")
5546+
.Input("grad_out")
5547+
.Input("query")
5548+
.Input("key")
5549+
.Input("value")
5550+
.Input("out")
5551+
.Input("softmax_lse")
5552+
.Input("rng_state")
5553+
.Output("grad_q")
5554+
.Output("grad_k")
5555+
.Output("grad_v")
5556+
.Build());
5557+
#endif
5558+
}
5559+
5560+
Maybe<TensorTuple> operator()(
5561+
const std::shared_ptr<one::Tensor>& grad_out, const std::shared_ptr<one::Tensor>& query,
5562+
const std::shared_ptr<one::Tensor>& key, const std::shared_ptr<one::Tensor>& value,
5563+
const std::shared_ptr<one::Tensor>& out, const std::shared_ptr<one::Tensor>& softmax_lse,
5564+
const std::shared_ptr<one::Tensor>& rng_state, const float& dropout_p, const bool& is_causal,
5565+
const float& scale) const {
5566+
#if CUDA_VERSION >= 11070
5567+
// grad_out(batch x q_sqe_len x num_heads x head_size)
5568+
// query (batch x q_seq_len x num_heads x head_size_padded)
5569+
// key (batch x kv_seq_len x num_heads_k x head_size_padded)
5570+
// value (batch x kv_seq_len x num_heads_k x head_size_padded)
5571+
// out (batch x kv_seq_len x num_heads x head_size_padded)
5572+
// softmax_lse (batch x num_heads x q_seq_len)
5573+
const auto head_size = grad_out->shape()->At(3);
5574+
const auto head_size_padded = query->shape()->At(3);
5575+
const auto batch_size = query->shape()->At(0);
5576+
const auto seqlen_q = query->shape()->At(1);
5577+
const auto seqlen_k = key->shape()->At(1);
5578+
const auto num_heads = query->shape()->At(2);
5579+
const auto num_heads_k = key->shape()->At(2);
5580+
CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0))
5581+
<< " key has different batch size from query.";
5582+
CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0))
5583+
<< " value has different batch size from query.";
5584+
CHECK_EQ_OR_RETURN(batch_size, grad_out->shape()->At(0))
5585+
<< " grad_out has different batch size from query.";
5586+
CHECK_EQ_OR_RETURN(batch_size, out->shape()->At(0))
5587+
<< " out has different batch size from query.";
5588+
CHECK_EQ_OR_RETURN(batch_size, softmax_lse->shape()->At(0))
5589+
<< " softmax_lse has different batch size from query.";
5590+
CHECK_EQ_OR_RETURN(num_heads, grad_out->shape()->At(2))
5591+
<< " grad_out has different num_heads from query.";
5592+
CHECK_EQ_OR_RETURN(num_heads, softmax_lse->shape()->At(1))
5593+
<< " softmax_lse has different num_heads from query.";
5594+
CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(2))
5595+
<< " value has different num_heads from key.";
5596+
CHECK_EQ_OR_RETURN(seqlen_q, grad_out->shape()->At(1))
5597+
<< " grad_out has different seq_len from query.";
5598+
CHECK_EQ_OR_RETURN(seqlen_q, softmax_lse->shape()->At(2))
5599+
<< " softmax_lse has different seq_len from query.";
5600+
CHECK_EQ_OR_RETURN(head_size_padded, key->shape()->At(3))
5601+
<< " key has different head dims from query.";
5602+
CHECK_EQ_OR_RETURN(head_size_padded, value->shape()->At(3))
5603+
<< " key has different head dims from query.";
5604+
CHECK_EQ_OR_RETURN(head_size_padded, out->shape()->At(3))
5605+
<< " out has different head dims from query.";
5606+
5607+
bool padded = head_size % 8;
5608+
5609+
auto grad_out_ = padded ? JUST(pad_last_dim<8>(grad_out)) : grad_out;
5610+
5611+
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal",
5612+
"window_size_left", "window_size_right");
5613+
attrs.SetAllAttrs(dropout_p, scale, is_causal, -1, -1);
5614+
5615+
auto output = std::make_shared<TensorTuple>(3);
5616+
auto output_ = JUST(OpInterpUtil::Dispatch<TensorTuple>(
5617+
*op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs));
5618+
CHECK_EQ(output_->size(), 3);
5619+
auto grad_q_ = (*output_)[0];
5620+
auto grad_k_ = (*output_)[1];
5621+
auto grad_v_ = (*output_)[2];
5622+
5623+
std::shared_ptr<Tensor> grad_q_padded, grad_k_padded, grad_v_padded;
5624+
5625+
bool expanded = num_heads != num_heads_k;
5626+
5627+
grad_q_padded = grad_q_;
5628+
if (expanded) {
5629+
grad_k_padded = JUST(functional::ReduceSum(
5630+
JUST(functional::Reshape(grad_k_, {batch_size, seqlen_k, num_heads_k,
5631+
num_heads / num_heads_k, head_size_padded})),
5632+
{3}, false, grad_k_->dtype()));
5633+
grad_v_padded = JUST(functional::ReduceSum(
5634+
JUST(functional::Reshape(grad_v_, {batch_size, seqlen_k, num_heads_k,
5635+
num_heads / num_heads_k, head_size_padded})),
5636+
{3}, false, grad_v_->dtype()));
5637+
} else {
5638+
grad_k_padded = grad_k_;
5639+
grad_v_padded = grad_v_;
5640+
}
5641+
5642+
auto grad_q = padded ? JUST(functional::Slice(grad_q_padded, {0, 0, 0, 0},
5643+
{batch_size, seqlen_q, num_heads, head_size},
5644+
{1, 1, 1, 1}, false))
5645+
: grad_q_padded;
5646+
auto grad_k = padded ? JUST(functional::Slice(grad_k_padded, {0, 0, 0, 0},
5647+
{batch_size, seqlen_k, num_heads_k, head_size},
5648+
{1, 1, 1, 1}, false))
5649+
: grad_k_padded;
5650+
auto grad_v = padded ? JUST(functional::Slice(grad_v_padded, {0, 0, 0, 0},
5651+
{batch_size, seqlen_k, num_heads_k, head_size},
5652+
{1, 1, 1, 1}, false))
5653+
: grad_v_padded;
5654+
5655+
(*output)[0] = grad_q;
5656+
(*output)[1] = grad_k;
5657+
(*output)[2] = grad_v;
5658+
return output;
5659+
5660+
#endif // CUDA_VERSION >= 11070
5661+
5662+
UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070.";
5663+
}
5664+
5665+
private:
5666+
#if CUDA_VERSION >= 11070
5667+
std::shared_ptr<OpExpr> op_;
5668+
#endif // CUDA_VERSION >= 11070
5669+
};
55415670
} // namespace impl
55425671

55435672
ONEFLOW_FUNCTION_LIBRARY(m) {
@@ -5676,6 +5805,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
56765805
m.add_functor<impl::MultiTensorYoloV5WeightUpdateFunctor>("MultiTensorYoloV5WeightUpdate");
56775806
m.add_functor<impl::FusedClipGradFunctor>("FusedClipGrad");
56785807
m.add_functor<impl::ScaledDotProductFlashAttentionFunctor>("ScaledDotProductFlashAttention");
5808+
m.add_functor<impl::ScaledDotProductFlashAttentionGradFunctor>(
5809+
"ScaledDotProductFlashAttentionGrad");
56795810
}
56805811

56815812
} // namespace functional

oneflow/ir/include/OneFlow/OneFlowUserOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2903,6 +2903,35 @@ def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_produc
29032903
let has_data_type_infer_fn = 1;
29042904
}
29052905

2906+
def OneFlow_ScaledDotProductFlashAttentionGradOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
2907+
let input = (ins
2908+
OneFlow_Tensor:$grad_out,
2909+
OneFlow_Tensor:$query,
2910+
OneFlow_Tensor:$key,
2911+
OneFlow_Tensor:$value,
2912+
OneFlow_Tensor:$out,
2913+
OneFlow_Tensor:$softmax_lse,
2914+
OneFlow_Tensor:$rng_state,
2915+
Optional<OneFlow_Tensor>:$alibi_slopes_
2916+
);
2917+
let output = (outs
2918+
OneFlow_Tensor:$grad_q,
2919+
OneFlow_Tensor:$grad_k,
2920+
OneFlow_Tensor:$grad_v
2921+
);
2922+
let attrs = (ins
2923+
DefaultValuedAttr<F32Attr, "0.">:$p_dropout,
2924+
DefaultValuedAttr<F32Attr, "0.">:$softmax_scale,
2925+
DefaultValuedAttr<BoolAttr, "false">:$is_causal,
2926+
SI32Attr:$window_size_left,
2927+
SI32Attr:$window_size_right
2928+
);
2929+
let has_logical_tensor_desc_infer_fn = 1;
2930+
let has_physical_tensor_desc_infer_fn = 1;
2931+
let has_get_sbp_fn = 1;
2932+
let has_data_type_infer_fn = 1;
2933+
}
2934+
29062935
def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
29072936
let input = (ins
29082937
OneFlow_Tensor:$query,

0 commit comments

Comments
 (0)