xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/fuse_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import inspect
4import logging
5import math
6
7import torch
8from torch.nn.attention import sdpa_kernel, SDPBackend
9
10from ..._dynamo.utils import counters
11from ..pattern_matcher import (
12    filter_nodes,
13    fwd_only,
14    gen_register_replacement,
15    joint_fwd_bwd,
16)
17
18
19log = logging.getLogger(__name__)
20aten = torch.ops.aten
21
22
23if torch.version.hip:
24
25    def _scaled_dot_product_attention(*args, **kwargs):
26        with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]):
27            return aten.scaled_dot_product_attention(*args, **kwargs)
28
29else:
30    _scaled_dot_product_attention = aten.scaled_dot_product_attention
31
32
33def _sfdp_pattern_1(query, key, value, inv_scale):
34    return (
35        torch.matmul(query, key.transpose(-2, -1))
36        .div(inv_scale)
37        .softmax(dim=-1)
38        .matmul(value)
39    )
40
41
42def _sfdp_replacement_1(query, key, value, inv_scale):
43    counters["inductor"]["fuse_attention"] += 1
44    return _scaled_dot_product_attention(
45        query.contiguous(),
46        key.contiguous(),
47        value.contiguous(),
48        attn_mask=None,
49        dropout_p=0.0,
50        is_causal=False,
51        scale=1.0 / inv_scale,
52    )
53
54
55def _sfdp_pattern_2(query, key, value, scale_factor):
56    return (
57        torch.matmul(query, key.transpose(-2, -1))
58        .mul(scale_factor)
59        .softmax(dim=-1)
60        .matmul(value)
61    )
62
63
64def _sfdp_replacement_2(query, key, value, scale_factor):
65    counters["inductor"]["fuse_attention"] += 1
66    return _scaled_dot_product_attention(
67        query.contiguous(),
68        key.contiguous(),
69        value.contiguous(),
70        attn_mask=None,
71        dropout_p=0.0,
72        is_causal=False,
73        scale=scale_factor,
74    )
75
76
77def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
78    return torch.nn.functional.dropout(
79        torch.matmul(query, key.transpose(-2, -1))
80        .div(inv_scale_factor)
81        .softmax(dim=-1),
82        p=dropout_p,
83    ).matmul(value)
84
85
86def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
87    counters["inductor"]["fuse_attention"] += 1
88    return _scaled_dot_product_attention(
89        query.contiguous(),
90        key.contiguous(),
91        value.contiguous(),
92        attn_mask=None,
93        dropout_p=dropout_p,
94        is_causal=False,
95        scale=1.0 / inv_scale_factor,
96    )
97
98
99def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
100    return torch.nn.functional.dropout(
101        torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
102        p=dropout_p,
103    ).matmul(value)
104
105
106def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
107    counters["inductor"]["fuse_attention"] += 1
108    return _scaled_dot_product_attention(
109        query.contiguous(),
110        key.contiguous(),
111        value.contiguous(),
112        attn_mask=None,
113        dropout_p=dropout_p,
114        is_causal=False,
115        scale=scale_factor,
116    )
117
118
119def _sfdp_pattern_5(query, key, value, attn_mask):
120    attn_weight = torch.softmax(
121        (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
122    )
123    # attn_weight = torch.dropout(attn_weight, dropout_p)
124    return attn_weight @ value
125
126
127def _sfdp_replacement_5(query, key, value, attn_mask):
128    counters["inductor"]["fuse_attention"] += 1
129    return _scaled_dot_product_attention(
130        query.contiguous(),
131        key.contiguous(),
132        value.contiguous(),
133        attn_mask=attn_mask.to(dtype=query.dtype),
134        dropout_p=0.0,
135        is_causal=False,
136    )
137
138
139def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
140    attn_weight = torch.softmax(
141        (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
142    )
143    attn_weight = torch.dropout(attn_weight, dropout_p, True)
144    return attn_weight @ value
145
146
147def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
148    counters["inductor"]["fuse_attention"] += 1
149    return _scaled_dot_product_attention(
150        query.contiguous(),
151        key.contiguous(),
152        value.contiguous(),
153        attn_mask=attn_mask.to(dtype=query.dtype),
154        dropout_p=dropout_p,
155        is_causal=False,
156    )
157
158
159def _sfdp_pattern_7(query, key, value, dropout_p):
160    # in real workloads inputs to matmul are permuted
161    # causing matmul to expand to a series of expand and clone calls
162    # we want the same to happen during pattern tracing
163    q = query.permute(0, 2, 1, 3)
164    k = key.permute(0, 2, 1, 3)
165    v = value.permute(0, 2, 1, 3)
166    div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
167    div = div.to(torch.float32)
168    attn_weight = torch.softmax(div, dim=-1)
169    attn_weight = torch.dropout(attn_weight, dropout_p, True)
170    attn_weight = attn_weight.to(torch.float16)
171    return attn_weight @ v
172
173
174def _sfdp_replacement_7(query, key, value, dropout_p):
175    # sdpa prefers inputs in permuted format
176    # it makes a copy to put them in this format
177    # if they aren't already
178    # to make replacement efficient ensure that inputs to sdpa
179    # are in required order
180    counters["inductor"]["fuse_attention"] += 1
181    q = query.permute(0, 2, 1, 3)
182    k = key.permute(0, 2, 1, 3)
183    v = value.permute(0, 2, 1, 3)
184    return _scaled_dot_product_attention(
185        q,
186        k,
187        v,
188        attn_mask=None,  # attn_mask,
189        dropout_p=dropout_p,
190        is_causal=False,
191    )
192
193
194def _sfdp_pattern_8(query, key, value):
195    # no dropout version of pattern 7
196    q = query.permute(0, 2, 1, 3)
197    k = key.permute(0, 2, 1, 3)
198    v = value.permute(0, 2, 1, 3)
199    div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
200    div = div.to(torch.float32)
201    attn_weight = torch.softmax(div, dim=-1)
202    attn_weight = attn_weight.to(torch.float16)
203    return attn_weight @ v
204
205
206def _sfdp_replacement_8(query, key, value):
207    counters["inductor"]["fuse_attention"] += 1
208    q = query.permute(0, 2, 1, 3)
209    k = key.permute(0, 2, 1, 3)
210    v = value.permute(0, 2, 1, 3)
211    return _scaled_dot_product_attention(
212        q,
213        k,
214        v,
215        attn_mask=None,  # attn_mask,
216        dropout_p=0.0,
217        is_causal=False,
218    )
219
220
221def _sfdp_pattern_9(query, key, value, dropout_p):
222    q = query.permute(0, 2, 1, 3)
223    k = key.permute(0, 2, 1, 3)
224    v = value.permute(0, 2, 1, 3)
225    q = q / math.sqrt(q.size(-1))
226    div = q @ k.transpose(-2, -1)
227    div = div.to(torch.float32)
228    attn_weight = torch.softmax(div, dim=-1)
229    attn_weight = torch.dropout(attn_weight, dropout_p, True)
230    attn_weight = attn_weight.to(torch.float16)
231    return attn_weight @ v
232
233
234def _sfdp_replacement_9(query, key, value, dropout_p):
235    counters["inductor"]["fuse_attention"] += 1
236    q = query.permute(0, 2, 1, 3)
237    k = key.permute(0, 2, 1, 3)
238    v = value.permute(0, 2, 1, 3)
239    return _scaled_dot_product_attention(
240        q,
241        k,
242        v,
243        attn_mask=None,  # attn_mask,
244        dropout_p=dropout_p,
245        is_causal=False,
246    )
247
248
249def _sfdp_pattern_10(query, key, value):
250    # no dropout version of 9
251    q = query.permute(0, 2, 1, 3)
252    k = key.permute(0, 2, 1, 3)
253    v = value.permute(0, 2, 1, 3)
254    q = q / math.sqrt(q.size(-1))
255    div = q @ k.transpose(-2, -1)
256    div = div.to(torch.float32)
257    attn_weight = torch.softmax(div, dim=-1)
258    attn_weight = attn_weight.to(torch.float16)
259    return attn_weight @ v
260
261
262def _sfdp_replacement_10(query, key, value):
263    counters["inductor"]["fuse_attention"] += 1
264    q = query.permute(0, 2, 1, 3)
265    k = key.permute(0, 2, 1, 3)
266    v = value.permute(0, 2, 1, 3)
267    return _scaled_dot_product_attention(
268        q,
269        k,
270        v,
271        attn_mask=None,  # attn_mask,
272        dropout_p=0.0,
273        is_causal=False,
274    )
275
276
277def _sfdp_pattern_11(query, key, value, inv_scale):
278    # Mainly for huggingface models
279    q = query.permute(0, 2, 1, 3)
280    k = key.permute(0, 2, 1, 3)
281    v = value.permute(0, 2, 1, 3)
282    return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
283
284
285def _sfdp_replacement_11(query, key, value, inv_scale):
286    counters["inductor"]["fuse_attention"] += 1
287    return _scaled_dot_product_attention(
288        query.transpose(1, 2),
289        key.transpose(1, 2),
290        value.transpose(1, 2),
291        attn_mask=None,
292        dropout_p=0.0,
293        is_causal=False,
294        scale=1.0 / inv_scale,
295    )
296
297
298def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
299    q = query.permute(0, 2, 1, 3)
300    k = key.permute(0, 2, 1, 3)
301    v = value.permute(0, 2, 1, 3)
302    return torch.nn.functional.dropout(
303        torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
304        p=dropout_p,
305    ).matmul(v)
306
307
308def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
309    counters["inductor"]["fuse_attention"] += 1
310    return _scaled_dot_product_attention(
311        query.transpose(1, 2),
312        key.transpose(1, 2),
313        value.transpose(1, 2),
314        attn_mask=None,
315        dropout_p=dropout_p,
316        is_causal=False,
317        scale=1.0 / inv_scale_factor,
318    )
319
320
321def _sfdp_pattern_13(query, key, value, dropout_p):
322    attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
323    attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
324    return torch.bmm(attn_weight, value)
325
326
327def _sfdp_replacement_13(query, key, value, dropout_p):
328    counters["inductor"]["fuse_attention"] += 1
329    return _scaled_dot_product_attention(
330        query.unsqueeze(0),
331        key.unsqueeze(0),
332        value.unsqueeze(0),
333        dropout_p=dropout_p,
334        scale=1.0,
335    ).squeeze(0)
336
337
338def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
339    # for BertLarge
340    # Permutations are needed to create clones in graph.
341    q = query.permute([0, 2, 1, 3])
342    k = key.permute([0, 2, 1, 3])
343    v = value.permute([0, 2, 1, 3])
344    return (
345        (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
346        .softmax(dim=-1)
347        .matmul(v)
348    )
349
350
351def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
352    counters["inductor"]["fuse_attention"] += 1
353    return _scaled_dot_product_attention(
354        query.transpose(1, 2),
355        key.transpose(1, 2),
356        value.transpose(1, 2),
357        attn_mask=attn_mask.to(dtype=query.dtype),
358        dropout_p=0.0,
359        is_causal=False,
360        scale=1.0 / inv_scale,
361    )
362
363
364def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
365    # for DistilBert
366    # Permutations are needed to create clones in graph.
367    # Ref: https://github.com/pytorch/pytorch/issues/119911
368    q = query.permute([0, 2, 1, 3])
369    k = key.permute([0, 2, 1, 3])
370    v = value.permute([0, 2, 1, 3])
371    bs = q.size(0)
372    k_len = k.size(-2)
373    scores = q @ k.transpose(-2, -1)
374    scores = scores.div(inv_scale)
375    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
376    attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
377    return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
378
379
380def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
381    counters["inductor"]["fuse_attention"] += 1
382    bs = query.size(0)
383    n_head = query.size(2)
384    q_len = query.size(1)
385    k_len = key.size(1)
386    # do attn_mask->logical_not() in _scaled_dot_product_attention
387    attn_mask = (
388        (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
389    )
390    return _scaled_dot_product_attention(
391        query.transpose(1, 2),
392        key.transpose(1, 2),
393        value.transpose(1, 2),
394        attn_mask=attn_mask.to(dtype=torch.bool),
395        dropout_p=0.0,
396        is_causal=False,
397        scale=1.0 / inv_scale,
398    )
399
400
401def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
402    # for BertLarge with dropout
403    q = query.permute([0, 2, 1, 3])
404    k = key.permute([0, 2, 1, 3])
405    v = value.permute([0, 2, 1, 3])
406    return (
407        torch.nn.functional.dropout(
408            (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
409                dim=-1
410            ),
411            dropout_p,
412        )
413        .to(dtype=query.dtype)
414        .matmul(v)
415    )
416
417
418def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
419    counters["inductor"]["fuse_attention"] += 1
420    return _scaled_dot_product_attention(
421        query.transpose(1, 2),
422        key.transpose(1, 2),
423        value.transpose(1, 2),
424        attn_mask=attn_mask.to(dtype=query.dtype),
425        dropout_p=dropout_p,
426        is_causal=False,
427        scale=1.0 / inv_scale,
428    )
429
430
431def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
432    # for DistilBert with dropout
433    q = query.permute([0, 2, 1, 3])
434    k = key.permute([0, 2, 1, 3])
435    v = value.permute([0, 2, 1, 3])
436    bs = q.size(0)
437    k_len = k.size(-2)
438    scores = q @ k.transpose(-2, -1)
439    scores = scores.div(inv_scale)
440    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
441    attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
442    return (
443        torch.nn.functional.dropout(
444            torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
445        )
446        @ v
447    )
448
449
450def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
451    counters["inductor"]["fuse_attention"] += 1
452    bs = query.size(0)
453    n_head = query.size(2)
454    q_len = query.size(1)
455    k_len = key.size(1)
456    # do attn_mask->logical_not() in _scaled_dot_product_attention
457    attn_mask = (
458        (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
459    )
460    return _scaled_dot_product_attention(
461        query.transpose(1, 2),
462        key.transpose(1, 2),
463        value.transpose(1, 2),
464        attn_mask=attn_mask.to(dtype=torch.bool),
465        dropout_p=dropout_p,
466        is_causal=False,
467        scale=1.0 / inv_scale,
468    )
469
470
471def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p):
472    # for hf_GPT2 with dropout (introduces clone node) for inference
473    # it also returns permuted key & value
474    query = query.permute([0, 2, 1, 3])
475    key = key.permute([0, 2, 1, 3])
476    value = value.permute([0, 2, 1, 3])
477    attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
478    inv_scale = torch.full(
479        [],
480        value.size(-1) ** 0.5,
481        dtype=attn_weights.dtype,
482        device=attn_weights.device,
483    )
484    attn_weights = attn_weights.div(inv_scale)
485    causal_mask_value = torch.full(
486        (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
487    )
488    attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
489    return (
490        (
491            torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul(
492                value
493            )
494        ),
495        key,
496        value,
497    )
498
499
500def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p):
501    counters["inductor"]["fuse_attention"] += 1
502    permuted_key = key.transpose(1, 2)
503    permuted_value = value.transpose(1, 2)
504    return (
505        _scaled_dot_product_attention(
506            query.transpose(1, 2),
507            permuted_key,
508            permuted_value,
509            attn_mask=causal_mask,
510            dropout_p=dropout_p,
511            is_causal=False,
512            scale=1.0 / math.sqrt(value.size(-1)),
513        ),
514        permuted_key,
515        permuted_value,
516    )
517
518
519def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p):
520    # for token-classification+gpt2 / text-generation+gpt2
521    attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
522    inv_scale = torch.full(
523        [],
524        value.size(-1) ** 0.5,
525        dtype=attn_weights.dtype,
526        device=attn_weights.device,
527    )
528    attn_weights = attn_weights.div(inv_scale)
529    causal_mask_value = torch.full(
530        (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
531    )
532    attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
533    attn_weights = attn_weights + attn_mask
534    attn_weights = attn_weights.softmax(dim=-1).type(value.dtype)
535    return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value)
536
537
538def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p):
539    counters["inductor"]["fuse_attention"] += 1
540    fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
541    attn_mask = torch.where(causal_mask, attn_mask, fill_value)
542    return _scaled_dot_product_attention(
543        query,
544        key,
545        value,
546        attn_mask=attn_mask,
547        dropout_p=dropout_p,
548        is_causal=False,
549        scale=1.0 / math.sqrt(value.size(-1)),
550    )
551
552
553def _sfdp_params_check(match):
554    assert all(k in match.kwargs for k in ("query", "key", "value"))
555    query = match.kwargs["query"].meta["val"]
556    key = match.kwargs["key"].meta["val"]
557    value = match.kwargs["value"].meta["val"]
558    if not (query.dtype == key.dtype == value.dtype) or not (
559        query.device == key.device == value.device
560    ):
561        return False
562    add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
563    # Has attn_mask add.
564    if len(add_mask_node) > 0:
565        attn_mask_node = add_mask_node[0].args[1]
566        # attn_mask_node may be a float/int number.
567        if not hasattr(attn_mask_node, "meta"):
568            return False
569        attn_mask = attn_mask_node.meta["val"]  # type: ignore[union-attr]
570        # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
571        # attn_mask.dtype == torch.float for models like albert.
572        if (
573            not isinstance(attn_mask, torch.Tensor)
574            or not (
575                attn_mask.dtype == query.dtype
576                or attn_mask.dtype == torch.bool
577                or attn_mask.dtype == torch.float
578            )
579            or query.device != attn_mask.device
580        ):
581            return False
582    return True
583
584
585def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False):
586    def fn(match):
587        if (
588            disable_cuda
589            and "query" in match.kwargs
590            and "cuda" in str(match.kwargs["query"].meta["val"].device)
591        ):
592            return False
593        if scale_factor_op is not None:
594            scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
595            # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
596            scale_factor = scale_factor_node.args[1]
597            # make sure the scale_factor a float/int. SymInt?
598            if not isinstance(scale_factor, (float, int)):
599                return False
600        return _sfdp_params_check(match)
601
602    return fn
603
604
605def partialize_and_update_signature(func, **kwargs):
606    """
607    Equivalent to functools.partial but also updates the signature on returned function
608    """
609    original_sig = inspect.signature(func)
610    parameters = original_sig.parameters
611
612    new_parameters = {
613        key: value for key, value in parameters.items() if key not in kwargs
614    }
615    new_sig = inspect.Signature(parameters=list(new_parameters.values()))
616
617    partial_func = functools.partial(func, **kwargs)
618
619    def wrapper(*args, **kwargs):
620        return partial_func(*args, **kwargs)
621
622    wrapper.__signature__ = new_sig  # type: ignore[attr-defined]
623    wrapper.__name__ = func.__name__
624
625    return wrapper
626
627
628def _get_sfdp_patterns():
629    from .joint_graph import patterns
630
631    if torch.cuda.is_available():
632        # workaround https://github.com/pytorch/pytorch/issues/97894
633        device = "cuda"
634    else:
635        device = "cpu"
636
637    # sizes/values don't actually matter for initial trace
638    # once we get a possible match we re-trace with the actual values and verify the match still holds
639    g_inp = functools.partial(
640        torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
641    )
642    # attn_mask
643    b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
644    m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
645    # inv_scale
646    c_inp = functools.partial(torch.tensor, 2.0, device=device)
647    # workaround https://github.com/pytorch/pytorch/issues/97894
648    # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
649    d = {"dropout_p": 0.113377}
650
651    # we could also generate all these patterns in 3d.. TODO
652    g_3d_inp = functools.partial(
653        torch.empty, (1024, 128, 128), device=device, requires_grad=True
654    )
655
656    # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
657    # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
658    # here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
659    g_bs1_inp = functools.partial(
660        torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
661    )
662    m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
663
664    # softmax will generate a dtype conversion on inputs if they are in half,
665    # but will not in float, so we generate a pattern for both
666    for dtype in [torch.float, torch.half]:
667        g = functools.partial(g_inp, dtype=dtype)
668        b = functools.partial(b_inp, dtype=dtype)
669        b_float = functools.partial(b_inp, dtype=torch.float)
670        b_bool = functools.partial(b_inp, dtype=torch.bool)
671        m = functools.partial(m_inp, dtype=dtype)
672        m_float = functools.partial(m_inp, dtype=torch.float)
673        m_bool = functools.partial(m_inp, dtype=torch.bool)
674        c = functools.partial(c_inp, dtype=dtype)
675        g_3d = functools.partial(g_3d_inp, dtype=dtype)
676        g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
677        m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
678        m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
679        m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool)
680
681        candidates = [
682            (
683                _sfdp_pattern_1,
684                _sfdp_replacement_1,
685                [g(), g(), g(), c()],
686                {},
687                _sfdp_extra_check(aten.div.Tensor),
688            ),
689            (
690                _sfdp_pattern_2,
691                _sfdp_replacement_2,
692                [g(), g(), g(), c()],
693                {},
694                _sfdp_extra_check(aten.mul.Tensor),
695            ),
696            (
697                _sfdp_pattern_3,
698                _sfdp_replacement_3,
699                [g(), g(), g(), c()],
700                d,
701                _sfdp_extra_check(aten.div.Tensor),
702            ),
703            (
704                _sfdp_pattern_4,
705                _sfdp_replacement_4,
706                [g(), g(), g(), c()],
707                d,
708                _sfdp_extra_check(aten.mul.Tensor),
709            ),
710            (
711                _sfdp_pattern_5,
712                _sfdp_replacement_5,
713                [g(), g(), g(), b()],
714                {},
715                _sfdp_params_check,
716            ),
717            (
718                _sfdp_pattern_6,
719                _sfdp_replacement_6,
720                [g(), g(), g(), b()],
721                d,
722                _sfdp_params_check,
723            ),
724            (
725                _sfdp_pattern_7,
726                _sfdp_replacement_7,
727                [g(), g(), g()],
728                d,
729                _sfdp_params_check,
730            ),
731            (
732                _sfdp_pattern_8,
733                _sfdp_replacement_8,
734                [g(), g(), g()],
735                {},
736                _sfdp_params_check,
737            ),
738            (
739                _sfdp_pattern_9,
740                _sfdp_replacement_9,
741                [g(), g(), g()],
742                d,
743                _sfdp_params_check,
744            ),
745            (
746                _sfdp_pattern_10,
747                _sfdp_replacement_10,
748                [g(), g(), g()],
749                {},
750                _sfdp_params_check,
751            ),
752            (
753                _sfdp_pattern_11,
754                _sfdp_replacement_11,
755                [g(), g(), g(), c()],
756                {},
757                _sfdp_extra_check(aten.div.Tensor),
758            ),
759            (
760                _sfdp_pattern_12,
761                _sfdp_replacement_12,
762                [g(), g(), g(), c()],
763                d,
764                _sfdp_extra_check(aten.div.Tensor),
765            ),
766            (
767                _sfdp_pattern_13,
768                _sfdp_replacement_13,
769                [g_3d(), g_3d(), g_3d()],
770                d,
771                _sfdp_params_check,
772            ),
773            (
774                _sfdp_pattern_14,
775                _sfdp_replacement_14,
776                [g(), g(), g(), m(), c()],
777                {},
778                _sfdp_extra_check(aten.div.Tensor),
779            ),
780            (
781                _sfdp_pattern_15,
782                _sfdp_replacement_15,
783                [g(), g(), g(), m(), c()],
784                {},
785                _sfdp_extra_check(aten.div.Tensor),
786            ),
787            # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
788            (
789                _sfdp_pattern_16,
790                _sfdp_replacement_16,
791                [g(), g(), g(), m(), c()],
792                d,
793                _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
794            ),
795            (
796                _sfdp_pattern_16,
797                _sfdp_replacement_16,
798                [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
799                d,
800                _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
801            ),
802            (
803                _sfdp_pattern_17,
804                _sfdp_replacement_17,
805                [g(), g(), g(), m(), c()],
806                d,
807                _sfdp_extra_check(aten.div.Tensor),
808            ),
809            (
810                _sfdp_pattern_18,
811                _sfdp_replacement_18,
812                [g(), g(), g(), m_bool()],
813                d,
814                # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed
815                _sfdp_extra_check(disable_cuda=True),
816            ),
817            (
818                _sfdp_pattern_18,
819                _sfdp_replacement_18,
820                [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()],
821                d,
822                # CUDA AOT Inductor CI job's GPT2ForSequenceClassification accuracy test failed
823                _sfdp_extra_check(disable_cuda=True),
824            ),
825            (
826                _sfdp_pattern_19,
827                _sfdp_replacement_19,
828                [g(), g(), g(), b_bool(), b_float()],
829                d,
830                _sfdp_params_check,
831            ),
832        ]
833        mask_fp32_patterns = ["pattern_16"]
834        if dtype == torch.half:
835            # Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
836            candidates.append(
837                (
838                    _sfdp_pattern_16,
839                    _sfdp_replacement_16,
840                    [g(), g(), g(), m_float(), c()],
841                    d,
842                    _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
843                )
844            )
845            candidates.append(
846                (
847                    _sfdp_pattern_16,
848                    _sfdp_replacement_16,
849                    [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
850                    d,
851                    _sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
852                )
853            )
854
855        for pattern, replacement, args, workaround, extra_check in candidates:
856            # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
857            # gets serialized to a python file and does not require tracing at runtime.
858            assert isinstance(workaround, dict)
859            name = pattern.__name__
860
861            if dtype != torch.float:
862                name += "_half"
863                if (
864                    any(p in name for p in mask_fp32_patterns)
865                    and args[3].dtype == torch.float32
866                ):
867                    name += "_mask_fp32"
868            if args[0].size(0) == 1:
869                name += "_bs1"
870
871            training_name = name + "_training"
872            yield training_name, {
873                "search_fn": pattern,
874                "replace_fn": replacement,
875                "example_inputs": args,
876                "trace_fn": joint_fwd_bwd,
877                "pass_dicts": patterns,
878                "extra_check": extra_check,
879                "scalar_workaround": workaround,
880            }
881
882            if workaround:
883                assert len(workaround) == 1 and "dropout_p" in workaround
884                # functools.partial insufficient because we look at signature downstream
885                pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
886                replacement = partialize_and_update_signature(
887                    replacement, dropout_p=0.0
888                )
889                workaround = {}
890
891            inference_name = name + "_inference"
892            yield inference_name, {
893                "search_fn": pattern,
894                "replace_fn": replacement,
895                "example_inputs": args,
896                "trace_fn": fwd_only,
897                "pass_dicts": patterns,
898                "extra_check": extra_check,
899                "scalar_workaround": workaround,
900                # with dropout turned into clone, we end up with a number of
901                # semantically identical graphs
902                "skip_duplicates": True,
903            }
904
905
906@functools.lru_cache(None)
907def _sfdp_init():
908    for key, register_replacement_kwargs in _get_sfdp_patterns():
909        gen_register_replacement(key, **register_replacement_kwargs)
910