@@ -5538,6 +5538,135 @@ class ScaledDotProductFlashAttentionFunctor {
5538
5538
#endif // CUDA_VERSION >= 11070
5539
5539
};
5540
5540
5541
+ class ScaledDotProductFlashAttentionGradFunctor {
5542
+ public:
5543
+ ScaledDotProductFlashAttentionGradFunctor () {
5544
+ #if CUDA_VERSION >= 11070
5545
+ op_ = CHECK_JUST (one::OpBuilder (" scaled_dot_product_flash_attention_grad" )
5546
+ .Input (" grad_out" )
5547
+ .Input (" query" )
5548
+ .Input (" key" )
5549
+ .Input (" value" )
5550
+ .Input (" out" )
5551
+ .Input (" softmax_lse" )
5552
+ .Input (" rng_state" )
5553
+ .Output (" grad_q" )
5554
+ .Output (" grad_k" )
5555
+ .Output (" grad_v" )
5556
+ .Build ());
5557
+ #endif
5558
+ }
5559
+
5560
+ Maybe<TensorTuple> operator ()(
5561
+ const std::shared_ptr<one::Tensor>& grad_out, const std::shared_ptr<one::Tensor>& query,
5562
+ const std::shared_ptr<one::Tensor>& key, const std::shared_ptr<one::Tensor>& value,
5563
+ const std::shared_ptr<one::Tensor>& out, const std::shared_ptr<one::Tensor>& softmax_lse,
5564
+ const std::shared_ptr<one::Tensor>& rng_state, const float & dropout_p, const bool & is_causal,
5565
+ const float & scale) const {
5566
+ #if CUDA_VERSION >= 11070
5567
+ // grad_out(batch x q_sqe_len x num_heads x head_size)
5568
+ // query (batch x q_seq_len x num_heads x head_size_padded)
5569
+ // key (batch x kv_seq_len x num_heads_k x head_size_padded)
5570
+ // value (batch x kv_seq_len x num_heads_k x head_size_padded)
5571
+ // out (batch x kv_seq_len x num_heads x head_size_padded)
5572
+ // softmax_lse (batch x num_heads x q_seq_len)
5573
+ const auto head_size = grad_out->shape ()->At (3 );
5574
+ const auto head_size_padded = query->shape ()->At (3 );
5575
+ const auto batch_size = query->shape ()->At (0 );
5576
+ const auto seqlen_q = query->shape ()->At (1 );
5577
+ const auto seqlen_k = key->shape ()->At (1 );
5578
+ const auto num_heads = query->shape ()->At (2 );
5579
+ const auto num_heads_k = key->shape ()->At (2 );
5580
+ CHECK_EQ_OR_RETURN (batch_size, key->shape ()->At (0 ))
5581
+ << " key has different batch size from query." ;
5582
+ CHECK_EQ_OR_RETURN (batch_size, value->shape ()->At (0 ))
5583
+ << " value has different batch size from query." ;
5584
+ CHECK_EQ_OR_RETURN (batch_size, grad_out->shape ()->At (0 ))
5585
+ << " grad_out has different batch size from query." ;
5586
+ CHECK_EQ_OR_RETURN (batch_size, out->shape ()->At (0 ))
5587
+ << " out has different batch size from query." ;
5588
+ CHECK_EQ_OR_RETURN (batch_size, softmax_lse->shape ()->At (0 ))
5589
+ << " softmax_lse has different batch size from query." ;
5590
+ CHECK_EQ_OR_RETURN (num_heads, grad_out->shape ()->At (2 ))
5591
+ << " grad_out has different num_heads from query." ;
5592
+ CHECK_EQ_OR_RETURN (num_heads, softmax_lse->shape ()->At (1 ))
5593
+ << " softmax_lse has different num_heads from query." ;
5594
+ CHECK_EQ_OR_RETURN (num_heads_k, value->shape ()->At (2 ))
5595
+ << " value has different num_heads from key." ;
5596
+ CHECK_EQ_OR_RETURN (seqlen_q, grad_out->shape ()->At (1 ))
5597
+ << " grad_out has different seq_len from query." ;
5598
+ CHECK_EQ_OR_RETURN (seqlen_q, softmax_lse->shape ()->At (2 ))
5599
+ << " softmax_lse has different seq_len from query." ;
5600
+ CHECK_EQ_OR_RETURN (head_size_padded, key->shape ()->At (3 ))
5601
+ << " key has different head dims from query." ;
5602
+ CHECK_EQ_OR_RETURN (head_size_padded, value->shape ()->At (3 ))
5603
+ << " key has different head dims from query." ;
5604
+ CHECK_EQ_OR_RETURN (head_size_padded, out->shape ()->At (3 ))
5605
+ << " out has different head dims from query." ;
5606
+
5607
+ bool padded = head_size % 8 ;
5608
+
5609
+ auto grad_out_ = padded ? JUST (pad_last_dim<8 >(grad_out)) : grad_out;
5610
+
5611
+ auto & attrs = THREAD_CACHED_MUTABLE_ATTR_MAP (" p_dropout" , " softmax_scale" , " is_causal" ,
5612
+ " window_size_left" , " window_size_right" );
5613
+ attrs.SetAllAttrs (dropout_p, scale, is_causal, -1 , -1 );
5614
+
5615
+ auto output = std::make_shared<TensorTuple>(3 );
5616
+ auto output_ = JUST (OpInterpUtil::Dispatch<TensorTuple>(
5617
+ *op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs));
5618
+ CHECK_EQ (output_->size (), 3 );
5619
+ auto grad_q_ = (*output_)[0 ];
5620
+ auto grad_k_ = (*output_)[1 ];
5621
+ auto grad_v_ = (*output_)[2 ];
5622
+
5623
+ std::shared_ptr<Tensor> grad_q_padded, grad_k_padded, grad_v_padded;
5624
+
5625
+ bool expanded = num_heads != num_heads_k;
5626
+
5627
+ grad_q_padded = grad_q_;
5628
+ if (expanded) {
5629
+ grad_k_padded = JUST (functional::ReduceSum (
5630
+ JUST (functional::Reshape (grad_k_, {batch_size, seqlen_k, num_heads_k,
5631
+ num_heads / num_heads_k, head_size_padded})),
5632
+ {3 }, false , grad_k_->dtype ()));
5633
+ grad_v_padded = JUST (functional::ReduceSum (
5634
+ JUST (functional::Reshape (grad_v_, {batch_size, seqlen_k, num_heads_k,
5635
+ num_heads / num_heads_k, head_size_padded})),
5636
+ {3 }, false , grad_v_->dtype ()));
5637
+ } else {
5638
+ grad_k_padded = grad_k_;
5639
+ grad_v_padded = grad_v_;
5640
+ }
5641
+
5642
+ auto grad_q = padded ? JUST (functional::Slice (grad_q_padded, {0 , 0 , 0 , 0 },
5643
+ {batch_size, seqlen_q, num_heads, head_size},
5644
+ {1 , 1 , 1 , 1 }, false ))
5645
+ : grad_q_padded;
5646
+ auto grad_k = padded ? JUST (functional::Slice (grad_k_padded, {0 , 0 , 0 , 0 },
5647
+ {batch_size, seqlen_k, num_heads_k, head_size},
5648
+ {1 , 1 , 1 , 1 }, false ))
5649
+ : grad_k_padded;
5650
+ auto grad_v = padded ? JUST (functional::Slice (grad_v_padded, {0 , 0 , 0 , 0 },
5651
+ {batch_size, seqlen_k, num_heads_k, head_size},
5652
+ {1 , 1 , 1 , 1 }, false ))
5653
+ : grad_v_padded;
5654
+
5655
+ (*output)[0 ] = grad_q;
5656
+ (*output)[1 ] = grad_k;
5657
+ (*output)[2 ] = grad_v;
5658
+ return output;
5659
+
5660
+ #endif // CUDA_VERSION >= 11070
5661
+
5662
+ UNIMPLEMENTED_THEN_RETURN () << " only support CUDA_VERSION >= 11070." ;
5663
+ }
5664
+
5665
+ private:
5666
+ #if CUDA_VERSION >= 11070
5667
+ std::shared_ptr<OpExpr> op_;
5668
+ #endif // CUDA_VERSION >= 11070
5669
+ };
5541
5670
} // namespace impl
5542
5671
5543
5672
ONEFLOW_FUNCTION_LIBRARY (m) {
@@ -5676,6 +5805,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
5676
5805
m.add_functor <impl::MultiTensorYoloV5WeightUpdateFunctor>(" MultiTensorYoloV5WeightUpdate" );
5677
5806
m.add_functor <impl::FusedClipGradFunctor>(" FusedClipGrad" );
5678
5807
m.add_functor <impl::ScaledDotProductFlashAttentionFunctor>(" ScaledDotProductFlashAttention" );
5808
+ m.add_functor <impl::ScaledDotProductFlashAttentionGradFunctor>(
5809
+ " ScaledDotProductFlashAttentionGrad" );
5679
5810
}
5680
5811
5681
5812
} // namespace functional
0 commit comments