1 #include <ATen/native/transformers/sdp_utils_cpp.h>
2 namespace sdp {
3 namespace {
4
priority_order_cpp(sdp_params const & params)5 std::array<SDPBackend, num_backends> priority_order_cpp(sdp_params const& params) {
6 constexpr std::array<SDPBackend, num_backends> default_order{
7 SDPBackend::flash_attention,
8 SDPBackend::math};
9
10 return default_order;
11 }
12
check_head_dim_size_cpp(sdp_params const & params,bool debug)13 bool check_head_dim_size_cpp(sdp_params const& params, bool debug) {
14 const auto query_size_last = params.query.sym_size(-1);
15 const auto key_size_last = params.key.sym_size(-1);
16 const auto value_size_last = params.value.sym_size(-1);
17 if (!(query_size_last == key_size_last &&
18 query_size_last == value_size_last)) {
19 if (debug) {
20 TORCH_WARN(
21 "Flash attention requires q,k,v to have the same last dimension.",
22 " Got Query.size(-1): ",
23 query_size_last,
24 ", Key.size(-1): ",
25 params.key.sym_size(-1),
26 ", Value.size(-1): ",
27 params.value.sym_size(-1),
28 " instead.");
29 }
30 return false;
31 }
32 return true;
33 }
34
use_flash_attention_cpp(sdp_params const & params,bool debug)35 bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
36 constexpr auto cpp_supported_flash_dtypes =
37 array_of<at::ScalarType>(at::kFloat, at::kDouble, at::kBFloat16, at::kHalf);
38
39 // Define gate functions that determine if a flash kernel can be run
40 constexpr auto constraints = array_of<bool (*)(sdp_params const&, bool)>(
41 check_runtime_disabled_flash,
42 check_nested_tensor,
43 check_for_dropout,
44 check_tensor_shapes,
45 check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>,
46 check_attn_mask_shape,
47 check_head_dim_size_cpp,
48 check_nonzero_sequence_lengths_dense,
49 check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>);
50 for (auto& constraint : constraints) {
51 if (!constraint(params, debug)) {
52 return false;
53 }
54 }
55
56 return check_tensor_dtype(params, cpp_supported_flash_dtypes, debug);
57 }
58 } // namespace
59
select_sdp_backend_cpp(sdp_params const & kernel_params)60 SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params) {
61 // This function defines the priority order of the different sdp backends
62 // 1. Flash Attention
63 // 2. Math fallback
64 auto& ctx = at::globalContext();
65 if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP()) {
66 return SDPBackend::error;
67 }
68 // Get ideal kernel ordering
69 const auto ordering = priority_order_cpp(kernel_params);
70
71 // Because TORCHCHECK checks if condition is true we negate debug so that
72 // The statements will be printed when debug is true
73 bool print_debug = false;
74 for (auto& backend : ordering) {
75 switch (backend) {
76 case SDPBackend::flash_attention:
77 if (use_flash_attention_cpp(kernel_params, print_debug)) {
78 return SDPBackend::flash_attention;
79 }
80 break;
81 case SDPBackend::math:
82 if (ctx.userEnabledMathSDP()) {
83 return SDPBackend::math;
84 }
85 break;
86 default:
87 TORCH_CHECK(false, "Invalid backend");
88 }
89 }
90 // If we have gotten to this point then two things have happened:
91 // 1. use_flash_attention did not satisfy the
92 // constraints to be ran
93 // 2. The user has explicitly disabled the math kernel
94 // We then re-run the kernel checks with debug enabled to print out the
95 // reason why the kernel was not selected
96
97 print_debug = true;
98 TORCH_WARN("Flash attention kernel not used because:");
99 use_flash_attention_cpp(kernel_params, print_debug);
100 TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
101 return SDPBackend::error;
102 }
103 } // namespace sdp
104