xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * Transformer-specific NestedTensor utility functions.
3  *
4  * Not co-located with NestedTensor core code yet because they only
5  * support specific cases needed in transformers.
6  */
7 #pragma once
8 
9 #include <vector>
10 
11 #include <c10/macros/Macros.h>
12 #include <optional>
13 
14 namespace c10 {
15 class Scalar;
16 } // namespace c10
17 
18 namespace at {
19 class Tensor;
20 namespace native {
21 struct NestedTensorImpl;
22 
23 // Requires that self is a contiguous NestedTensor, other is not a
24 // NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self
25 // must have a consistent last dimension across its included Tensors
26 // and that dimension must match other.size(0).
27 Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other);
28 
29 // Requires that mat1 is a contiguous NestedTensor, self & mat2 are
30 // not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1
31 // has a consistent last dimension across its included Tensors that
32 // matches mat2.size(0).
33 Tensor NestedTensor_times_Tensor_plus_Tensor_addmm(
34     const Tensor& self,
35     const Tensor& mat1,
36     const Tensor& mat2,
37     const c10::Scalar& beta,
38     const c10::Scalar& alpha,
39     std::optional<bool> use_gelu = std::nullopt);
40 
41 Tensor NestedTensor_add_NestedTensor_in_place(
42     const Tensor& self,
43     const Tensor& other);
44 
45 TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor(
46     const Tensor& sizes,
47     int64_t extra_elements);
48 
49 Tensor NestedTensor_from_padded_tensor_cpu(
50     const Tensor& padded,
51     const NestedTensorImpl& nt);
52 
53 Tensor NestedTensor_to_mask(const Tensor& nt, std::optional<int64_t> mask_dim, std::optional<int64_t> mask_dim_length);
54 
55 template <typename T>
56 void remove_padding_kernelLauncher(
57     const T* input,
58     T* output,
59     const int* offsets,
60     const int* input_sizes,
61     const int* output_sizes,
62     int output_dim,
63     const int batch_size);
64 
65 template <typename T>
66 void remove_padding_transform0213_kernelLauncher(
67     const T* input,
68     T* output,
69     const int* offsets,
70     const int* input_sizes,
71     const int* output_sizes,
72     int output_dim,
73     const int batch_size);
74 
75 template <typename T>
76 void add_padding_kernelLauncher(
77     T* input,
78     T* output,
79     T padding_value,
80     const int* offsets,
81     const int* input_sizes,
82     int input_dim,
83     const std::vector<int64_t>& output_sizes,
84     const int batch_size,
85     const int output_batch_size);
86 
87 TORCH_API Tensor flash_attention_helper(
88     const Tensor& query,
89     const Tensor& key,
90     const Tensor& value,
91     double dropout_p,
92     bool need_attn_weights,
93     bool is_causal);
94 
95 TORCH_API std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
96     const Tensor& query,
97     const Tensor& key,
98     const Tensor& value,
99     double dropout_p,
100     bool need_attn_weights,
101     bool is_causal);
102 } // namespace native
103 } // namespace at
104