1 /** 2 * Transformer-specific NestedTensor utility functions. 3 * 4 * Not co-located with NestedTensor core code yet because they only 5 * support specific cases needed in transformers. 6 */ 7 #pragma once 8 9 #include <vector> 10 11 #include <c10/macros/Macros.h> 12 #include <optional> 13 14 namespace c10 { 15 class Scalar; 16 } // namespace c10 17 18 namespace at { 19 class Tensor; 20 namespace native { 21 struct NestedTensorImpl; 22 23 // Requires that self is a contiguous NestedTensor, other is not a 24 // NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self 25 // must have a consistent last dimension across its included Tensors 26 // and that dimension must match other.size(0). 27 Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other); 28 29 // Requires that mat1 is a contiguous NestedTensor, self & mat2 are 30 // not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1 31 // has a consistent last dimension across its included Tensors that 32 // matches mat2.size(0). 33 Tensor NestedTensor_times_Tensor_plus_Tensor_addmm( 34 const Tensor& self, 35 const Tensor& mat1, 36 const Tensor& mat2, 37 const c10::Scalar& beta, 38 const c10::Scalar& alpha, 39 std::optional<bool> use_gelu = std::nullopt); 40 41 Tensor NestedTensor_add_NestedTensor_in_place( 42 const Tensor& self, 43 const Tensor& other); 44 45 TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor( 46 const Tensor& sizes, 47 int64_t extra_elements); 48 49 Tensor NestedTensor_from_padded_tensor_cpu( 50 const Tensor& padded, 51 const NestedTensorImpl& nt); 52 53 Tensor NestedTensor_to_mask(const Tensor& nt, std::optional<int64_t> mask_dim, std::optional<int64_t> mask_dim_length); 54 55 template <typename T> 56 void remove_padding_kernelLauncher( 57 const T* input, 58 T* output, 59 const int* offsets, 60 const int* input_sizes, 61 const int* output_sizes, 62 int output_dim, 63 const int batch_size); 64 65 template <typename T> 66 void remove_padding_transform0213_kernelLauncher( 67 const T* input, 68 T* output, 69 const int* offsets, 70 const int* input_sizes, 71 const int* output_sizes, 72 int output_dim, 73 const int batch_size); 74 75 template <typename T> 76 void add_padding_kernelLauncher( 77 T* input, 78 T* output, 79 T padding_value, 80 const int* offsets, 81 const int* input_sizes, 82 int input_dim, 83 const std::vector<int64_t>& output_sizes, 84 const int batch_size, 85 const int output_batch_size); 86 87 TORCH_API Tensor flash_attention_helper( 88 const Tensor& query, 89 const Tensor& key, 90 const Tensor& value, 91 double dropout_p, 92 bool need_attn_weights, 93 bool is_causal); 94 95 TORCH_API std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked( 96 const Tensor& query, 97 const Tensor& key, 98 const Tensor& value, 99 double dropout_p, 100 bool need_attn_weights, 101 bool is_causal); 102 } // namespace native 103 } // namespace at 104