1 #pragma once 2 #include <ATen/core/Tensor.h> 3 #include <c10/macros/Export.h> 4 #include <ATen/native/DispatchStub.h> 5 #include <ATen/native/transformers/attention.h> 6 #include <optional> 7 8 namespace at { 9 namespace native { 10 11 using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value, 12 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa); 13 14 DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub); 15 16 TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b); 17 TORCH_API Tensor masked_softmax( 18 Tensor& attn_scores, 19 std::optional<Tensor> attn_mask, 20 const Tensor& query, 21 std::optional<int64_t> mask_type = {}); 22 23 using transform_bias_rescale_qkv_fn = void(*)( 24 at::ScalarType type, 25 void* _q_k_v, 26 const void* _qkv, 27 const void* _qkv_bias, 28 int64_t B, 29 int64_t T, 30 int64_t D, 31 int64_t num_head); 32 33 DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub); 34 35 TORCH_API Tensor transform0213_gemm_nt_bias( 36 const Tensor& a, 37 const Tensor& b, 38 const Tensor& c, 39 const Tensor& query); 40 41 TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b); 42 43 TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape); 44 45 TORCH_API Tensor qkv_projection( 46 const Tensor& query, 47 const Tensor& key, 48 const Tensor& value, 49 const int64_t embed_dim, 50 const Tensor& qkv_weight); 51 52 using flash_attention_fn = void (*)( 53 const Tensor& output, const Tensor& logsumexp, 54 const Tensor& query, const Tensor& key, const Tensor& value, 55 double dropout_p, bool is_causal, 56 std::optional<Tensor> attn_mask, 57 std::optional<double> scale); 58 59 using flash_attention_backward_fn = void (*)( 60 const Tensor& grad_q, const Tensor& grad_k, 61 const Tensor& grad_v, const Tensor& grad_out, 62 const Tensor& query, const Tensor& key, 63 const Tensor& value, const Tensor& out, const Tensor& logsumexp, 64 double dropout_p, bool is_causal, 65 std::optional<Tensor> attn_mask, 66 std::optional<double> scale); 67 68 DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel); 69 DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel); 70 71 } // namespace native 72 } // namespace at 73