xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/transformers/sdp_utils_cpp.h>
2 namespace sdp {
3 namespace {
4 
priority_order_cpp(sdp_params const & params)5 std::array<SDPBackend, num_backends> priority_order_cpp(sdp_params const& params) {
6   constexpr std::array<SDPBackend, num_backends> default_order{
7       SDPBackend::flash_attention,
8       SDPBackend::math};
9 
10   return default_order;
11 }
12 
check_head_dim_size_cpp(sdp_params const & params,bool debug)13 bool check_head_dim_size_cpp(sdp_params const& params, bool debug) {
14   const auto query_size_last = params.query.sym_size(-1);
15   const auto key_size_last = params.key.sym_size(-1);
16   const auto value_size_last = params.value.sym_size(-1);
17   if (!(query_size_last == key_size_last &&
18         query_size_last == value_size_last)) {
19     if (debug) {
20       TORCH_WARN(
21           "Flash attention requires q,k,v to have the same last dimension.",
22           " Got Query.size(-1): ",
23           query_size_last,
24           ", Key.size(-1): ",
25           params.key.sym_size(-1),
26           ", Value.size(-1): ",
27           params.value.sym_size(-1),
28           " instead.");
29     }
30     return false;
31   }
32   return true;
33 }
34 
use_flash_attention_cpp(sdp_params const & params,bool debug)35 bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
36   constexpr auto cpp_supported_flash_dtypes =
37       array_of<at::ScalarType>(at::kFloat, at::kDouble, at::kBFloat16, at::kHalf);
38 
39   // Define gate functions that determine if a flash kernel can be run
40   constexpr auto constraints = array_of<bool (*)(sdp_params const&, bool)>(
41       check_runtime_disabled_flash,
42       check_nested_tensor,
43       check_for_dropout,
44       check_tensor_shapes,
45       check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>,
46       check_attn_mask_shape,
47       check_head_dim_size_cpp,
48       check_nonzero_sequence_lengths_dense,
49       check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>);
50   for (auto& constraint : constraints) {
51     if (!constraint(params, debug)) {
52       return false;
53     }
54   }
55 
56   return check_tensor_dtype(params, cpp_supported_flash_dtypes, debug);
57 }
58 } // namespace
59 
select_sdp_backend_cpp(sdp_params const & kernel_params)60 SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params) {
61   // This function defines the priority order of the different sdp backends
62   // 1. Flash Attention
63   // 2. Math fallback
64   auto& ctx = at::globalContext();
65   if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP()) {
66     return SDPBackend::error;
67   }
68   // Get ideal kernel ordering
69   const auto ordering = priority_order_cpp(kernel_params);
70 
71   // Because TORCHCHECK checks if condition is true we negate debug so that
72   // The statements will be printed when debug is true
73   bool print_debug = false;
74   for (auto& backend : ordering) {
75     switch (backend) {
76       case SDPBackend::flash_attention:
77         if (use_flash_attention_cpp(kernel_params, print_debug)) {
78           return SDPBackend::flash_attention;
79         }
80         break;
81       case SDPBackend::math:
82         if (ctx.userEnabledMathSDP()) {
83           return SDPBackend::math;
84         }
85         break;
86       default:
87         TORCH_CHECK(false, "Invalid backend");
88     }
89   }
90   // If we have gotten to this point then two things have happened:
91   // 1. use_flash_attention did not satisfy the
92   // constraints to be ran
93   // 2. The user has explicitly disabled the math kernel
94   // We then re-run the kernel checks with debug enabled to print out the
95   // reason why the kernel was not selected
96 
97   print_debug = true;
98   TORCH_WARN("Flash attention kernel not used because:");
99   use_flash_attention_cpp(kernel_params, print_debug);
100   TORCH_CHECK(!print_debug, "No available kernel.  Aborting execution.")
101   return SDPBackend::error;
102 }
103 } // namespace sdp
104