1# mypy: ignore-errors 2 3from inspect import getattr_static 4from typing import TYPE_CHECKING 5 6from ..bytecode_transformation import create_call_function 7from ..exc import Unsupported 8from .base import VariableTracker 9 10 11if TYPE_CHECKING: 12 from torch._dynamo.symbolic_convert import InstructionTranslator 13 14 15class SDPAParamsVariable(VariableTracker): 16 """Represents the c++ params struct for scaled dot product attention. 17 This is a read-only container.""" 18 19 @staticmethod 20 def create(tx: "InstructionTranslator", value, source): 21 from torch.backends.cuda import SDPAParams 22 23 from ..source import AttrSource 24 from .builder import VariableBuilder 25 from .torch import TorchInGraphFunctionVariable 26 27 query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) 28 key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) 29 value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) 30 attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( 31 value.attn_mask 32 ) 33 dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) 34 is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( 35 value.is_causal 36 ) 37 enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))( 38 value.enable_gqa 39 ) 40 param_vars = [ 41 query_var, 42 key_var, 43 value_var, 44 attn_mask_var, 45 dropout_var, 46 is_causal_var, 47 enable_gqa_var, 48 ] 49 return TorchInGraphFunctionVariable(SDPAParams).call_function( 50 tx, param_vars, {} 51 ) 52 53 def __init__(self, proxy, param_vars, **kwargs) -> None: 54 self.proxy = proxy 55 self.param_vars = param_vars 56 super().__init__(**kwargs) 57 58 def reconstruct(self, codegen): 59 assert self.source is None 60 assert self.param_vars is not None 61 codegen.add_push_null( 62 lambda: codegen.load_import_from("torch._C", "_SDPAParams") 63 ) 64 codegen.foreach(self.param_vars) 65 codegen.extend_output(create_call_function(len(self.param_vars), False)) 66 67 def as_proxy(self): 68 return self.proxy 69 70 def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: 71 import torch._C 72 73 from ..source import AttrSource 74 from .builder import wrap_fx_proxy 75 from .misc import GetAttrVariable 76 77 try: 78 getattr_static(torch._C._SDPAParams, name) 79 except AttributeError: 80 # Using raise from is too verbose here 81 raise Unsupported( 82 f"Unsupported torch._C._SDPAParams attribute {name}" 83 ) from None 84 85 proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) 86 if self.source is not None: 87 return wrap_fx_proxy( 88 tx=tx, proxy=proxy, source=AttrSource(self.source, name) 89 ) 90 else: 91 return wrap_fx_proxy(tx=tx, proxy=proxy) 92 93 @staticmethod 94 def is_sdpa_params(value): 95 from torch.backends.cuda import SDPAParams 96 97 return value is SDPAParams 98