xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorTransformerUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 
3 namespace at::native::preprocessing {
4 
5 /**
6  * This function will take nested query, key, and value
7  * and will preprocess it in order to run with either
8  * the flash-attention or efficient-attention kernels.
9  * @return A tuple containing all the necessary data for running the fused
10  * kernels
11  */
12 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor>
13 sdpa_nested_preprocessing(
14     const Tensor& query,
15     const Tensor& key,
16     const Tensor& value);
17 
18 /**
19  * This function will take nested query, key, and value, grad_out, and out
20  * and will preprocess it in order to run with either
21  * the flash-attention or efficient-attention kernels backwards.
22  * We use both functions to avoid having to do the same preprocessing
23  * for cumulative_sequence_length_q and cumulative_sequence_length_kv
24  * @return A tuple containing all the necessary data for running the fused
25  * kernels
26  */
27 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
28 sdpa_nested_preprocessing_backward(
29     const at::Tensor& grad_out_,
30     const at::Tensor& query,
31     const at::Tensor& key,
32     const at::Tensor& value,
33     const at::Tensor& out,
34     const Tensor& cumulative_sequence_length_q,
35     const Tensor& cumulative_sequence_length_kv,
36     const int64_t max_seqlen_batch_q,
37     const int64_t max_seqlen_batch_kv);
38 
39 } // namespace at::native::preprocessing
40