1 #include <ATen/ATen.h> 2 3 namespace at::native::preprocessing { 4 5 /** 6 * This function will take nested query, key, and value 7 * and will preprocess it in order to run with either 8 * the flash-attention or efficient-attention kernels. 9 * @return A tuple containing all the necessary data for running the fused 10 * kernels 11 */ 12 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor> 13 sdpa_nested_preprocessing( 14 const Tensor& query, 15 const Tensor& key, 16 const Tensor& value); 17 18 /** 19 * This function will take nested query, key, and value, grad_out, and out 20 * and will preprocess it in order to run with either 21 * the flash-attention or efficient-attention kernels backwards. 22 * We use both functions to avoid having to do the same preprocessing 23 * for cumulative_sequence_length_q and cumulative_sequence_length_kv 24 * @return A tuple containing all the necessary data for running the fused 25 * kernels 26 */ 27 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> 28 sdpa_nested_preprocessing_backward( 29 const at::Tensor& grad_out_, 30 const at::Tensor& query, 31 const at::Tensor& key, 32 const at::Tensor& value, 33 const at::Tensor& out, 34 const Tensor& cumulative_sequence_length_q, 35 const Tensor& cumulative_sequence_length_kv, 36 const int64_t max_seqlen_batch_q, 37 const int64_t max_seqlen_batch_kv); 38 39 } // namespace at::native::preprocessing 40