1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/NestedTensorImpl.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/detail/CUDAHooksInterface.h>
10 #include <ATen/native/DispatchStub.h>
11 #include <ATen/native/transformers/cuda/sdp_utils.h>
12 #include <ATen/native/transformers/sdp_utils_cpp.h>
13 #include <c10/core/ScalarType.h>
14 #include <c10/util/Exception.h>
15 #include <c10/util/env.h>
16 #include <c10/util/irange.h>
17 #include <c10/util/CallOnce.h>
18
19 #include <c10/core/SymInt.h>
20 #include <c10/util/string_view.h>
21
22 #if USE_ROCM
23 #include <aotriton/flash.h>
24 #endif
25
26 /**
27 * Note [SDPA Runtime Dispatch]
28 * SDPA relies on a runtime dispatch mechanism to select the appropriate
29 * kernel. This file contains exposes this through the `select_sdp_backend`
30 * The basic structure of this function is to call `priority_order` to get a
31 * list of backends to try, and then iterate through them until one succeeds.
32 * Each backend defines a use_<backend> function that returns true if the
33 * backend can be run with the given SDP parameters. The use_<backend> function
34 * will iterate over a list of "filters" that check for specific properties of
35 * the SDP parameters. If all filters pass, the backend can be used and use_<backend>
36 * returns true. If any filter fails, then use_<backend> returns false.
37 *
38 * In order to aid in debugging, each filter takes sdp_params and a debug flag.
39 * If the debug flag is set, the filter will print a warning message if it fails.
40 * The behavior of select_sdp_backend is to return the first backend that
41 * succeeds. If no backend is viable then it will run each use_<backend> function
42 * with debug=true and return SDPBackend::error.
43 */
44
45 namespace sdp {
46 namespace {
47 // flash_attention V2 is universally faster than efficient_attention and Math
priority_order(sdp_params const & params)48 std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
49 constexpr std::array<SDPBackend, num_backends> default_order{
50 SDPBackend::cudnn_attention,
51 SDPBackend::flash_attention,
52 SDPBackend::efficient_attention,
53 SDPBackend::math};
54 return default_order;
55 }
56
use_tensor_cores(sdp_params const & params,cudaDeviceProp * dprops,bool is_half)57 bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {
58 if (dprops->major >= 8) {
59 return true;
60 }
61 if (dprops->major >= 7) {
62 return is_half;
63 }
64 return false;
65 }
minimum_gemm_alignment(sdp_params const & params)66 int64_t minimum_gemm_alignment(sdp_params const& params) {
67 auto dprops = at::cuda::getCurrentDeviceProperties();
68 bool is_half = (params.query.dtype() == at::kHalf) ||
69 (params.query.dtype() == at::kBFloat16);
70 bool use_tc = use_tensor_cores(params, dprops, is_half);
71 int64_t matmul_alignment_mn = 1;
72 if (dprops->major >= 8) {
73 matmul_alignment_mn = 4;
74 }
75 int64_t bits_per_scalar = is_half ? 16 : 32;
76 if (use_tc) {
77 matmul_alignment_mn = std::max(matmul_alignment_mn, 128 / bits_per_scalar);
78 }
79 return matmul_alignment_mn;
80 }
81
check_head_dim_size_flash(sdp_params const & params,bool debug)82 bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
83 // All head_dim sizes must be equal and less than 256
84 const auto max_size = c10::SymInt(256);
85 const auto query_size_last = params.query.sym_size(-1);
86 const auto key_size_last = params.key.sym_size(-1);
87 const auto value_size_last = params.value.sym_size(-1);
88 bool same_head_dim_size =
89 query_size_last == key_size_last && query_size_last == value_size_last;
90 if (!(same_head_dim_size && (query_size_last <= max_size))) {
91 if (debug) {
92 TORCH_WARN(
93 "Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.",
94 " Got Query.size(-1): ",
95 query_size_last,
96 ", Key.size(-1): ",
97 key_size_last,
98 ", Value.size(-1): ",
99 value_size_last,
100 " instead.");
101 }
102 return false;
103 }
104 return true;
105 }
106
check_head_dim_size_flash_nested(sdp_params const & params,bool debug)107 bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
108 const auto max_size = c10::SymInt(256);
109 const auto query_size_last = params.query.sym_size(-1);
110 const auto key_size_last = params.key.sym_size(-1);
111 const auto value_size_last = params.value.sym_size(-1);
112 bool same_head_dim_size =
113 query_size_last == key_size_last && query_size_last == value_size_last;
114 if (!(same_head_dim_size && (query_size_last % 8 == 0) &&
115 (query_size_last <= max_size))) {
116 if (debug) {
117 TORCH_WARN(
118 "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
119 " Got Query.size(-1): ",
120 query_size_last,
121 ", Key.size(-1): ",
122 params.key.sym_size(-1),
123 ", Value.size(-1): ",
124 params.value.sym_size(-1),
125 " instead.");
126 }
127 return false;
128 }
129 return true;
130 }
131
check_head_dim_size_mem_efficient(sdp_params const & params,bool debug)132 bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) {
133 const auto query_size_last = params.query.sym_size(-1);
134 const auto value_size_last = params.value.sym_size(-1);
135 const int64_t alignment = minimum_gemm_alignment(params);
136 if (!(query_size_last == params.key.sym_size(-1) &&
137 query_size_last % alignment == 0 && query_size_last > 0 &&
138 value_size_last % alignment == 0 && value_size_last > 0)) {
139 if (debug) {
140 TORCH_WARN(
141 "Mem efficient attention requires last dimension of inputs to be divisible by ",
142 alignment,
143 ". ",
144 "Got Query.size(-1): ",
145 query_size_last,
146 ", Key.size(-1): ",
147 params.key.sym_size(-1),
148 ", Value.size(-1): ",
149 params.value.sym_size(-1),
150 " instead.");
151 }
152 return false;
153 }
154 return true;
155 }
156
157 template <int Major, int Minor>
158 struct SMVersion {
159 static constexpr int major = Major;
160 static constexpr int minor = Minor;
161 constexpr SMVersion() = default;
162 };
163
164 /**
165 * Checks if the current CUDA device architecture is inclusively within the specified range.
166 *
167 * @param lower_bound The lower bound of the CUDA device architecture range.
168 * @param upper_bound The upper bound of the CUDA device architecture range.
169 * @param params The parameters for the current operation.
170 * @return True if the current CUDA device architecture is within the specified range, false otherwise.
171 */
172 template <typename lower_bound, typename upper_bound>
check_sm_version(cudaDeviceProp * dprops)173 bool check_sm_version(cudaDeviceProp * dprops) {
174 bool is_gte_lower_bound = dprops->major > lower_bound::major ||
175 (dprops->major == lower_bound::major &&
176 dprops->minor >= lower_bound::minor);
177 bool is_lte_upper_bound = dprops->major < upper_bound::major ||
178 (dprops->major == upper_bound::major &&
179 dprops->minor <= upper_bound::minor);
180 return is_gte_lower_bound && is_lte_upper_bound;
181 }
182
check_flash_attention_hardware_support(sdp_params const & params,bool debug)183 bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
184 // Check that the gpu is capable of running flash attention
185 using sm80 = SMVersion<8, 0>;
186 using sm90 = SMVersion<9, 0>;
187 #if USE_ROCM
188 auto stream = at::cuda::getCurrentCUDAStream().stream();
189 if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
190 auto dprops = at::cuda::getCurrentDeviceProperties();
191 if (debug) {
192 TORCH_WARN(
193 "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
194 }
195 return false;
196 }
197 #else
198 auto dprops = at::cuda::getCurrentDeviceProperties();
199 if (!check_sm_version<sm80, sm90>(dprops)) {
200 if (debug) {
201 TORCH_WARN(
202 "Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ",
203 dprops->major,
204 ".",
205 dprops->minor,
206 " gpu.");
207 }
208 return false;
209 }
210 #endif
211 return true;
212 }
213
check_mem_efficient_hardware_support(sdp_params const & params,bool debug)214 bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
215 // Mem Efficient attention supports hardware in the range [sm_50, sm_90]
216 using sm50 = SMVersion<5, 0>;
217 using sm90 = SMVersion<9, 0>;
218 #if USE_ROCM
219 auto stream = at::cuda::getCurrentCUDAStream().stream();
220 if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
221 auto dprops = at::cuda::getCurrentDeviceProperties();
222 if (debug) {
223 TORCH_WARN(
224 "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
225 }
226 return false;
227 }
228 #else
229 auto dprops = at::cuda::getCurrentDeviceProperties();
230 if (!check_sm_version<sm50, sm90>(dprops)) {
231 if (debug) {
232 TORCH_WARN(
233 "Mem Efficient Attention only supports gpu architectures in the range [sm50, sm90]. Attempting to run on a sm ",
234 dprops->major,
235 ".",
236 dprops->minor,
237 " gpu.");
238 }
239 return false;
240 }
241 #endif
242 return true;
243 }
244
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(sdp_params const & params,bool debug)245 bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
246 sdp_params const& params,
247 bool debug) {
248 // Flash Attention will raise an error in the backward pass if the head_dim
249 // size is greater than 192 And the device is between in the range [sm86, sm89]
250 using sm86 = SMVersion<8, 6>;
251 using sm89 = SMVersion<8, 9>;
252 auto dprops = at::cuda::getCurrentDeviceProperties();
253 bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
254 bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
255 bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224;
256 bool is_dropout = params.dropout > 0.0;
257 // head_dim size in (192, 224] is not supported on sm86 and sm89
258 bool cond1 = is_head_dim_gt192 && is_head_dim_lte224;
259 // head_dim size > 224 and is_dropout is not supported on sm86 and sm89
260 bool cond2 = params.query.sym_size(-1) > 224 && is_dropout;
261 if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) {
262 if (debug) {
263 TORCH_WARN(
264 "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or "
265 "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].",
266 "Attempting to run with dropout set to: ", params.dropout,
267 "and head_dim: ",
268 params.query.sym_size(-1), " on a sm ", dprops->major, ".",
269 dprops->minor, " gpu.");
270 }
271 return false;
272 }
273 return true;
274 }
275
check_flash_causal_non_square_seqlens(sdp_params const & params,bool debug)276 bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) {
277 // FlashAttention 2 updated the default mask meaning for causal in this PR:
278 // 9e5e8bc91e it is now aligned to lower_right which would be a BC break
279 // for non-square masks. We will not support non-square masks for causal w/ FAV2
280 if (params.is_causal &&
281 !params.query.is_nested() && !params.key.is_nested() &&
282 params.query.sym_size(-2) != params.key.sym_size(-2)) {
283 if (debug) {
284 TORCH_WARN(
285 "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k. ",
286 "Got seqlen_q: ", params.query.sym_size(-2), " seqlen_k: ",
287 params.key.sym_size(-2), ". If you would like to use causal attention with non-square masks, please see CausalAttnMask.");
288 }
289 return false;
290 }
291 return true;
292 }
293
check_all_tensors_on_device(sdp_params const & params,bool debug)294 bool check_all_tensors_on_device(sdp_params const& params, bool debug) {
295 // Check that all tensors are on the GPU device
296 // This should be handled by the stub dispatch, but whe call can_use_*_attention
297 // directly from python we need to ensure that the tensors are on cuda
298 if (params.query.device().type() != at::DeviceType::CUDA) {
299 if (debug) {
300 TORCH_WARN(
301 "All tensors need to be on cuda device. Got query on device: ",
302 params.query.device(),
303 ", key on device: ",
304 params.key.device(),
305 ", value on device: ",
306 params.value.device());
307 }
308 return false;
309 }
310 return true;
311 }
312
check_cudnn_tensor_shapes(sdp_params const & params,bool debug)313 bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
314 const auto s_q = params.query.sym_size(2);
315 const auto s_k = params.key.sym_size(2);
316 const auto head_dim = params.query.sym_size(3);
317 long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
318 if (cudnn_version >= 90000) {
319 if (head_dim % 8 != 0 || head_dim > 256) {
320 if (debug) {
321 TORCH_WARN("head_dim should be a multiple of 8 and no more than 256");
322 }
323 return false;
324 }
325 } else {
326 if (head_dim % 8 != 0 || head_dim > 128) {
327 if (debug) {
328 TORCH_WARN("head_dim should be a multiple of 8 and no more than 128");
329 }
330 return false;
331 }
332 }
333 if (cudnn_version < 8903) {
334 if (debug) {
335 TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
336 }
337 return false;
338 }
339 if (params.dropout != 0.0 && cudnn_version < 8906) {
340 if (debug) {
341 TORCH_WARN("Dropout reference is only supported on 8.9.6 onwards.");
342 }
343 return false;
344 }
345 if (cudnn_version < 90000) {
346 if (s_q < 64) {
347 if (debug) {
348 TORCH_WARN("s_q less than 64 is not supported before cudnn 9.0.0");
349 }
350 return false;
351 }
352 if ((s_q % 64 != 0 || s_k % 64 != 0) && params.dropout != 0.0) {
353 if (debug) {
354 TORCH_WARN(
355 "s_q not a multiple of 64 with padding/dropout is not supported with cudnn version 9.0.0");
356 }
357 return false;
358 }
359 }
360 if (s_k % 64 != 0 && cudnn_version < 8906) {
361 if (debug) {
362 TORCH_WARN("not-multiple-of-64 seq_kv is not supported below 8.9.6");
363 }
364 return false;
365 }
366 return true;
367 }
368
check_cudnn_layout(sdp_params const & params,bool debug)369 bool check_cudnn_layout(sdp_params const& params, bool debug) {
370 const int64_t h = params.query.size(1);
371 const int64_t s_q = params.query.size(2);
372 const int64_t d = params.query.size(3);
373 const int64_t s_k = params.key.size(2);
374 const int64_t s_v = params.value.size(2);
375 // corresponds to cuDNN's "packed QKV" layout
376 const bool packed_query_layout_ok = (params.query.stride(0) == s_q * 3 * h * d) &&
377 (params.query.stride(1) == d) &&
378 (params.query.stride(2) == 3 * h * d) &&
379 (params.query.stride(3) == 1);
380 const bool packed_key_layout_ok = (params.key.stride(0) == s_k * 3 * h * d) &&
381 (params.key.stride(1) == d) &&
382 (params.key.stride(2) == 3 * h * d) &&
383 (params.key.stride(3) == 1);
384 const bool packed_value_layout_ok = (params.value.stride(0) == s_v * 3 * h * d) &&
385 (params.value.stride(1) == d) &&
386 (params.value.stride(2) == 3 * h * d) &&
387 (params.value.stride(3) == 1);
388
389 const bool packed_layout_ok = packed_query_layout_ok && packed_key_layout_ok && packed_value_layout_ok;
390
391 const bool query_layout_ok = (params.query.stride(0) == s_q * h * d) &&
392 (params.query.stride(1) == d) &&
393 (params.query.stride(2) == h * d) &&
394 (params.query.stride(3) == 1);
395 const bool key_layout_ok = (params.key.stride(0) == s_k * h * d) &&
396 (params.key.stride(1) == d) &&
397 (params.key.stride(2) == h * d) &&
398 (params.key.stride(3) == 1);
399 const bool value_layout_ok = (params.value.stride(0) == s_v * h * d) &&
400 (params.value.stride(1) == d) &&
401 (params.value.stride(2) == h * d) &&
402 (params.value.stride(3) == 1);
403
404 const bool layout_ok = query_layout_ok && key_layout_ok && value_layout_ok;
405
406 if (!packed_value_layout_ok && !layout_ok) {
407 if (debug) {
408 if (!packed_layout_ok) {
409 if (!packed_query_layout_ok) {
410 TORCH_WARN("Query tensor was not in cuDNN-supported packed QKV layout", params.query.strides());
411 }
412 if (!packed_key_layout_ok) {
413 TORCH_WARN("Key tensor was not in cuDNN-supported packed QKV layout", params.key.strides());
414 }
415 if (!packed_value_layout_ok) {
416 TORCH_WARN("Value tensor was not in cuDNN-supported packed QKV layout", params.value.strides());
417 }
418 }
419 if (!layout_ok) {
420 if (!query_layout_ok) {
421 TORCH_WARN("Query tensor was not in cuDNN-supported unpacked QKV layout", params.query.strides());
422 }
423 if (!key_layout_ok) {
424 TORCH_WARN("Key tensor was not in cuDNN-supported unpacked QKV layout", params.key.strides());
425 }
426 if (!value_layout_ok) {
427 TORCH_WARN("Value tensor was not in cuDNN-supported unpacked QKV layout", params.value.strides());
428 }
429 }
430 }
431 return false;
432 }
433 return true;
434 }
435
check_cudnn_hardware_support(sdp_params const & params,bool debug)436 bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
437 using sm80 = SMVersion<8, 0>;
438 using sm90 = SMVersion<9, 0>;
439 auto dprops = at::cuda::getCurrentDeviceProperties();
440 if (!check_sm_version<sm80, sm90>(dprops)) {
441 if (debug) {
442 TORCH_WARN(
443 "cuDNN MHA only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ",
444 dprops->major,
445 ".",
446 dprops->minor,
447 " gpu.");
448 }
449 return false;
450 }
451 return true;
452 }
453
check_is_causal(sdp_params const & params,bool debug)454 bool check_is_causal(sdp_params const& params, bool debug) {
455 // Check that the input is causal
456 if (!params.is_causal) {
457 if (debug) {
458 TORCH_WARN("CuDNN requires is_causal=True.");
459 }
460 return false;
461 }
462 return true;
463 }
464
check_for_nested_inputs(sdp_params const & params,bool debug)465 bool check_for_nested_inputs(sdp_params const& params, bool debug) {
466 // Check that the input is nested
467 if (has_for_nested_inputs(params)) {
468 if (debug) {
469 TORCH_WARN("CuDNN currently does not support nested inputs.");
470 }
471 return false;
472 }
473 return true;
474 }
475
check_dtypes_low_precision(sdp_params const & params,bool debug)476 bool check_dtypes_low_precision(sdp_params const& params, bool debug) {
477 auto dprop = at::cuda::getCurrentDeviceProperties();
478 if (dprop->major >= 8) {
479 constexpr auto sm80_dtypes =
480 array_of<at::ScalarType>(at::kHalf, at::kBFloat16);
481 return check_tensor_dtype(params, sm80_dtypes, debug);
482 } else {
483 constexpr auto default_dtypes = array_of<at::ScalarType>(at::kHalf);
484 return check_tensor_dtype(params, default_dtypes, debug);
485 }
486 }
487
check_runtime_enabled_cudnn(sdp_params const & params,bool debug)488 bool check_runtime_enabled_cudnn(sdp_params const& params, bool debug) {
489 static c10::once_flag supported_flag;
490 static bool supported = false;
491 c10::call_once(supported_flag, []() {
492 supported = (c10::utils::check_env("TORCH_CUDNN_SDPA_ENABLED") == true);
493 });
494 if (!supported) {
495 if (debug) {
496 TORCH_WARN(
497 "The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1`");
498 }
499 return false;
500 }
501 return true;
502 }
503
check_runtime_disabled_cudnn(sdp_params const & params,bool debug)504 bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
505 // We check the global context to see if user has explicitly turned of cudnn
506 // sdp kernels
507 if (!at::globalContext().userEnabledCuDNNSDP()) {
508 if (debug) {
509 TORCH_WARN("CuDNN attention has been runtime disabled.");
510 }
511 return false;
512 }
513 return true;
514 }
515
check_cudnn_requires_grad(sdp_params const & params,bool debug)516 bool check_cudnn_requires_grad(sdp_params const& params, bool debug) {
517 // Check that the input is causal
518 if (input_requires_grad(params)) {
519 if (debug) {
520 TORCH_WARN("CuDNN does not currently support inputs with requires_grad=True.");
521 }
522 return false;
523 }
524 return true;
525 }
526
527 } // namespace
528
can_use_cudnn_attention(const sdp_params & params,bool debug)529 bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
530
531 // Define gate functions that determine if a flash kernel can be ran
532 // Replace with std::to_array when we migrate to c++20
533 constexpr auto general_constraints =
534 array_of<bool (*)(sdp_params const&, bool)>(
535 check_runtime_enabled_cudnn,
536 check_runtime_disabled_cudnn,
537 check_cudnn_hardware_support,
538 check_all_tensors_on_device,
539 check_cudnn_tensor_shapes,
540 check_cudnn_layout,
541 // check_is_causal,
542 check_for_nested_inputs,
543 check_cudnn_requires_grad,
544 check_dtypes_low_precision);
545 for (auto& constraint : general_constraints) {
546 if (!constraint(params, debug)) {
547 return false;
548 }
549 }
550 return true;
551 }
552
can_use_flash_attention(sdp_params const & params,bool debug)553 bool can_use_flash_attention(sdp_params const& params, bool debug) {
554 #ifndef USE_FLASH_ATTENTION
555 TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
556 return false;
557 #endif
558
559 // Define gate functions that determine if a flash kernel can be ran
560 // Replace with std::to_array when we migrate to c++20
561 constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
562 check_runtime_disabled_flash,
563 check_all_tensors_on_device,
564 check_tensor_shapes,
565 check_for_attn_mask,
566 check_head_dim_size_flash,
567 check_flash_attention_hardware_support,
568 check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
569 check_flash_causal_non_square_seqlens,
570 check_dtypes_low_precision);
571 for (auto& constraint : general_constraints) {
572 if (!constraint(params, debug)) {
573 return false;
574 }
575 }
576
577 if (has_for_nested_inputs(params)) {
578 constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
579 check_batch_size_nested,
580 check_head_dim_size_flash_nested,
581 check_for_seq_len_0_nested_tensor);
582 for (auto& constraint : nested_constraints) {
583 if (!constraint(params, debug)) {
584 return false;
585 }
586 }
587 }
588 if (has_only_dense_inputs(params)) {
589 constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
590 check_batch_size_and_num_heads_dense,
591 check_nonzero_sequence_lengths_dense,
592 check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
593 for (auto& constraint : dense_constraints) {
594 if (!constraint(params, debug)) {
595 return false;
596 }
597 }
598 }
599 return true;
600 }
601
can_use_mem_efficient_attention(sdp_params const & params,bool debug)602 bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
603 #ifndef USE_MEM_EFF_ATTENTION
604 TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
605 return false;
606 #endif
607 // Constraints specific to mem efficient attention
608 constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
609 array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
610 constexpr auto less_than_sm80_mem_efficient_dtypes =
611 array_of<at::ScalarType>(at::kHalf, at::kFloat);
612 #ifdef USE_ROCM
613 constexpr auto aotriton_mem_efficient_dtypes =
614 array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
615 #endif
616
617 // Define gate functions that determine if a mem efficient kernel can be ran
618 constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
619 check_runtime_disabled_mem_efficient,
620 check_all_tensors_on_device,
621 check_mem_efficient_hardware_support,
622 check_tensor_shapes,
623 #ifdef USE_ROCM
624 check_head_dim_size_flash
625 #else
626 check_head_dim_size_mem_efficient
627 #endif
628 );
629 for (auto& constraint : general_constraints) {
630 if (!constraint(params, debug)) {
631 return false;
632 }
633 }
634
635 if (has_for_nested_inputs(params)) {
636 #ifdef USE_ROCM
637 TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention.");
638 return false;
639 #endif
640 constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
641 check_requires_grad_and_nested,
642 check_batch_size_nested,
643 check_for_seq_len_0_nested_tensor);
644 for (auto& constraint : nested_constraints) {
645 if (!constraint(params, debug)) {
646 return false;
647 }
648 }
649 }
650 if (has_only_dense_inputs(params)) {
651 constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
652 check_batch_size_and_num_heads_dense,
653 check_nonzero_sequence_lengths_dense,
654 check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
655 for (auto& constraint : dense_constraints) {
656 if (!constraint(params, debug)) {
657 return false;
658 }
659 }
660 }
661
662 #ifdef USE_ROCM
663 return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
664 #else
665 auto dprop = at::cuda::getCurrentDeviceProperties();
666 if (dprop->major >= 8) {
667 return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
668 }
669 #endif
670 return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
671 }
672
select_sdp_backend(sdp_params const & kernel_params)673 SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
674 // This function defines the priority order of the different sdp backends
675 // 1. Flash Attention
676 // 2. Mem Efficient Attention
677 // 3. Math fallback
678 auto& ctx = at::globalContext();
679 if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() &&
680 !ctx.userEnabledMemEfficientSDP() && !ctx.userEnabledCuDNNSDP()) {
681 return SDPBackend::error;
682 }
683 // Get ideal kernel ordering
684 const auto ordering = priority_order(kernel_params);
685
686 // Because TORCHCHECK checks if condition is true we negate debug so that
687 // The statements will be printed when debug is true
688 bool print_debug = false;
689 for (auto& backend : ordering) {
690 switch (backend) {
691 case SDPBackend::cudnn_attention:
692 if (sdp::can_use_cudnn_attention(kernel_params, print_debug)) {
693 return SDPBackend::cudnn_attention;
694 }
695 break;
696 case SDPBackend::flash_attention:
697 if (sdp::can_use_flash_attention(kernel_params, print_debug)) {
698 return SDPBackend::flash_attention;
699 }
700 break;
701 case SDPBackend::efficient_attention:
702 if (sdp::can_use_mem_efficient_attention(kernel_params, print_debug)) {
703 return SDPBackend::efficient_attention;
704 }
705 break;
706 case SDPBackend::math:
707 if (ctx.userEnabledMathSDP()) {
708 return SDPBackend::math;
709 }
710 break;
711 default:
712 TORCH_CHECK(false, "Invalid backend");
713 }
714 }
715 // If we have gotten to this point then two things have happened:
716 // 1. use_flash_attention or use_mem_efficient did not satisfy the
717 // constraints to be ran
718 // 2. The user has explicitly disabled the math kernel
719 // We then re-run the kernel checks with debug enabled to print out the
720 // reason why the kernel was not selected
721
722 print_debug = true;
723 TORCH_WARN("Memory efficient kernel not used because:");
724 sdp::can_use_mem_efficient_attention(kernel_params, print_debug);
725 TORCH_WARN("Flash attention kernel not used because:");
726 sdp::can_use_flash_attention(kernel_params, print_debug);
727 TORCH_WARN("CuDNN attention kernel not used because:");
728 sdp::can_use_cudnn_attention(kernel_params, print_debug);
729 TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
730 return SDPBackend::error;
731 }
732
check_for_seq_len_1_nested_tensor(sdp_params const & params,bool debug)733 bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug) {
734 // When this function is called we are assured that the nt is dim==4
735 if (!params.query.is_nested()) {
736 return true;
737 }
738
739 const auto nt_q_tensor_impl =
740 at::native::get_nested_tensor_impl(params.query);
741 const at::Tensor& sizes = nt_q_tensor_impl->get_nested_sizes();
742 auto* sizes_ptr = sizes.data_ptr<int64_t>();
743 const int64_t n_tensors = params.query.size(0);
744 const int64_t size_tensor_stride = sizes.stride(0);
745
746 // This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
747 for (const auto i : c10::irange(n_tensors)) {
748 if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) {
749 if (debug) {
750 TORCH_WARN(
751 "Packed projection for fused kernels does not support sequence_length <= 1");
752 }
753 return false;
754 }
755 }
756
757 return true;
758 }
759
760 } // namespace sdp
761