1# mypy: allow-untyped-defs 2import functools 3import inspect 4import logging 5import math 6 7import torch 8from torch.nn.attention import sdpa_kernel, SDPBackend 9 10from ..._dynamo.utils import counters 11from ..pattern_matcher import ( 12 filter_nodes, 13 fwd_only, 14 gen_register_replacement, 15 joint_fwd_bwd, 16) 17 18 19log = logging.getLogger(__name__) 20aten = torch.ops.aten 21 22 23if torch.version.hip: 24 25 def _scaled_dot_product_attention(*args, **kwargs): 26 with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]): 27 return aten.scaled_dot_product_attention(*args, **kwargs) 28 29else: 30 _scaled_dot_product_attention = aten.scaled_dot_product_attention 31 32 33def _sfdp_pattern_1(query, key, value, inv_scale): 34 return ( 35 torch.matmul(query, key.transpose(-2, -1)) 36 .div(inv_scale) 37 .softmax(dim=-1) 38 .matmul(value) 39 ) 40 41 42def _sfdp_replacement_1(query, key, value, inv_scale): 43 counters["inductor"]["fuse_attention"] += 1 44 return _scaled_dot_product_attention( 45 query.contiguous(), 46 key.contiguous(), 47 value.contiguous(), 48 attn_mask=None, 49 dropout_p=0.0, 50 is_causal=False, 51 scale=1.0 / inv_scale, 52 ) 53 54 55def _sfdp_pattern_2(query, key, value, scale_factor): 56 return ( 57 torch.matmul(query, key.transpose(-2, -1)) 58 .mul(scale_factor) 59 .softmax(dim=-1) 60 .matmul(value) 61 ) 62 63 64def _sfdp_replacement_2(query, key, value, scale_factor): 65 counters["inductor"]["fuse_attention"] += 1 66 return _scaled_dot_product_attention( 67 query.contiguous(), 68 key.contiguous(), 69 value.contiguous(), 70 attn_mask=None, 71 dropout_p=0.0, 72 is_causal=False, 73 scale=scale_factor, 74 ) 75 76 77def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): 78 return torch.nn.functional.dropout( 79 torch.matmul(query, key.transpose(-2, -1)) 80 .div(inv_scale_factor) 81 .softmax(dim=-1), 82 p=dropout_p, 83 ).matmul(value) 84 85 86def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): 87 counters["inductor"]["fuse_attention"] += 1 88 return _scaled_dot_product_attention( 89 query.contiguous(), 90 key.contiguous(), 91 value.contiguous(), 92 attn_mask=None, 93 dropout_p=dropout_p, 94 is_causal=False, 95 scale=1.0 / inv_scale_factor, 96 ) 97 98 99def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): 100 return torch.nn.functional.dropout( 101 torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1), 102 p=dropout_p, 103 ).matmul(value) 104 105 106def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): 107 counters["inductor"]["fuse_attention"] += 1 108 return _scaled_dot_product_attention( 109 query.contiguous(), 110 key.contiguous(), 111 value.contiguous(), 112 attn_mask=None, 113 dropout_p=dropout_p, 114 is_causal=False, 115 scale=scale_factor, 116 ) 117 118 119def _sfdp_pattern_5(query, key, value, attn_mask): 120 attn_weight = torch.softmax( 121 (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 122 ) 123 # attn_weight = torch.dropout(attn_weight, dropout_p) 124 return attn_weight @ value 125 126 127def _sfdp_replacement_5(query, key, value, attn_mask): 128 counters["inductor"]["fuse_attention"] += 1 129 return _scaled_dot_product_attention( 130 query.contiguous(), 131 key.contiguous(), 132 value.contiguous(), 133 attn_mask=attn_mask.to(dtype=query.dtype), 134 dropout_p=0.0, 135 is_causal=False, 136 ) 137 138 139def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): 140 attn_weight = torch.softmax( 141 (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 142 ) 143 attn_weight = torch.dropout(attn_weight, dropout_p, True) 144 return attn_weight @ value 145 146 147def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): 148 counters["inductor"]["fuse_attention"] += 1 149 return _scaled_dot_product_attention( 150 query.contiguous(), 151 key.contiguous(), 152 value.contiguous(), 153 attn_mask=attn_mask.to(dtype=query.dtype), 154 dropout_p=dropout_p, 155 is_causal=False, 156 ) 157 158 159def _sfdp_pattern_7(query, key, value, dropout_p): 160 # in real workloads inputs to matmul are permuted 161 # causing matmul to expand to a series of expand and clone calls 162 # we want the same to happen during pattern tracing 163 q = query.permute(0, 2, 1, 3) 164 k = key.permute(0, 2, 1, 3) 165 v = value.permute(0, 2, 1, 3) 166 div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) 167 div = div.to(torch.float32) 168 attn_weight = torch.softmax(div, dim=-1) 169 attn_weight = torch.dropout(attn_weight, dropout_p, True) 170 attn_weight = attn_weight.to(torch.float16) 171 return attn_weight @ v 172 173 174def _sfdp_replacement_7(query, key, value, dropout_p): 175 # sdpa prefers inputs in permuted format 176 # it makes a copy to put them in this format 177 # if they aren't already 178 # to make replacement efficient ensure that inputs to sdpa 179 # are in required order 180 counters["inductor"]["fuse_attention"] += 1 181 q = query.permute(0, 2, 1, 3) 182 k = key.permute(0, 2, 1, 3) 183 v = value.permute(0, 2, 1, 3) 184 return _scaled_dot_product_attention( 185 q, 186 k, 187 v, 188 attn_mask=None, # attn_mask, 189 dropout_p=dropout_p, 190 is_causal=False, 191 ) 192 193 194def _sfdp_pattern_8(query, key, value): 195 # no dropout version of pattern 7 196 q = query.permute(0, 2, 1, 3) 197 k = key.permute(0, 2, 1, 3) 198 v = value.permute(0, 2, 1, 3) 199 div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) 200 div = div.to(torch.float32) 201 attn_weight = torch.softmax(div, dim=-1) 202 attn_weight = attn_weight.to(torch.float16) 203 return attn_weight @ v 204 205 206def _sfdp_replacement_8(query, key, value): 207 counters["inductor"]["fuse_attention"] += 1 208 q = query.permute(0, 2, 1, 3) 209 k = key.permute(0, 2, 1, 3) 210 v = value.permute(0, 2, 1, 3) 211 return _scaled_dot_product_attention( 212 q, 213 k, 214 v, 215 attn_mask=None, # attn_mask, 216 dropout_p=0.0, 217 is_causal=False, 218 ) 219 220 221def _sfdp_pattern_9(query, key, value, dropout_p): 222 q = query.permute(0, 2, 1, 3) 223 k = key.permute(0, 2, 1, 3) 224 v = value.permute(0, 2, 1, 3) 225 q = q / math.sqrt(q.size(-1)) 226 div = q @ k.transpose(-2, -1) 227 div = div.to(torch.float32) 228 attn_weight = torch.softmax(div, dim=-1) 229 attn_weight = torch.dropout(attn_weight, dropout_p, True) 230 attn_weight = attn_weight.to(torch.float16) 231 return attn_weight @ v 232 233 234def _sfdp_replacement_9(query, key, value, dropout_p): 235 counters["inductor"]["fuse_attention"] += 1 236 q = query.permute(0, 2, 1, 3) 237 k = key.permute(0, 2, 1, 3) 238 v = value.permute(0, 2, 1, 3) 239 return _scaled_dot_product_attention( 240 q, 241 k, 242 v, 243 attn_mask=None, # attn_mask, 244 dropout_p=dropout_p, 245 is_causal=False, 246 ) 247 248 249def _sfdp_pattern_10(query, key, value): 250 # no dropout version of 9 251 q = query.permute(0, 2, 1, 3) 252 k = key.permute(0, 2, 1, 3) 253 v = value.permute(0, 2, 1, 3) 254 q = q / math.sqrt(q.size(-1)) 255 div = q @ k.transpose(-2, -1) 256 div = div.to(torch.float32) 257 attn_weight = torch.softmax(div, dim=-1) 258 attn_weight = attn_weight.to(torch.float16) 259 return attn_weight @ v 260 261 262def _sfdp_replacement_10(query, key, value): 263 counters["inductor"]["fuse_attention"] += 1 264 q = query.permute(0, 2, 1, 3) 265 k = key.permute(0, 2, 1, 3) 266 v = value.permute(0, 2, 1, 3) 267 return _scaled_dot_product_attention( 268 q, 269 k, 270 v, 271 attn_mask=None, # attn_mask, 272 dropout_p=0.0, 273 is_causal=False, 274 ) 275 276 277def _sfdp_pattern_11(query, key, value, inv_scale): 278 # Mainly for huggingface models 279 q = query.permute(0, 2, 1, 3) 280 k = key.permute(0, 2, 1, 3) 281 v = value.permute(0, 2, 1, 3) 282 return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v) 283 284 285def _sfdp_replacement_11(query, key, value, inv_scale): 286 counters["inductor"]["fuse_attention"] += 1 287 return _scaled_dot_product_attention( 288 query.transpose(1, 2), 289 key.transpose(1, 2), 290 value.transpose(1, 2), 291 attn_mask=None, 292 dropout_p=0.0, 293 is_causal=False, 294 scale=1.0 / inv_scale, 295 ) 296 297 298def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): 299 q = query.permute(0, 2, 1, 3) 300 k = key.permute(0, 2, 1, 3) 301 v = value.permute(0, 2, 1, 3) 302 return torch.nn.functional.dropout( 303 torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), 304 p=dropout_p, 305 ).matmul(v) 306 307 308def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): 309 counters["inductor"]["fuse_attention"] += 1 310 return _scaled_dot_product_attention( 311 query.transpose(1, 2), 312 key.transpose(1, 2), 313 value.transpose(1, 2), 314 attn_mask=None, 315 dropout_p=dropout_p, 316 is_causal=False, 317 scale=1.0 / inv_scale_factor, 318 ) 319 320 321def _sfdp_pattern_13(query, key, value, dropout_p): 322 attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) 323 attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p) 324 return torch.bmm(attn_weight, value) 325 326 327def _sfdp_replacement_13(query, key, value, dropout_p): 328 counters["inductor"]["fuse_attention"] += 1 329 return _scaled_dot_product_attention( 330 query.unsqueeze(0), 331 key.unsqueeze(0), 332 value.unsqueeze(0), 333 dropout_p=dropout_p, 334 scale=1.0, 335 ).squeeze(0) 336 337 338def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): 339 # for BertLarge 340 # Permutations are needed to create clones in graph. 341 q = query.permute([0, 2, 1, 3]) 342 k = key.permute([0, 2, 1, 3]) 343 v = value.permute([0, 2, 1, 3]) 344 return ( 345 (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask) 346 .softmax(dim=-1) 347 .matmul(v) 348 ) 349 350 351def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): 352 counters["inductor"]["fuse_attention"] += 1 353 return _scaled_dot_product_attention( 354 query.transpose(1, 2), 355 key.transpose(1, 2), 356 value.transpose(1, 2), 357 attn_mask=attn_mask.to(dtype=query.dtype), 358 dropout_p=0.0, 359 is_causal=False, 360 scale=1.0 / inv_scale, 361 ) 362 363 364def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): 365 # for DistilBert 366 # Permutations are needed to create clones in graph. 367 # Ref: https://github.com/pytorch/pytorch/issues/119911 368 q = query.permute([0, 2, 1, 3]) 369 k = key.permute([0, 2, 1, 3]) 370 v = value.permute([0, 2, 1, 3]) 371 bs = q.size(0) 372 k_len = k.size(-2) 373 scores = q @ k.transpose(-2, -1) 374 scores = scores.div(inv_scale) 375 fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) 376 attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) 377 return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v 378 379 380def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): 381 counters["inductor"]["fuse_attention"] += 1 382 bs = query.size(0) 383 n_head = query.size(2) 384 q_len = query.size(1) 385 k_len = key.size(1) 386 # do attn_mask->logical_not() in _scaled_dot_product_attention 387 attn_mask = ( 388 (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) 389 ) 390 return _scaled_dot_product_attention( 391 query.transpose(1, 2), 392 key.transpose(1, 2), 393 value.transpose(1, 2), 394 attn_mask=attn_mask.to(dtype=torch.bool), 395 dropout_p=0.0, 396 is_causal=False, 397 scale=1.0 / inv_scale, 398 ) 399 400 401def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): 402 # for BertLarge with dropout 403 q = query.permute([0, 2, 1, 3]) 404 k = key.permute([0, 2, 1, 3]) 405 v = value.permute([0, 2, 1, 3]) 406 return ( 407 torch.nn.functional.dropout( 408 (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax( 409 dim=-1 410 ), 411 dropout_p, 412 ) 413 .to(dtype=query.dtype) 414 .matmul(v) 415 ) 416 417 418def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): 419 counters["inductor"]["fuse_attention"] += 1 420 return _scaled_dot_product_attention( 421 query.transpose(1, 2), 422 key.transpose(1, 2), 423 value.transpose(1, 2), 424 attn_mask=attn_mask.to(dtype=query.dtype), 425 dropout_p=dropout_p, 426 is_causal=False, 427 scale=1.0 / inv_scale, 428 ) 429 430 431def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p): 432 # for DistilBert with dropout 433 q = query.permute([0, 2, 1, 3]) 434 k = key.permute([0, 2, 1, 3]) 435 v = value.permute([0, 2, 1, 3]) 436 bs = q.size(0) 437 k_len = k.size(-2) 438 scores = q @ k.transpose(-2, -1) 439 scores = scores.div(inv_scale) 440 fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) 441 attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) 442 return ( 443 torch.nn.functional.dropout( 444 torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p 445 ) 446 @ v 447 ) 448 449 450def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): 451 counters["inductor"]["fuse_attention"] += 1 452 bs = query.size(0) 453 n_head = query.size(2) 454 q_len = query.size(1) 455 k_len = key.size(1) 456 # do attn_mask->logical_not() in _scaled_dot_product_attention 457 attn_mask = ( 458 (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) 459 ) 460 return _scaled_dot_product_attention( 461 query.transpose(1, 2), 462 key.transpose(1, 2), 463 value.transpose(1, 2), 464 attn_mask=attn_mask.to(dtype=torch.bool), 465 dropout_p=dropout_p, 466 is_causal=False, 467 scale=1.0 / inv_scale, 468 ) 469 470 471def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): 472 # for hf_GPT2 with dropout (introduces clone node) for inference 473 # it also returns permuted key & value 474 query = query.permute([0, 2, 1, 3]) 475 key = key.permute([0, 2, 1, 3]) 476 value = value.permute([0, 2, 1, 3]) 477 attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) 478 inv_scale = torch.full( 479 [], 480 value.size(-1) ** 0.5, 481 dtype=attn_weights.dtype, 482 device=attn_weights.device, 483 ) 484 attn_weights = attn_weights.div(inv_scale) 485 causal_mask_value = torch.full( 486 (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device 487 ) 488 attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) 489 return ( 490 ( 491 torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul( 492 value 493 ) 494 ), 495 key, 496 value, 497 ) 498 499 500def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): 501 counters["inductor"]["fuse_attention"] += 1 502 permuted_key = key.transpose(1, 2) 503 permuted_value = value.transpose(1, 2) 504 return ( 505 _scaled_dot_product_attention( 506 query.transpose(1, 2), 507 permuted_key, 508 permuted_value, 509 attn_mask=causal_mask, 510 dropout_p=dropout_p, 511 is_causal=False, 512 scale=1.0 / math.sqrt(value.size(-1)), 513 ), 514 permuted_key, 515 permuted_value, 516 ) 517 518 519def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): 520 # for token-classification+gpt2 / text-generation+gpt2 521 attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) 522 inv_scale = torch.full( 523 [], 524 value.size(-1) ** 0.5, 525 dtype=attn_weights.dtype, 526 device=attn_weights.device, 527 ) 528 attn_weights = attn_weights.div(inv_scale) 529 causal_mask_value = torch.full( 530 (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device 531 ) 532 attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) 533 attn_weights = attn_weights + attn_mask 534 attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) 535 return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value) 536 537 538def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): 539 counters["inductor"]["fuse_attention"] += 1 540 fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) 541 attn_mask = torch.where(causal_mask, attn_mask, fill_value) 542 return _scaled_dot_product_attention( 543 query, 544 key, 545 value, 546 attn_mask=attn_mask, 547 dropout_p=dropout_p, 548 is_causal=False, 549 scale=1.0 / math.sqrt(value.size(-1)), 550 ) 551 552 553def _sfdp_params_check(match): 554 assert all(k in match.kwargs for k in ("query", "key", "value")) 555 query = match.kwargs["query"].meta["val"] 556 key = match.kwargs["key"].meta["val"] 557 value = match.kwargs["value"].meta["val"] 558 if not (query.dtype == key.dtype == value.dtype) or not ( 559 query.device == key.device == value.device 560 ): 561 return False 562 add_mask_node = filter_nodes(match.nodes, aten.add.Tensor) 563 # Has attn_mask add. 564 if len(add_mask_node) > 0: 565 attn_mask_node = add_mask_node[0].args[1] 566 # attn_mask_node may be a float/int number. 567 if not hasattr(attn_mask_node, "meta"): 568 return False 569 attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] 570 # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool 571 # attn_mask.dtype == torch.float for models like albert. 572 if ( 573 not isinstance(attn_mask, torch.Tensor) 574 or not ( 575 attn_mask.dtype == query.dtype 576 or attn_mask.dtype == torch.bool 577 or attn_mask.dtype == torch.float 578 ) 579 or query.device != attn_mask.device 580 ): 581 return False 582 return True 583 584 585def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False): 586 def fn(match): 587 if ( 588 disable_cuda 589 and "query" in match.kwargs 590 and "cuda" in str(match.kwargs["query"].meta["val"].device) 591 ): 592 return False 593 if scale_factor_op is not None: 594 scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] 595 # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. 596 scale_factor = scale_factor_node.args[1] 597 # make sure the scale_factor a float/int. SymInt? 598 if not isinstance(scale_factor, (float, int)): 599 return False 600 return _sfdp_params_check(match) 601 602 return fn 603 604 605def partialize_and_update_signature(func, **kwargs): 606 """ 607 Equivalent to functools.partial but also updates the signature on returned function 608 """ 609 original_sig = inspect.signature(func) 610 parameters = original_sig.parameters 611 612 new_parameters = { 613 key: value for key, value in parameters.items() if key not in kwargs 614 } 615 new_sig = inspect.Signature(parameters=list(new_parameters.values())) 616 617 partial_func = functools.partial(func, **kwargs) 618 619 def wrapper(*args, **kwargs): 620 return partial_func(*args, **kwargs) 621 622 wrapper.__signature__ = new_sig # type: ignore[attr-defined] 623 wrapper.__name__ = func.__name__ 624 625 return wrapper 626 627 628def _get_sfdp_patterns(): 629 from .joint_graph import patterns 630 631 if torch.cuda.is_available(): 632 # workaround https://github.com/pytorch/pytorch/issues/97894 633 device = "cuda" 634 else: 635 device = "cpu" 636 637 # sizes/values don't actually matter for initial trace 638 # once we get a possible match we re-trace with the actual values and verify the match still holds 639 g_inp = functools.partial( 640 torch.empty, (2, 4, 8, 16), device=device, requires_grad=True 641 ) 642 # attn_mask 643 b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) 644 m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) 645 # inv_scale 646 c_inp = functools.partial(torch.tensor, 2.0, device=device) 647 # workaround https://github.com/pytorch/pytorch/issues/97894 648 # 0.113377 is a "magic" value that lets us recover the lost input arg relationship 649 d = {"dropout_p": 0.113377} 650 651 # we could also generate all these patterns in 3d.. TODO 652 g_3d_inp = functools.partial( 653 torch.empty, (1024, 128, 128), device=device, requires_grad=True 654 ) 655 656 # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. 657 # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. 658 # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. 659 g_bs1_inp = functools.partial( 660 torch.empty, (1, 4, 8, 16), device=device, requires_grad=True 661 ) 662 m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) 663 664 # softmax will generate a dtype conversion on inputs if they are in half, 665 # but will not in float, so we generate a pattern for both 666 for dtype in [torch.float, torch.half]: 667 g = functools.partial(g_inp, dtype=dtype) 668 b = functools.partial(b_inp, dtype=dtype) 669 b_float = functools.partial(b_inp, dtype=torch.float) 670 b_bool = functools.partial(b_inp, dtype=torch.bool) 671 m = functools.partial(m_inp, dtype=dtype) 672 m_float = functools.partial(m_inp, dtype=torch.float) 673 m_bool = functools.partial(m_inp, dtype=torch.bool) 674 c = functools.partial(c_inp, dtype=dtype) 675 g_3d = functools.partial(g_3d_inp, dtype=dtype) 676 g_bs1 = functools.partial(g_bs1_inp, dtype=dtype) 677 m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) 678 m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float) 679 m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool) 680 681 candidates = [ 682 ( 683 _sfdp_pattern_1, 684 _sfdp_replacement_1, 685 [g(), g(), g(), c()], 686 {}, 687 _sfdp_extra_check(aten.div.Tensor), 688 ), 689 ( 690 _sfdp_pattern_2, 691 _sfdp_replacement_2, 692 [g(), g(), g(), c()], 693 {}, 694 _sfdp_extra_check(aten.mul.Tensor), 695 ), 696 ( 697 _sfdp_pattern_3, 698 _sfdp_replacement_3, 699 [g(), g(), g(), c()], 700 d, 701 _sfdp_extra_check(aten.div.Tensor), 702 ), 703 ( 704 _sfdp_pattern_4, 705 _sfdp_replacement_4, 706 [g(), g(), g(), c()], 707 d, 708 _sfdp_extra_check(aten.mul.Tensor), 709 ), 710 ( 711 _sfdp_pattern_5, 712 _sfdp_replacement_5, 713 [g(), g(), g(), b()], 714 {}, 715 _sfdp_params_check, 716 ), 717 ( 718 _sfdp_pattern_6, 719 _sfdp_replacement_6, 720 [g(), g(), g(), b()], 721 d, 722 _sfdp_params_check, 723 ), 724 ( 725 _sfdp_pattern_7, 726 _sfdp_replacement_7, 727 [g(), g(), g()], 728 d, 729 _sfdp_params_check, 730 ), 731 ( 732 _sfdp_pattern_8, 733 _sfdp_replacement_8, 734 [g(), g(), g()], 735 {}, 736 _sfdp_params_check, 737 ), 738 ( 739 _sfdp_pattern_9, 740 _sfdp_replacement_9, 741 [g(), g(), g()], 742 d, 743 _sfdp_params_check, 744 ), 745 ( 746 _sfdp_pattern_10, 747 _sfdp_replacement_10, 748 [g(), g(), g()], 749 {}, 750 _sfdp_params_check, 751 ), 752 ( 753 _sfdp_pattern_11, 754 _sfdp_replacement_11, 755 [g(), g(), g(), c()], 756 {}, 757 _sfdp_extra_check(aten.div.Tensor), 758 ), 759 ( 760 _sfdp_pattern_12, 761 _sfdp_replacement_12, 762 [g(), g(), g(), c()], 763 d, 764 _sfdp_extra_check(aten.div.Tensor), 765 ), 766 ( 767 _sfdp_pattern_13, 768 _sfdp_replacement_13, 769 [g_3d(), g_3d(), g_3d()], 770 d, 771 _sfdp_params_check, 772 ), 773 ( 774 _sfdp_pattern_14, 775 _sfdp_replacement_14, 776 [g(), g(), g(), m(), c()], 777 {}, 778 _sfdp_extra_check(aten.div.Tensor), 779 ), 780 ( 781 _sfdp_pattern_15, 782 _sfdp_replacement_15, 783 [g(), g(), g(), m(), c()], 784 {}, 785 _sfdp_extra_check(aten.div.Tensor), 786 ), 787 # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention 788 ( 789 _sfdp_pattern_16, 790 _sfdp_replacement_16, 791 [g(), g(), g(), m(), c()], 792 d, 793 _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), 794 ), 795 ( 796 _sfdp_pattern_16, 797 _sfdp_replacement_16, 798 [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()], 799 d, 800 _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), 801 ), 802 ( 803 _sfdp_pattern_17, 804 _sfdp_replacement_17, 805 [g(), g(), g(), m(), c()], 806 d, 807 _sfdp_extra_check(aten.div.Tensor), 808 ), 809 ( 810 _sfdp_pattern_18, 811 _sfdp_replacement_18, 812 [g(), g(), g(), m_bool()], 813 d, 814 # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed 815 _sfdp_extra_check(disable_cuda=True), 816 ), 817 ( 818 _sfdp_pattern_18, 819 _sfdp_replacement_18, 820 [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()], 821 d, 822 # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed 823 _sfdp_extra_check(disable_cuda=True), 824 ), 825 ( 826 _sfdp_pattern_19, 827 _sfdp_replacement_19, 828 [g(), g(), g(), b_bool(), b_float()], 829 d, 830 _sfdp_params_check, 831 ), 832 ] 833 mask_fp32_patterns = ["pattern_16"] 834 if dtype == torch.half: 835 # Add inputs of bf16 q/k/v and fp32 mask, for models like albert. 836 candidates.append( 837 ( 838 _sfdp_pattern_16, 839 _sfdp_replacement_16, 840 [g(), g(), g(), m_float(), c()], 841 d, 842 _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), 843 ) 844 ) 845 candidates.append( 846 ( 847 _sfdp_pattern_16, 848 _sfdp_replacement_16, 849 [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()], 850 d, 851 _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), 852 ) 853 ) 854 855 for pattern, replacement, args, workaround, extra_check in candidates: 856 # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern 857 # gets serialized to a python file and does not require tracing at runtime. 858 assert isinstance(workaround, dict) 859 name = pattern.__name__ 860 861 if dtype != torch.float: 862 name += "_half" 863 if ( 864 any(p in name for p in mask_fp32_patterns) 865 and args[3].dtype == torch.float32 866 ): 867 name += "_mask_fp32" 868 if args[0].size(0) == 1: 869 name += "_bs1" 870 871 training_name = name + "_training" 872 yield training_name, { 873 "search_fn": pattern, 874 "replace_fn": replacement, 875 "example_inputs": args, 876 "trace_fn": joint_fwd_bwd, 877 "pass_dicts": patterns, 878 "extra_check": extra_check, 879 "scalar_workaround": workaround, 880 } 881 882 if workaround: 883 assert len(workaround) == 1 and "dropout_p" in workaround 884 # functools.partial insufficient because we look at signature downstream 885 pattern = partialize_and_update_signature(pattern, dropout_p=0.0) 886 replacement = partialize_and_update_signature( 887 replacement, dropout_p=0.0 888 ) 889 workaround = {} 890 891 inference_name = name + "_inference" 892 yield inference_name, { 893 "search_fn": pattern, 894 "replace_fn": replacement, 895 "example_inputs": args, 896 "trace_fn": fwd_only, 897 "pass_dicts": patterns, 898 "extra_check": extra_check, 899 "scalar_workaround": workaround, 900 # with dropout turned into clone, we end up with a number of 901 # semantically identical graphs 902 "skip_duplicates": True, 903 } 904 905 906@functools.lru_cache(None) 907def _sfdp_init(): 908 for key, register_replacement_kwargs in _get_sfdp_patterns(): 909 gen_register_replacement(key, **register_replacement_kwargs) 910