xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/sdpa.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from inspect import getattr_static
4from typing import TYPE_CHECKING
5
6from ..bytecode_transformation import create_call_function
7from ..exc import Unsupported
8from .base import VariableTracker
9
10
11if TYPE_CHECKING:
12    from torch._dynamo.symbolic_convert import InstructionTranslator
13
14
15class SDPAParamsVariable(VariableTracker):
16    """Represents the c++ params struct for scaled dot product attention.
17    This is a read-only container."""
18
19    @staticmethod
20    def create(tx: "InstructionTranslator", value, source):
21        from torch.backends.cuda import SDPAParams
22
23        from ..source import AttrSource
24        from .builder import VariableBuilder
25        from .torch import TorchInGraphFunctionVariable
26
27        query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
28        key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
29        value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
30        attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
31            value.attn_mask
32        )
33        dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
34        is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
35            value.is_causal
36        )
37        enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
38            value.enable_gqa
39        )
40        param_vars = [
41            query_var,
42            key_var,
43            value_var,
44            attn_mask_var,
45            dropout_var,
46            is_causal_var,
47            enable_gqa_var,
48        ]
49        return TorchInGraphFunctionVariable(SDPAParams).call_function(
50            tx, param_vars, {}
51        )
52
53    def __init__(self, proxy, param_vars, **kwargs) -> None:
54        self.proxy = proxy
55        self.param_vars = param_vars
56        super().__init__(**kwargs)
57
58    def reconstruct(self, codegen):
59        assert self.source is None
60        assert self.param_vars is not None
61        codegen.add_push_null(
62            lambda: codegen.load_import_from("torch._C", "_SDPAParams")
63        )
64        codegen.foreach(self.param_vars)
65        codegen.extend_output(create_call_function(len(self.param_vars), False))
66
67    def as_proxy(self):
68        return self.proxy
69
70    def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
71        import torch._C
72
73        from ..source import AttrSource
74        from .builder import wrap_fx_proxy
75        from .misc import GetAttrVariable
76
77        try:
78            getattr_static(torch._C._SDPAParams, name)
79        except AttributeError:
80            # Using raise from is too verbose here
81            raise Unsupported(
82                f"Unsupported torch._C._SDPAParams attribute {name}"
83            ) from None
84
85        proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
86        if self.source is not None:
87            return wrap_fx_proxy(
88                tx=tx, proxy=proxy, source=AttrSource(self.source, name)
89            )
90        else:
91            return wrap_fx_proxy(tx=tx, proxy=proxy)
92
93    @staticmethod
94    def is_sdpa_params(value):
95        from torch.backends.cuda import SDPAParams
96
97        return value is SDPAParams
98