1# mypy: ignore-errors 2 3import inspect 4from typing import Dict, List, TYPE_CHECKING 5 6import torch.utils._pytree as pytree 7from torch._guards import Source 8from torch.overrides import _get_overloaded_args, get_default_nowrap_functions 9from torch.utils._device import DeviceContext 10 11from ..exc import unimplemented 12from ..guards import GuardBuilder, install_guard 13from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource 14from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter 15from .base import VariableTracker 16from .constant import ConstantVariable 17from .ctx_manager import ContextWrappingVariable 18from .lists import TupleVariable 19from .tensor import TensorSubclassVariable, TensorVariable 20from .user_defined import UserDefinedObjectVariable 21 22 23if TYPE_CHECKING: 24 from torch._dynamo.symbolic_convert import InstructionTranslator 25 26 27# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues): 28# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches 29# __torch_function__ on attribute accesses, method calls, and torch API calls. 30# The following is not supported: 31# - triggering __torch_function__ on tensor subclass non-tensor custom attributes 32# - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause 33# excessive recompiles in certain degenerate cases 34# - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls 35 36# The following is supported: 37# - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as 38# any argument 39# - triggering __torch_function__ on torch API calls with tensor subclass arguments 40# - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances 41# - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position 42 43# See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w 44# for more information on the design. 45 46# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py 47 48 49banned_attrs = [ 50 fn.__self__.__name__ 51 for fn in get_default_nowrap_functions() 52 if is_tensor_base_attr_getter(fn) 53] 54 55# Today set default device is placed in the graph and guarded on separately 56# so we should not trace through it. In the future we can trace it once 57# mode tracing is implemented and not put in the graph, but this is more 58# of a BE project and can be evaluated later 59IGNORED_MODES = {DeviceContext} 60 61 62class TorchFunctionModeStackVariable(VariableTracker): 63 """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" 64 65 # singleton value representing the global torch function mode stack 66 # singleton (it exists in C++) 67 stack_value_singleton = object() 68 69 # offset is used to track if we have inserted/removed a 70 # device context which is always placed at the bottom of the stack 71 # if a device context is inserted, the graph will run this mutation 72 # so when we want to reconstruct any other modes on the stack 73 # their indices should be shifted right by 1 (+1) 74 # Conversely, if there was a device context on the stack, and the graph 75 # mutates the stack to remove that context (set default device to None) 76 # each of the indices of other modes should be shifted left by 1 (-1) 77 offset = 0 78 79 def __init__(self, source, symbolic_stack): 80 self.source = source 81 self.symbolic_stack = symbolic_stack 82 83 @classmethod 84 def reset(cls): 85 cls.offset = 0 86 87 @classmethod 88 def register_mutation(cls, tx: "InstructionTranslator"): 89 if cls.stack_value_singleton not in tx.output.side_effects: 90 var = cls( 91 source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack 92 ) 93 tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) 94 tx.output.side_effects.mutation(var) 95 96 @classmethod 97 def register_device_context_insertion(cls, tx: "InstructionTranslator"): 98 stack = tx.symbolic_torch_function_mode_stack 99 if stack and cls.is_device_context(stack[0]): 100 return 101 else: 102 cls.offset += 1 103 tx.symbolic_torch_function_mode_stack.insert( 104 0, 105 TorchFunctionModeVariable( 106 None, source=TorchFunctionModeStackSource(-cls.offset) 107 ), 108 ) 109 110 @classmethod 111 def clear_default_device(cls, tx: "InstructionTranslator"): 112 stack = tx.symbolic_torch_function_mode_stack 113 if stack and cls.is_device_context(stack[0]): 114 stack.popleft() 115 cls.offset -= 1 116 117 @staticmethod 118 def is_device_context(var): 119 return isinstance(var.value, DeviceContext) or var.value is None 120 121 @classmethod 122 def get_mode_index(cls, ind): 123 return ind + cls.offset 124 125 126class TorchFunctionModeVariable(ContextWrappingVariable): 127 def __init__(self, value, **kwargs): 128 super().__init__(value, **kwargs) 129 self.value = value 130 131 @staticmethod 132 def get_global_mangled_name(tx, val): 133 return get_safe_global_name( 134 tx, f"__torch_function_mode_{val.__class__.__name__}", val 135 ) 136 137 def reconstruct(self, codegen): 138 # We don't support locally created torch function modes yet 139 assert self.source 140 self.source.reconstruct(codegen) 141 142 def _call_func(self, tx, values): 143 unimplemented("torch function mode context manager is not supported yet") 144 145 146def _get_all_args(args, kwargs): 147 return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) 148 149 150def _flatten_vts(vts): 151 from collections import deque 152 153 from .dicts import ConstDictVariable 154 from .lazy import LazyVariableTracker 155 from .lists import ListVariable 156 157 vts = deque(vts) 158 output = [] 159 160 while vts: 161 vt = vts.pop() 162 LazyVariableTracker.realize_all(vt) 163 if isinstance(vt, ListVariable): 164 vts.extend(vt.items) 165 elif isinstance(vt, ConstDictVariable): 166 vts.extend(vt.items.values()) 167 else: 168 output.append(vt) 169 170 return output 171 172 173def _get_subclass_type(var): 174 assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) 175 return var.python_type() 176 177 178def _get_subclass_type_var(tx: "InstructionTranslator", var): 179 assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) 180 if isinstance(var, TensorWithTFOverrideVariable): 181 return var.class_type_var(tx) 182 elif isinstance(var, UserDefinedObjectVariable): 183 from .builder import SourcelessBuilder, VariableBuilder 184 185 if var.source: 186 return VariableBuilder(tx, TypeSource(var.source))(var.python_type()) 187 else: 188 return SourcelessBuilder.create(tx, var.python_type()) 189 190 191def _is_attr_overidden(tx: "InstructionTranslator", var, name): 192 import torch 193 194 overridden = False 195 try: 196 attr_val = inspect.getattr_static(var.python_type(), name) 197 overridden |= attr_val != getattr(torch.Tensor, name) 198 except AttributeError: 199 pass 200 201 return overridden 202 203 204def call_torch_function( 205 tx, torch_function_type, torch_function_var, fn, types, args, kwargs 206): 207 from .builder import SourcelessBuilder 208 209 # signature: 210 # def __torch_function__(cls, func, types, args=(), kwargs=None): 211 tf_args = ( 212 torch_function_type, 213 fn, 214 types, 215 SourcelessBuilder.create(tx, tuple(args)), 216 SourcelessBuilder.create(tx, kwargs), 217 ) 218 return tx.inline_user_function_return(torch_function_var, tf_args, {}) 219 220 221def build_torch_function_fn(tx: "InstructionTranslator", value, source): 222 from .builder import SourcelessBuilder, VariableBuilder 223 224 if source: 225 return VariableBuilder( 226 tx, 227 AttrSource(AttrSource(source, "__torch_function__"), "__func__"), 228 )(value.__torch_function__.__func__) 229 else: 230 return SourcelessBuilder.create(tx, value.__torch_function__.__func__) 231 232 233def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): 234 return tx.output.torch_function_enabled and any( 235 has_torch_function(arg) for arg in _get_all_args(args, kwargs) 236 ) 237 238 239def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): 240 """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" 241 242 all_args = _get_all_args(args, kwargs) 243 overloaded_args = _get_overloaded_args( 244 [arg for arg in all_args if has_torch_function(arg)], 245 _get_subclass_type, 246 ) 247 248 for arg in overloaded_args: 249 res = arg.call_torch_function( 250 tx, 251 fn, 252 TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), 253 args, 254 kwargs, 255 ) 256 257 if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): 258 return res 259 260 unimplemented( 261 f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented" 262 ) 263 264 265class TensorWithTFOverrideVariable(TensorVariable): 266 """ 267 Represents a tensor subclass instance with a __torch_function__ override. 268 """ 269 270 def __init__(self, *args, **kwargs) -> None: 271 self.torch_function_fn = kwargs.pop("torch_function_fn") 272 super().__init__(*args, **kwargs) 273 274 @classmethod 275 def from_tensor_var(cls, tx, tensor_var, class_type, torch_function_fn): 276 import torch 277 278 kwargs = dict(tensor_var.__dict__) 279 assert ( 280 kwargs.pop("class_type") is torch.Tensor 281 ), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" 282 var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs) 283 var.install_global(tx) 284 return var 285 286 def install_global(self, tx): 287 # stash the subclass type to rewrap an output tensor if needed 288 # this is needed because the actual type needs to be available 289 # each time the compiled artifact is run and outputs a wrapped tensor. 290 if self.global_mangled_class_name(tx) not in tx.output.global_scope: 291 # Safe because global_mangled_class_name figures it out 292 tx.output.install_global_unsafe( 293 self.global_mangled_class_name(tx), self.class_type 294 ) 295 296 def python_type(self): 297 return self.class_type 298 299 def class_type_var(self, tx): 300 return TensorSubclassVariable( 301 self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) 302 ) 303 304 def global_mangled_class_name(self, tx): 305 return get_safe_global_name( 306 tx, f"__subclass_{self.class_type.__name__}", self.class_type 307 ) 308 309 def var_getattr(self, tx: "InstructionTranslator", name): 310 # [Note: __torch_function__] We currently only support attributes that are defined on 311 # base tensors, custom attribute accesses will graph break. 312 import torch 313 314 from .builder import SourcelessBuilder 315 316 if name in banned_attrs: 317 unimplemented( 318 f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" 319 ) 320 321 if _is_attr_overidden(tx, self, name): 322 unimplemented( 323 f"Accessing overridden method/attribute {name} on a tensor" 324 " subclass with a __torch_function__ override is not supported" 325 ) 326 327 if tx.output.torch_function_enabled and hasattr(torch.Tensor, name): 328 if self.source: 329 install_guard( 330 AttrSource(AttrSource(self.source, "__class__"), name).make_guard( 331 GuardBuilder.FUNCTION_MATCH 332 ) 333 ) 334 get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__) 335 336 return self.call_torch_function( 337 tx, 338 get_fn, 339 TupleVariable([self.class_type_var(tx)]), 340 [self], 341 {}, 342 ) 343 else: 344 return super().var_getattr(tx, name) 345 346 def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): 347 return call_torch_function( 348 tx, 349 self.class_type_var(tx), 350 self.torch_function_fn, 351 fn, 352 types, 353 args, 354 kwargs, 355 ) 356 357 def call_method( 358 self, 359 tx, 360 name, 361 args: "List[VariableTracker]", 362 kwargs: "Dict[str, VariableTracker]", 363 ) -> "VariableTracker": 364 # This code block implements inlining the __torch_function__ override 365 # of `call_method`. 366 if tx.output.torch_function_enabled: 367 import torch 368 369 from .builder import SourcelessBuilder, VariableBuilder 370 371 if _is_attr_overidden(tx, self, name): 372 unimplemented( 373 f"Calling overridden method {name} on a tensor" 374 " subclass with a __torch_function__ override is not supported" 375 ) 376 377 # [Note: __torch_function__] Currently we only support methods that are defined on tensor 378 # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality 379 # We've established with the above check that the method is not overridden, so we guard that the method is the same 380 # as the impl defined on tensor and retrieve it 381 if self.source: 382 func_var = VariableBuilder( 383 tx, AttrSource(AttrSource(self.source, "__class__"), name) 384 )(inspect.getattr_static(self.python_type(), name)) 385 else: 386 func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) 387 return dispatch_torch_function(tx, func_var, [self] + args, kwargs) 388 else: 389 return super().call_method(tx, name, args, kwargs) 390