1 #pragma once
2 #include <ATen/Context.h>
3 #include <ATen/NestedTensorImpl.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <c10/core/ScalarType.h>
10
11 #include <c10/util/Exception.h>
12 #include <c10/util/env.h>
13 #include <c10/util/irange.h>
14
15 #include <c10/core/SymInt.h>
16 #include <c10/core/SymFloat.h>
17 #include <c10/util/string_view.h>
18 #include <c10/util/Array.h>
19 #include <cmath>
20 #include <cstdint>
21 #include <functional>
22
23 namespace sdp {
24
25 constexpr int32_t num_backends = 5;
26 enum class SDPBackend {
27 error = -1,
28 math = 0,
29 flash_attention = 1,
30 efficient_attention = 2,
31 cudnn_attention = 3,
32 overrideable = 4
33 };
34
35 // Note that if this changed make sure to update
36 // the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h
37 enum class CustomMaskType {
38 NoCustomMask = 0,
39 CausalFromTopLeft = 1,
40 CausalFromBottomRight = 2,
41 NumCustomMaskTypes,
42 };
43
44 struct sdp_params {
45 at::Tensor query;
46 at::Tensor key;
47 at::Tensor value;
48 std::optional<at::Tensor> attn_mask;
49 double dropout;
50 bool is_causal;
51 bool enable_gqa;
52 };
53
54 SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
55
calculate_scale(const at::Tensor & query,std::optional<double> scale)56 inline c10::SymFloat calculate_scale(
57 const at::Tensor& query,
58 std::optional<double> scale) {
59 const auto softmax_scale = scale.has_value()
60 ? scale.value()
61 : (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
62 return c10::SymFloat(softmax_scale);
63 }
64
65 using c10::array_of;
66
input_requires_grad(sdp_params const & params)67 inline bool input_requires_grad(sdp_params const& params) {
68 const bool any_inputs_require_grad = params.query.requires_grad() ||
69 params.key.requires_grad() || params.value.requires_grad();
70 const bool gradmode_enabled = at::GradMode::is_enabled();
71 return any_inputs_require_grad && gradmode_enabled;
72 }
73
has_for_nested_inputs(sdp_params const & params)74 inline bool has_for_nested_inputs(sdp_params const& params) {
75 return
76 (params.query.is_nested() && params.query.layout() == c10::kStrided) ||
77 (params.key.is_nested() && params.key.layout() == c10::kStrided) ||
78 (params.value.is_nested() && params.value.layout() == c10::kStrided);
79 }
80
has_for_dense_inputs(sdp_params const & params)81 inline bool has_for_dense_inputs(sdp_params const& params) {
82 return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
83 }
84
has_only_dense_inputs(sdp_params const & params)85 inline bool has_only_dense_inputs(sdp_params const& params) {
86 return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
87 }
88
89 template <typename dtype_vector>
check_tensor_dtype(sdp_params const & params,dtype_vector allowed_dtypes,bool debug)90 inline bool check_tensor_dtype(
91 sdp_params const& params,
92 dtype_vector allowed_dtypes,
93 bool debug) {
94 auto query_dtype = params.query.dtype();
95 if (!(query_dtype == params.key.dtype() &&
96 query_dtype == params.value.dtype() &&
97 (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
98 allowed_dtypes.end()))) {
99 if (debug) {
100 TORCH_WARN(
101 "Expected query, key and value to all be of dtype: {",
102 c10::Join(", ", allowed_dtypes),
103 "}. Got ",
104 "Query dtype: ",
105 params.query.dtype(),
106 ", Key dtype: ",
107 params.key.dtype(),
108 ", and Value dtype: ",
109 params.value.dtype(),
110 " instead.");
111 }
112 return false;
113 }
114 return true;
115 }
116
117
try_broadcast_param_size(const c10::SymInt q_size,const c10::SymInt k_size,const c10::SymInt v_size,c10::string_view param_name,bool debug)118 inline bool try_broadcast_param_size(
119 const c10::SymInt q_size,
120 const c10::SymInt k_size,
121 const c10::SymInt v_size,
122 c10::string_view param_name,
123 bool debug) {
124 auto max_size = std::max({q_size, k_size, v_size});
125 if ((q_size != max_size && q_size != 1) ||
126 (k_size != max_size && k_size != 1) ||
127 (v_size != max_size && v_size != 1)) {
128 if (debug) {
129 TORCH_WARN(
130 "Both fused kernels require query, key and value to have broadcastable ",
131 param_name,
132 "got Query ",
133 param_name,
134 q_size,
135 ", Key ",
136 param_name,
137 k_size,
138 ", Value ",
139 param_name,
140 v_size,
141 " instead.");
142 }
143 return false;
144 }
145 return true;
146 }
147
check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(at::Tensor const & param,c10::string_view param_name,bool debug)148 inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
149 at::Tensor const& param,
150 c10::string_view param_name,
151 bool debug) {
152 const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
153 const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
154 auto num_head_dims = nt_tensor_impl->opt_size(1);
155 if (!num_head_dims.has_value()) {
156 // num_head_dims is ragged
157 if (debug) {
158 TORCH_WARN(
159 "Fused kernels do not support ragged num_head_dims, ",
160 param_name,
161 "has a ragged num_heads.");
162 }
163 return false;
164 }
165
166 auto* sizes_ptr = sizes.data_ptr<int64_t>();
167 const int64_t n_tensors = param.size(0);
168 const int64_t size_tensor_stride = sizes.stride(0);
169
170 // This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
171 for (const auto i : c10::irange(n_tensors)) {
172 if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
173 if (debug) {
174 TORCH_WARN(
175 "Fused kernels do not support seq_len == 0, ",
176 param_name,
177 "has a seq len of 0.");
178 }
179 return false;
180 }
181 }
182 return true;
183 }
184
check_for_seq_len_0_nested_tensor(sdp_params const & params,bool debug)185 inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
186 // When this function is called we are assured that the nt is dim==4
187 bool q_is_safe = params.query.is_nested()
188 ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
189 params.query, "query ", debug)
190 : true;
191 // short circuit if any is unsafe
192 if (!q_is_safe) {
193 return false;
194 }
195
196 bool k_is_safe = params.key.is_nested()
197 ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
198 params.key, "key ", debug)
199 : true;
200 if (!k_is_safe) {
201 return false;
202 }
203
204 bool v_is_safe = params.value.is_nested()
205 ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
206 params.value, "value ", debug)
207 : true;
208 if (!v_is_safe) {
209 return false;
210 }
211
212 // We now know none of the inputs have ragged num_heads, so we can safely
213 // access .size(1)
214 auto q_num_heads = params.query.size(1);
215 auto k_num_heads = params.key.size(1);
216 auto v_num_heads = params.value.size(1);
217 bool same_num_heads =
218 q_num_heads == k_num_heads && q_num_heads == v_num_heads;
219
220 if (!same_num_heads) {
221 if (input_requires_grad(params)){
222 if (debug) {
223 TORCH_WARN(
224 "Both fused kernels do not support training with broadcasted NT inputs.");
225 }
226 return false;
227 }
228 return try_broadcast_param_size(
229 q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
230 }
231
232 return true;
233 }
234
check_nested_tensor(sdp_params const & params,bool debug)235 inline bool check_nested_tensor(sdp_params const& params, bool debug) {
236 // Return false if have nested tensor
237 if (!has_only_dense_inputs(params)) {
238 if (debug) {
239 TORCH_WARN(
240 "Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
241 }
242 return false;
243 }
244 return true;
245 }
246
check_for_dropout(sdp_params const & params,bool debug)247 inline bool check_for_dropout(sdp_params const& params, bool debug) {
248 if (params.dropout > 0.0) {
249 if (debug) {
250 TORCH_WARN("Both fused kernels do not support non-zero dropout.");
251 }
252 return false;
253 }
254 return true;
255 }
256
check_requires_grad_and_nested(sdp_params const & params,bool debug)257 inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
258 if (input_requires_grad(params)) {
259 if (debug) {
260 TORCH_WARN(
261 "Memory efficient attention currently doesn't support training with NT inputs.");
262 }
263 return false;
264 }
265 return true;
266 }
267
check_for_attn_mask(sdp_params const & params,bool debug)268 inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
269 if (params.attn_mask.has_value()) {
270 if (debug) {
271 TORCH_WARN("Flash Attention does not support non-null attn_mask.");
272 }
273 return false;
274 }
275 return true;
276 }
277
check_attn_mask_shape(sdp_params const & params,bool debug)278 inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
279 auto attn_mask = params.attn_mask;
280 if (!attn_mask.has_value()) {
281 return true;
282 }
283 if (attn_mask.value().requires_grad()) {
284 return false;
285 }
286 auto batchSize = params.query.sym_size(0);
287 auto qSize = params.query.sym_size(2);
288 auto kvSize = params.key.sym_size(2);
289 auto num_head = params.query.sym_size(1);
290 if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) {
291 return false;
292 }
293 if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) {
294 return false;
295 }
296 if (attn_mask.value().dim() == 2) {
297 return true;
298 } else if (attn_mask.value().dim() == 4) {
299 if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize)
300 && (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) {
301 return true;
302 }
303 }
304 if (debug) {
305 TORCH_WARN("Please use the following attn mask shapes: ",
306 "2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ",
307 "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})");
308 }
309 return false;
310 }
311
check_tensor_shapes(sdp_params const & params,bool debug)312 inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
313 auto query_dim = params.query.dim();
314 if (!(query_dim == params.key.dim() && query_dim == params.value.dim() &&
315 (query_dim == 4))) {
316 if (debug) {
317 TORCH_WARN(
318 "Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
319 query_dim,
320 ", Key dim: ",
321 params.key.dim(),
322 ", Value dim: ",
323 params.value.dim(),
324 " instead.");
325 }
326 return false;
327 }
328 return true;
329 }
330
check_safe_kv_broadcast(at::Tensor const & param,bool debug)331 inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
332 const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
333 auto seq_len = nt_tensor_impl->opt_size(2);
334 if (!seq_len.has_value()) {
335 if (debug) {
336 TORCH_WARN(
337 "For both fused kernels, if one of key/value batch_size requires "
338 "broadcasting and the other does not, then the other must have a ",
339 "consistent seq_len dim.")
340 }
341 return false;
342 }
343 return true;
344 }
345
check_grouped_query_attention(sdp_params const & params,bool debug)346 inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
347 const auto q_num_heads = params.query.sym_size(-3);
348 const auto k_num_heads = params.key.sym_size(-3);
349 const auto v_num_heads = params.value.sym_size(-3);
350 const bool same_kv_heads = k_num_heads == v_num_heads;
351
352 if (!(same_kv_heads)){
353 if (debug) {
354 TORCH_WARN(
355 "Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
356 "Key sizes: ",
357 params.key.sizes(),
358 ", Value sizes: ",
359 params.value.sizes(),
360 ", Query sizes: ",
361 params.query.sizes(),
362 " instead.");
363 }
364 return false;
365 }
366 // Check if grouped query attention is supported and validate the number of
367 // heads
368 if (q_num_heads % k_num_heads != 0) {
369 if (debug) {
370 TORCH_WARN(
371 "FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
372 "Got input Key sizes(): ",
373 params.key.sym_size(-3),
374 ", Value sizes(): ",
375 params.value.sym_size(-3),
376 ", Query sizes(): ",
377 params.query.sym_size(-3),
378 " instead.");
379 }
380 return false;
381 }
382 return true;
383 }
384
385 template <bool supports_gqa>
check_batch_size_and_num_heads_dense(sdp_params const & params,bool debug)386 inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
387 // This is expected to be called after check_tensor_shapes ensuring that the
388 // size() calls won't error since the inputs are all 4 dimensional
389
390 auto q_batch_size = params.query.sym_size(0);
391 auto k_batch_size = params.key.sym_size(0);
392 auto v_batch_size = params.value.sym_size(0);
393
394 bool same_batch_size =
395 q_batch_size == k_batch_size && q_batch_size == v_batch_size;
396
397 auto q_num_heads = params.query.sym_size(-3);
398 auto k_num_heads = params.key.sym_size(-3);
399 auto v_num_heads = params.value.sym_size(-3);
400
401 bool same_num_heads =
402 q_num_heads == k_num_heads && q_num_heads == v_num_heads;
403
404 if (!same_batch_size){
405 if(debug) {
406 TORCH_WARN(
407 "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
408 "Query.sizes(): ",
409 params.query.sizes(),
410 ", Key.sizes(): ",
411 params.key.sizes(),
412 ", Value.sizes(): ",
413 params.value.sizes(),
414 " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
415 }
416 return false;
417 }
418
419 if(params.enable_gqa && supports_gqa){
420 return check_grouped_query_attention(params, debug);
421 }
422
423 if (!same_num_heads){
424 if (debug) {
425 TORCH_WARN(
426 "For dense input, both fused kernels require query, key and value to have the same num_heads. ",
427 "Query.sizes(): ",
428 params.query.sizes(),
429 ", Key sizes(): ",
430 params.key.sizes(),
431 ", Value sizes(): ",
432 params.value.sizes(),
433 " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
434 }
435 return false;
436 }
437 // If all checks pass, return true
438 return true;
439 }
440
check_batch_size_nested(sdp_params const & params,bool debug)441 inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
442 // This is expected to be called after check_tensor_shapes ensuring that the
443 // size() calls won't error since the inputs are all 4 dimensional
444 auto q_batch_size = params.query.sym_size(0);
445 auto k_batch_size = params.key.sym_size(0);
446 auto v_batch_size = params.value.sym_size(0);
447
448 bool same_batch_size =
449 q_batch_size == k_batch_size && q_batch_size == v_batch_size;
450
451 // num_heads logic for nested input is checked in
452 // check_for_seq_len_0_nested_tensor as there is handling there to make sure
453 // num_heads is not ragged
454 bool broadcastable_batch_size = true;
455 if (!same_batch_size) {
456 if (input_requires_grad(params)){
457 if (debug) {
458 TORCH_WARN(
459 "Both fused kernels do not support training with broadcasted NT inputs.");
460 }
461 return false;
462 }
463 // try to broadcast batchsize
464 broadcastable_batch_size = try_broadcast_param_size(
465 q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);
466
467 // if only one of k or v require broadcasting of batch size, the other
468 // must have a consistent seq_len dim
469 if (broadcastable_batch_size) {
470 if (k_batch_size == 1 && v_batch_size != 1 &&
471 !check_safe_kv_broadcast(params.value, debug)) {
472 return false;
473 }
474 if (v_batch_size == 1 && k_batch_size != 1 &&
475 !check_safe_kv_broadcast(params.key, debug)) {
476 return false;
477 }
478 }
479 }
480 return broadcastable_batch_size;
481 }
482
check_nonzero_sequence_lengths_dense(sdp_params const & params,bool debug)483 inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
484 // In some cases people will pass in 0 sized tensors, this will
485 // cause the fused path to error with unaligned mask
486 bool zero_seq_len_q = params.query.sym_size(-2) == 0;
487 bool zero_seq_len_k = params.key.sym_size(-2) == 0;
488 if (zero_seq_len_q || zero_seq_len_k) {
489 if (debug) {
490 TORCH_WARN(
491 "Both fused kernels do not support zero seq_len_q or seq_len_kv.");
492 }
493 return false;
494 }
495 return true;
496 }
497
498 template<bool ignore_singleton_dim>
check_last_dim_stride_equals_1_dense(sdp_params const & params,bool debug)499 inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
500 // The stride checking for NestedTensors is done within the kernel
501 // And .contiguous will be called if needed
502
503 // This function checks that the last dimension of the inputs to
504 // fused_attention have stride 1
505 bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
506 params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;
507
508 // https://github.com/pytorch/pytorch/issues/116333
509 // If the head_dim is size 1 the stride won't matter, but we
510 // check this condition before padding the head_dim to 1
511 if (ignore_singleton_dim){
512 qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
513 }
514 bool mask_stride_equal_1 = params.attn_mask.has_value()
515 ? params.attn_mask.value().sym_stride(-1) == 1
516 : true;
517 if (!(qkv_strides_equal_1 && mask_stride_equal_1)) {
518 if (debug) {
519 std::ostringstream epilogue_message;
520 if (params.attn_mask.has_value()) {
521 epilogue_message << ", Attn_mask.stride(-1): "
522 << params.attn_mask.value().sym_stride(-1);
523 }
524 epilogue_message << " instead.";
525 TORCH_WARN(
526 "Both fused kernels require the last dimension of the input to have stride 1. ",
527 "Got Query.stride(-1): ",
528 params.query.sym_stride(-1),
529 ", Key.stride(-1): ",
530 params.key.sym_stride(-1),
531 ", Value.stride(-1): ",
532 params.value.sym_stride(-1),
533 epilogue_message.str());
534 }
535
536 return false;
537 }
538 return true;
539 }
540
check_runtime_disabled_flash(sdp_params const & params,bool debug)541 inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) {
542 // We check the global context to see if user has explicitly turned of flash
543 // sdp kernels
544 if (!at::globalContext().userEnabledFlashSDP()) {
545 if (debug) {
546 TORCH_WARN("Flash attention has been runtime disabled.");
547 }
548 return false;
549 }
550 return true;
551 }
552
check_runtime_disabled_mem_efficient(sdp_params const & params,bool debug)553 inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) {
554 // We check the global context to see if user has explicitly turned of
555 // mem_efficient sdp kernels
556 if (!at::globalContext().userEnabledMemEfficientSDP()) {
557 if (debug) {
558 TORCH_WARN("Memory Efficient attention has been runtime disabled.");
559 }
560 return false;
561 }
562 return true;
563 }
564
565
566 } // namespace sdp
567