xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/attention.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <c10/macros/Export.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/transformers/attention.h>
6 #include <optional>
7 
8 namespace at {
9 namespace native {
10 
11 using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
12         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
13 
14 DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
15 
16 TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b);
17 TORCH_API Tensor masked_softmax(
18     Tensor& attn_scores,
19     std::optional<Tensor> attn_mask,
20     const Tensor& query,
21     std::optional<int64_t> mask_type = {});
22 
23 using transform_bias_rescale_qkv_fn = void(*)(
24     at::ScalarType type,
25     void* _q_k_v,
26     const void* _qkv,
27     const void* _qkv_bias,
28     int64_t B,
29     int64_t T,
30     int64_t D,
31     int64_t num_head);
32 
33 DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub);
34 
35 TORCH_API Tensor transform0213_gemm_nt_bias(
36     const Tensor& a,
37     const Tensor& b,
38     const Tensor& c,
39     const Tensor& query);
40 
41 TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b);
42 
43 TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape);
44 
45 TORCH_API Tensor qkv_projection(
46     const Tensor& query,
47     const Tensor& key,
48     const Tensor& value,
49     const int64_t embed_dim,
50     const Tensor& qkv_weight);
51 
52 using flash_attention_fn = void (*)(
53     const Tensor& output, const Tensor& logsumexp,
54     const Tensor& query, const Tensor& key, const Tensor& value,
55     double dropout_p, bool is_causal,
56     std::optional<Tensor> attn_mask,
57     std::optional<double> scale);
58 
59 using flash_attention_backward_fn = void (*)(
60     const Tensor& grad_q, const Tensor& grad_k,
61     const Tensor& grad_v, const Tensor& grad_out,
62     const Tensor& query, const Tensor& key,
63     const Tensor& value, const Tensor& out, const Tensor& logsumexp,
64     double dropout_p, bool is_causal,
65     std::optional<Tensor> attn_mask,
66     std::optional<double> scale);
67 
68 DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel);
69 DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel);
70 
71 } // namespace native
72 } // namespace at
73