xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/NestedTensorImpl.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/detail/CUDAHooksInterface.h>
10 #include <ATen/native/DispatchStub.h>
11 #include <ATen/native/transformers/cuda/sdp_utils.h>
12 #include <ATen/native/transformers/sdp_utils_cpp.h>
13 #include <c10/core/ScalarType.h>
14 #include <c10/util/Exception.h>
15 #include <c10/util/env.h>
16 #include <c10/util/irange.h>
17 #include <c10/util/CallOnce.h>
18 
19 #include <c10/core/SymInt.h>
20 #include <c10/util/string_view.h>
21 
22 #if USE_ROCM
23 #include <aotriton/flash.h>
24 #endif
25 
26 /**
27 * Note [SDPA Runtime Dispatch]
28 * SDPA relies on a runtime dispatch mechanism to select the appropriate
29 * kernel. This file contains exposes this through the `select_sdp_backend`
30 * The basic structure of this function is to call `priority_order` to get a
31 * list of backends to try, and then iterate through them until one succeeds.
32 * Each backend defines a use_<backend> function that returns true if the
33 * backend can be run with the given SDP parameters. The use_<backend> function
34 * will iterate over a list of "filters" that check for specific properties of
35 * the SDP parameters. If all filters pass, the backend can be used and use_<backend>
36 * returns true. If any filter fails, then use_<backend> returns false.
37 *
38 * In order to aid in debugging, each filter takes sdp_params and a debug flag.
39 * If the debug flag is set, the filter will print a warning message if it fails.
40 * The behavior of select_sdp_backend is to return the first backend that
41 * succeeds. If no backend is viable then it will run each use_<backend> function
42 * with debug=true and return SDPBackend::error.
43 */
44 
45 namespace sdp {
46 namespace {
47 // flash_attention V2 is universally faster than efficient_attention and Math
priority_order(sdp_params const & params)48 std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
49   constexpr std::array<SDPBackend, num_backends> default_order{
50       SDPBackend::cudnn_attention,
51       SDPBackend::flash_attention,
52       SDPBackend::efficient_attention,
53       SDPBackend::math};
54   return default_order;
55 }
56 
use_tensor_cores(sdp_params const & params,cudaDeviceProp * dprops,bool is_half)57 bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {
58   if (dprops->major >= 8) {
59     return true;
60   }
61   if (dprops->major >= 7) {
62     return is_half;
63   }
64   return false;
65 }
minimum_gemm_alignment(sdp_params const & params)66 int64_t minimum_gemm_alignment(sdp_params const& params) {
67   auto dprops = at::cuda::getCurrentDeviceProperties();
68   bool is_half = (params.query.dtype() == at::kHalf) ||
69       (params.query.dtype() == at::kBFloat16);
70   bool use_tc = use_tensor_cores(params, dprops, is_half);
71   int64_t matmul_alignment_mn = 1;
72   if (dprops->major >= 8) {
73     matmul_alignment_mn = 4;
74   }
75   int64_t bits_per_scalar = is_half ? 16 : 32;
76   if (use_tc) {
77     matmul_alignment_mn = std::max(matmul_alignment_mn, 128 / bits_per_scalar);
78   }
79   return matmul_alignment_mn;
80 }
81 
check_head_dim_size_flash(sdp_params const & params,bool debug)82 bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
83   // All head_dim sizes must be equal and less than 256
84   const auto max_size = c10::SymInt(256);
85   const auto query_size_last = params.query.sym_size(-1);
86   const auto key_size_last = params.key.sym_size(-1);
87   const auto value_size_last = params.value.sym_size(-1);
88   bool same_head_dim_size =
89       query_size_last == key_size_last && query_size_last == value_size_last;
90   if (!(same_head_dim_size && (query_size_last <= max_size))) {
91     if (debug) {
92       TORCH_WARN(
93           "Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.",
94           " Got Query.size(-1): ",
95           query_size_last,
96           ", Key.size(-1): ",
97           key_size_last,
98           ", Value.size(-1): ",
99           value_size_last,
100           " instead.");
101     }
102     return false;
103   }
104   return true;
105 }
106 
check_head_dim_size_flash_nested(sdp_params const & params,bool debug)107 bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
108   const auto max_size = c10::SymInt(256);
109   const auto query_size_last = params.query.sym_size(-1);
110   const auto key_size_last = params.key.sym_size(-1);
111   const auto value_size_last = params.value.sym_size(-1);
112   bool same_head_dim_size =
113       query_size_last == key_size_last && query_size_last == value_size_last;
114   if (!(same_head_dim_size && (query_size_last % 8 == 0) &&
115         (query_size_last <= max_size))) {
116     if (debug) {
117       TORCH_WARN(
118           "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
119           " Got Query.size(-1): ",
120           query_size_last,
121           ", Key.size(-1): ",
122           params.key.sym_size(-1),
123           ", Value.size(-1): ",
124           params.value.sym_size(-1),
125           " instead.");
126     }
127     return false;
128   }
129   return true;
130 }
131 
check_head_dim_size_mem_efficient(sdp_params const & params,bool debug)132 bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) {
133   const auto query_size_last = params.query.sym_size(-1);
134   const auto value_size_last = params.value.sym_size(-1);
135   const int64_t alignment = minimum_gemm_alignment(params);
136   if (!(query_size_last == params.key.sym_size(-1) &&
137         query_size_last % alignment == 0 && query_size_last > 0 &&
138         value_size_last % alignment == 0 && value_size_last > 0)) {
139     if (debug) {
140       TORCH_WARN(
141           "Mem efficient attention requires last dimension of inputs to be divisible by ",
142           alignment,
143           ". ",
144           "Got Query.size(-1): ",
145           query_size_last,
146           ", Key.size(-1): ",
147           params.key.sym_size(-1),
148           ", Value.size(-1): ",
149           params.value.sym_size(-1),
150           " instead.");
151     }
152     return false;
153   }
154   return true;
155 }
156 
157 template <int Major, int Minor>
158 struct SMVersion {
159   static constexpr int major = Major;
160   static constexpr int minor = Minor;
161   constexpr SMVersion() = default;
162 };
163 
164 /**
165  * Checks if the current CUDA device architecture is inclusively within the specified range.
166  *
167  * @param lower_bound The lower bound of the CUDA device architecture range.
168  * @param upper_bound The upper bound of the CUDA device architecture range.
169  * @param params The parameters for the current operation.
170  * @return True if the current CUDA device architecture is within the specified range, false otherwise.
171  */
172 template <typename lower_bound, typename upper_bound>
check_sm_version(cudaDeviceProp * dprops)173 bool check_sm_version(cudaDeviceProp * dprops) {
174   bool is_gte_lower_bound = dprops->major > lower_bound::major ||
175       (dprops->major == lower_bound::major &&
176        dprops->minor >= lower_bound::minor);
177   bool is_lte_upper_bound = dprops->major < upper_bound::major ||
178       (dprops->major == upper_bound::major &&
179        dprops->minor <= upper_bound::minor);
180   return is_gte_lower_bound && is_lte_upper_bound;
181 }
182 
check_flash_attention_hardware_support(sdp_params const & params,bool debug)183 bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
184   // Check that the gpu is capable of running flash attention
185   using sm80 = SMVersion<8, 0>;
186   using sm90 = SMVersion<9, 0>;
187 #if USE_ROCM
188   auto stream = at::cuda::getCurrentCUDAStream().stream();
189   if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
190       auto dprops = at::cuda::getCurrentDeviceProperties();
191       if (debug) {
192           TORCH_WARN(
193                   "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
194       }
195       return false;
196   }
197 #else
198   auto dprops = at::cuda::getCurrentDeviceProperties();
199   if (!check_sm_version<sm80, sm90>(dprops)) {
200     if (debug) {
201       TORCH_WARN(
202           "Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ",
203           dprops->major,
204           ".",
205           dprops->minor,
206           " gpu.");
207     }
208     return false;
209   }
210 #endif
211   return true;
212 }
213 
check_mem_efficient_hardware_support(sdp_params const & params,bool debug)214 bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
215   // Mem Efficient attention supports hardware in the range [sm_50, sm_90]
216   using sm50 = SMVersion<5, 0>;
217   using sm90 = SMVersion<9, 0>;
218 #if USE_ROCM
219   auto stream = at::cuda::getCurrentCUDAStream().stream();
220   if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
221       auto dprops = at::cuda::getCurrentDeviceProperties();
222       if (debug) {
223           TORCH_WARN(
224                   "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
225       }
226       return false;
227   }
228 #else
229   auto dprops = at::cuda::getCurrentDeviceProperties();
230   if (!check_sm_version<sm50, sm90>(dprops)) {
231     if (debug) {
232       TORCH_WARN(
233           "Mem Efficient Attention only supports gpu architectures in the range [sm50, sm90]. Attempting to run on a sm ",
234           dprops->major,
235           ".",
236           dprops->minor,
237           " gpu.");
238     }
239     return false;
240   }
241 #endif
242   return true;
243 }
244 
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(sdp_params const & params,bool debug)245 bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
246     sdp_params const& params,
247     bool debug) {
248   // Flash Attention will raise an error in the backward pass if the head_dim
249   // size is greater than 192 And the device is between in the range [sm86, sm89]
250   using sm86 = SMVersion<8, 6>;
251   using sm89 = SMVersion<8, 9>;
252   auto dprops = at::cuda::getCurrentDeviceProperties();
253   bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
254   bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
255   bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224;
256   bool is_dropout = params.dropout > 0.0;
257   //  head_dim size  in (192, 224] is not supported on sm86 and sm89
258   bool cond1 = is_head_dim_gt192 && is_head_dim_lte224;
259   // head_dim size > 224 and is_dropout is not supported on sm86 and sm89
260   bool cond2 = params.query.sym_size(-1) > 224 && is_dropout;
261   if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) {
262     if (debug) {
263       TORCH_WARN(
264           "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or "
265           "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].",
266           "Attempting to run with dropout set to: ", params.dropout,
267           "and head_dim: ",
268           params.query.sym_size(-1), " on a sm ", dprops->major, ".",
269           dprops->minor, " gpu.");
270     }
271     return false;
272   }
273   return true;
274 }
275 
check_flash_causal_non_square_seqlens(sdp_params const & params,bool debug)276 bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) {
277   // FlashAttention 2 updated the default mask meaning for causal in this PR:
278   // 9e5e8bc91e it is now aligned to lower_right which would be a BC break
279   // for non-square masks. We will not support non-square masks for causal w/ FAV2
280   if (params.is_causal &&
281       !params.query.is_nested() && !params.key.is_nested() &&
282       params.query.sym_size(-2) != params.key.sym_size(-2)) {
283     if (debug) {
284       TORCH_WARN(
285           "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k. ",
286           "Got seqlen_q: ", params.query.sym_size(-2), " seqlen_k: ",
287           params.key.sym_size(-2), ". If you would like to use causal attention with non-square masks, please see CausalAttnMask.");
288     }
289     return false;
290   }
291   return true;
292 }
293 
check_all_tensors_on_device(sdp_params const & params,bool debug)294 bool check_all_tensors_on_device(sdp_params const& params, bool debug) {
295   // Check that all tensors are on the GPU device
296   // This should be handled by the stub dispatch, but whe call can_use_*_attention
297   // directly from python we need to ensure that the tensors are on cuda
298   if (params.query.device().type() != at::DeviceType::CUDA) {
299     if (debug) {
300       TORCH_WARN(
301           "All tensors need to be on cuda device. Got query on device: ",
302           params.query.device(),
303           ", key on device: ",
304           params.key.device(),
305           ", value on device: ",
306           params.value.device());
307     }
308     return false;
309   }
310   return true;
311 }
312 
check_cudnn_tensor_shapes(sdp_params const & params,bool debug)313 bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
314   const auto s_q = params.query.sym_size(2);
315   const auto s_k = params.key.sym_size(2);
316   const auto head_dim = params.query.sym_size(3);
317   long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
318   if (cudnn_version >= 90000) {
319     if (head_dim % 8 != 0 || head_dim > 256) {
320       if (debug) {
321         TORCH_WARN("head_dim should be a multiple of 8 and no more than 256");
322       }
323       return false;
324     }
325   } else {
326     if (head_dim % 8 != 0 || head_dim > 128) {
327       if (debug) {
328         TORCH_WARN("head_dim should be a multiple of 8 and no more than 128");
329       }
330       return false;
331     }
332   }
333   if (cudnn_version < 8903) {
334     if (debug) {
335       TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
336     }
337     return false;
338   }
339   if (params.dropout != 0.0 && cudnn_version < 8906) {
340     if (debug) {
341       TORCH_WARN("Dropout reference is only supported on 8.9.6 onwards.");
342     }
343     return false;
344   }
345   if (cudnn_version < 90000) {
346     if (s_q < 64) {
347       if (debug) {
348         TORCH_WARN("s_q less than 64 is not supported before cudnn 9.0.0");
349       }
350       return false;
351     }
352     if ((s_q % 64 != 0 || s_k % 64 != 0) && params.dropout != 0.0) {
353       if (debug) {
354         TORCH_WARN(
355             "s_q not a multiple of 64 with padding/dropout is not supported with cudnn version 9.0.0");
356       }
357       return false;
358     }
359   }
360   if (s_k % 64 != 0 && cudnn_version < 8906) {
361     if (debug) {
362       TORCH_WARN("not-multiple-of-64 seq_kv is not supported below 8.9.6");
363     }
364     return false;
365   }
366   return true;
367 }
368 
check_cudnn_layout(sdp_params const & params,bool debug)369 bool check_cudnn_layout(sdp_params const& params, bool debug) {
370   const int64_t h = params.query.size(1);
371   const int64_t s_q = params.query.size(2);
372   const int64_t d = params.query.size(3);
373   const int64_t s_k = params.key.size(2);
374   const int64_t s_v = params.value.size(2);
375   // corresponds to cuDNN's "packed QKV" layout
376   const bool packed_query_layout_ok = (params.query.stride(0) == s_q * 3 * h * d) &&
377                                  (params.query.stride(1) == d) &&
378                                  (params.query.stride(2) == 3 * h * d) &&
379                                  (params.query.stride(3) == 1);
380   const bool packed_key_layout_ok = (params.key.stride(0) == s_k * 3 * h * d) &&
381                                (params.key.stride(1) == d) &&
382                                (params.key.stride(2) == 3 * h * d) &&
383                                (params.key.stride(3) == 1);
384   const bool packed_value_layout_ok = (params.value.stride(0) == s_v * 3 * h * d) &&
385                                  (params.value.stride(1) == d) &&
386                                  (params.value.stride(2) == 3 * h * d) &&
387                                  (params.value.stride(3) == 1);
388 
389   const bool packed_layout_ok = packed_query_layout_ok && packed_key_layout_ok && packed_value_layout_ok;
390 
391   const bool query_layout_ok = (params.query.stride(0) == s_q * h * d) &&
392                                (params.query.stride(1) == d) &&
393                                (params.query.stride(2) == h * d) &&
394                                (params.query.stride(3) == 1);
395   const bool key_layout_ok = (params.key.stride(0) == s_k * h * d) &&
396                               (params.key.stride(1) == d) &&
397                               (params.key.stride(2) == h * d) &&
398                               (params.key.stride(3) == 1);
399   const bool value_layout_ok = (params.value.stride(0) == s_v * h * d) &&
400                                (params.value.stride(1) == d) &&
401                                (params.value.stride(2) == h * d) &&
402                                (params.value.stride(3) == 1);
403 
404   const bool layout_ok = query_layout_ok && key_layout_ok && value_layout_ok;
405 
406   if (!packed_value_layout_ok && !layout_ok) {
407     if (debug) {
408       if (!packed_layout_ok) {
409         if (!packed_query_layout_ok) {
410           TORCH_WARN("Query tensor was not in cuDNN-supported packed QKV layout", params.query.strides());
411         }
412         if (!packed_key_layout_ok) {
413           TORCH_WARN("Key tensor was not in cuDNN-supported packed QKV layout", params.key.strides());
414         }
415         if (!packed_value_layout_ok) {
416           TORCH_WARN("Value tensor was not in cuDNN-supported packed QKV layout", params.value.strides());
417         }
418       }
419       if (!layout_ok) {
420         if (!query_layout_ok) {
421           TORCH_WARN("Query tensor was not in cuDNN-supported unpacked QKV layout", params.query.strides());
422         }
423         if (!key_layout_ok) {
424           TORCH_WARN("Key tensor was not in cuDNN-supported unpacked QKV layout", params.key.strides());
425         }
426         if (!value_layout_ok) {
427           TORCH_WARN("Value tensor was not in cuDNN-supported unpacked QKV layout", params.value.strides());
428         }
429       }
430     }
431     return false;
432   }
433   return true;
434 }
435 
check_cudnn_hardware_support(sdp_params const & params,bool debug)436 bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
437   using sm80 = SMVersion<8, 0>;
438   using sm90 = SMVersion<9, 0>;
439   auto dprops = at::cuda::getCurrentDeviceProperties();
440   if (!check_sm_version<sm80, sm90>(dprops)) {
441     if (debug) {
442       TORCH_WARN(
443           "cuDNN MHA only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ",
444           dprops->major,
445           ".",
446           dprops->minor,
447           " gpu.");
448     }
449     return false;
450   }
451   return true;
452 }
453 
check_is_causal(sdp_params const & params,bool debug)454 bool check_is_causal(sdp_params const& params, bool debug) {
455   // Check that the input is causal
456   if (!params.is_causal) {
457     if (debug) {
458       TORCH_WARN("CuDNN requires is_causal=True.");
459     }
460     return false;
461   }
462   return true;
463 }
464 
check_for_nested_inputs(sdp_params const & params,bool debug)465 bool check_for_nested_inputs(sdp_params const& params, bool debug) {
466   // Check that the input is nested
467   if (has_for_nested_inputs(params)) {
468     if (debug) {
469       TORCH_WARN("CuDNN currently does not support nested inputs.");
470     }
471     return false;
472   }
473   return true;
474 }
475 
check_dtypes_low_precision(sdp_params const & params,bool debug)476 bool check_dtypes_low_precision(sdp_params const& params, bool debug) {
477   auto dprop = at::cuda::getCurrentDeviceProperties();
478   if (dprop->major >= 8) {
479     constexpr auto sm80_dtypes =
480         array_of<at::ScalarType>(at::kHalf, at::kBFloat16);
481     return check_tensor_dtype(params, sm80_dtypes, debug);
482   } else {
483     constexpr auto default_dtypes = array_of<at::ScalarType>(at::kHalf);
484     return check_tensor_dtype(params, default_dtypes, debug);
485   }
486 }
487 
check_runtime_enabled_cudnn(sdp_params const & params,bool debug)488 bool check_runtime_enabled_cudnn(sdp_params const& params, bool debug) {
489   static c10::once_flag supported_flag;
490   static bool supported = false;
491   c10::call_once(supported_flag, []() {
492     supported = (c10::utils::check_env("TORCH_CUDNN_SDPA_ENABLED") == true);
493   });
494   if (!supported) {
495     if (debug) {
496       TORCH_WARN(
497           "The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1`");
498     }
499     return false;
500   }
501   return true;
502 }
503 
check_runtime_disabled_cudnn(sdp_params const & params,bool debug)504 bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
505   // We check the global context to see if user has explicitly turned of cudnn
506   // sdp kernels
507   if (!at::globalContext().userEnabledCuDNNSDP()) {
508     if (debug) {
509       TORCH_WARN("CuDNN attention has been runtime disabled.");
510     }
511     return false;
512   }
513   return true;
514 }
515 
check_cudnn_requires_grad(sdp_params const & params,bool debug)516 bool check_cudnn_requires_grad(sdp_params const& params, bool debug) {
517   // Check that the input is causal
518   if (input_requires_grad(params)) {
519     if (debug) {
520       TORCH_WARN("CuDNN does not currently support inputs with requires_grad=True.");
521     }
522     return false;
523   }
524   return true;
525 }
526 
527 } // namespace
528 
can_use_cudnn_attention(const sdp_params & params,bool debug)529 bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
530 
531   // Define gate functions that determine if a flash kernel can be ran
532   // Replace with std::to_array when we migrate to c++20
533   constexpr auto general_constraints =
534       array_of<bool (*)(sdp_params const&, bool)>(
535           check_runtime_enabled_cudnn,
536           check_runtime_disabled_cudnn,
537           check_cudnn_hardware_support,
538           check_all_tensors_on_device,
539           check_cudnn_tensor_shapes,
540           check_cudnn_layout,
541           // check_is_causal,
542           check_for_nested_inputs,
543           check_cudnn_requires_grad,
544           check_dtypes_low_precision);
545   for (auto& constraint : general_constraints) {
546     if (!constraint(params, debug)) {
547       return false;
548     }
549   }
550   return true;
551 }
552 
can_use_flash_attention(sdp_params const & params,bool debug)553 bool can_use_flash_attention(sdp_params const& params, bool debug) {
554 #ifndef USE_FLASH_ATTENTION
555   TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
556   return false;
557 #endif
558 
559   // Define gate functions that determine if a flash kernel can be ran
560   // Replace with std::to_array when we migrate to c++20
561   constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
562       check_runtime_disabled_flash,
563       check_all_tensors_on_device,
564       check_tensor_shapes,
565       check_for_attn_mask,
566       check_head_dim_size_flash,
567       check_flash_attention_hardware_support,
568       check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
569       check_flash_causal_non_square_seqlens,
570       check_dtypes_low_precision);
571   for (auto& constraint : general_constraints) {
572     if (!constraint(params, debug)) {
573       return false;
574     }
575   }
576 
577   if (has_for_nested_inputs(params)) {
578     constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
579         check_batch_size_nested,
580         check_head_dim_size_flash_nested,
581         check_for_seq_len_0_nested_tensor);
582     for (auto& constraint : nested_constraints) {
583       if (!constraint(params, debug)) {
584         return false;
585       }
586     }
587   }
588   if (has_only_dense_inputs(params)) {
589     constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
590         check_batch_size_and_num_heads_dense,
591         check_nonzero_sequence_lengths_dense,
592         check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
593     for (auto& constraint : dense_constraints) {
594       if (!constraint(params, debug)) {
595         return false;
596       }
597     }
598   }
599   return true;
600 }
601 
can_use_mem_efficient_attention(sdp_params const & params,bool debug)602 bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
603 #ifndef USE_MEM_EFF_ATTENTION
604   TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
605   return false;
606 #endif
607   // Constraints specific to mem efficient attention
608   constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
609       array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
610   constexpr auto less_than_sm80_mem_efficient_dtypes =
611       array_of<at::ScalarType>(at::kHalf, at::kFloat);
612 #ifdef USE_ROCM
613   constexpr auto aotriton_mem_efficient_dtypes =
614       array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
615 #endif
616 
617   //  Define gate functions that determine if a mem efficient kernel can be ran
618   constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
619       check_runtime_disabled_mem_efficient,
620       check_all_tensors_on_device,
621       check_mem_efficient_hardware_support,
622       check_tensor_shapes,
623 #ifdef USE_ROCM
624       check_head_dim_size_flash
625 #else
626       check_head_dim_size_mem_efficient
627 #endif
628   );
629   for (auto& constraint : general_constraints) {
630     if (!constraint(params, debug)) {
631       return false;
632     }
633   }
634 
635   if (has_for_nested_inputs(params)) {
636 #ifdef USE_ROCM
637     TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention.");
638     return false;
639 #endif
640     constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
641         check_requires_grad_and_nested,
642         check_batch_size_nested,
643         check_for_seq_len_0_nested_tensor);
644     for (auto& constraint : nested_constraints) {
645       if (!constraint(params, debug)) {
646         return false;
647       }
648     }
649   }
650   if (has_only_dense_inputs(params)) {
651     constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
652         check_batch_size_and_num_heads_dense,
653         check_nonzero_sequence_lengths_dense,
654         check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
655     for (auto& constraint : dense_constraints) {
656       if (!constraint(params, debug)) {
657         return false;
658       }
659     }
660   }
661 
662 #ifdef USE_ROCM
663   return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
664 #else
665   auto dprop = at::cuda::getCurrentDeviceProperties();
666   if (dprop->major >= 8) {
667     return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
668   }
669 #endif
670   return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
671 }
672 
select_sdp_backend(sdp_params const & kernel_params)673 SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
674   // This function defines the priority order of the different sdp backends
675   // 1. Flash Attention
676   // 2. Mem Efficient Attention
677   // 3. Math fallback
678   auto& ctx = at::globalContext();
679   if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() &&
680       !ctx.userEnabledMemEfficientSDP() && !ctx.userEnabledCuDNNSDP()) {
681     return SDPBackend::error;
682   }
683   // Get ideal kernel ordering
684   const auto ordering = priority_order(kernel_params);
685 
686   // Because TORCHCHECK checks if condition is true we negate debug so that
687   // The statements will be printed when debug is true
688   bool print_debug = false;
689   for (auto& backend : ordering) {
690     switch (backend) {
691       case SDPBackend::cudnn_attention:
692         if (sdp::can_use_cudnn_attention(kernel_params, print_debug)) {
693               return SDPBackend::cudnn_attention;
694         }
695         break;
696       case SDPBackend::flash_attention:
697         if (sdp::can_use_flash_attention(kernel_params, print_debug)) {
698           return SDPBackend::flash_attention;
699         }
700         break;
701       case SDPBackend::efficient_attention:
702         if (sdp::can_use_mem_efficient_attention(kernel_params, print_debug)) {
703           return SDPBackend::efficient_attention;
704         }
705         break;
706       case SDPBackend::math:
707         if (ctx.userEnabledMathSDP()) {
708           return SDPBackend::math;
709         }
710         break;
711       default:
712         TORCH_CHECK(false, "Invalid backend");
713     }
714   }
715   // If we have gotten to this point then two things have happened:
716   // 1. use_flash_attention or use_mem_efficient did not satisfy the
717   // constraints to be ran
718   // 2. The user has explicitly disabled the math kernel
719   // We then re-run the kernel checks with debug enabled to print out the
720   // reason why the kernel was not selected
721 
722   print_debug = true;
723   TORCH_WARN("Memory efficient kernel not used because:");
724   sdp::can_use_mem_efficient_attention(kernel_params, print_debug);
725   TORCH_WARN("Flash attention kernel not used because:");
726   sdp::can_use_flash_attention(kernel_params, print_debug);
727   TORCH_WARN("CuDNN attention kernel not used because:");
728   sdp::can_use_cudnn_attention(kernel_params, print_debug);
729   TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
730   return SDPBackend::error;
731 }
732 
check_for_seq_len_1_nested_tensor(sdp_params const & params,bool debug)733 bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug) {
734   // When this function is called we are assured that the nt is dim==4
735   if (!params.query.is_nested()) {
736     return true;
737   }
738 
739   const auto nt_q_tensor_impl =
740       at::native::get_nested_tensor_impl(params.query);
741   const at::Tensor& sizes = nt_q_tensor_impl->get_nested_sizes();
742   auto* sizes_ptr = sizes.data_ptr<int64_t>();
743   const int64_t n_tensors = params.query.size(0);
744   const int64_t size_tensor_stride = sizes.stride(0);
745 
746   // This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
747   for (const auto i : c10::irange(n_tensors)) {
748     if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) {
749       if (debug) {
750         TORCH_WARN(
751             "Packed projection for fused kernels does not support sequence_length <= 1");
752       }
753       return false;
754     }
755   }
756 
757   return true;
758 }
759 
760 } // namespace sdp
761