1# mypy: ignore-errors 2 3import enum 4import dis 5import copy 6import sys 7import torch 8import inspect 9import operator 10import collections 11import logging 12 13from dataclasses import is_dataclass, fields 14 15 16from .graph import magic_methods, reflectable_magic_methods, Graph 17from torch.utils._traceback import CapturedTraceback 18from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable 19from .node import Target, Node, Argument, base_types, map_aggregate 20from ._compatibility import compatibility 21from .operator_schemas import check_for_mutable_operation 22import torch.fx.traceback as fx_traceback 23 24__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 25 'Proxy', 'Attribute', 'ParameterProxy', 'Scope', 26 'ScopeContextManager'] 27 28 29log = logging.getLogger(__name__) 30 31 32@compatibility(is_backward_compatible=False) 33class Scope: 34 """ Scope object that records the module path and the module type 35 of a module. Scope is used to track the information of the module 36 that contains a Node in a Graph of GraphModule. For example:: 37 38 class Sub(torch.nn.Module): 39 def forward(self, x): 40 # This will be a call_method Node in GraphModule, 41 # scope for this would be (module_path="sub", module_type=Sub) 42 return x.transpose(1, 2) 43 44 class M(torch.nn.Module): 45 def __init__(self) -> None: 46 self.sub = Sub() 47 48 def forward(self, x): 49 # This will be a call_method Node as well, 50 # scope for this would be (module_path="", None) 51 x = x.transpose(1, 2) 52 x = self.sub(x) 53 return x 54 55 """ 56 57 def __init__(self, module_path: str, module_type: Any): 58 super().__init__() 59 self.module_path = module_path 60 self.module_type = module_type 61 62 63@compatibility(is_backward_compatible=False) 64class ScopeContextManager: 65 """ A context manager to track the Scope of Node during symbolic tracing. 66 When entering a forward function of a Module, we'll update the scope information of 67 the current module, and when we exit, we'll restore the previous scope information. 68 """ 69 70 def __init__( 71 self, 72 scope: Scope, 73 current_scope: Scope, 74 ): 75 super().__init__() 76 # Keep a copy of prev scope to restore on exit 77 self._prev_scope = copy.copy(scope) 78 # Update scope to current scope 79 scope.module_path = current_scope.module_path 80 scope.module_type = current_scope.module_type 81 # Save a reference so we can restore it 82 self._scope = scope 83 84 def __enter__(self): 85 return self._scope 86 87 def __exit__(self, *args): 88 self._scope.module_path = self._prev_scope.module_path 89 self._scope.module_type = self._prev_scope.module_type 90 return 91 92 93_COPY_META_FIELDS = [ 94 "nn_module_stack", 95 "torch_fn", 96 "source_fn_stack", 97 "original_aten", 98 "recompute", 99 "ac_graph_id", 100 "from_node", 101 "quantization_tag", # TODO deprecated 102 "_numeric_debug_handle", # TODO deprecated 103 "custom", 104 "partitioner_tag" 105] 106 107 108@compatibility(is_backward_compatible=True) 109class TracerBase: 110 graph: Graph 111 record_stack_traces : bool = False 112 # Feature flag for mutable schema checking 113 # Enableby default in 1.12 114 check_mutable_operations : bool = False 115 # Feature flag for assert tracing 116 trace_asserts : bool = False 117 # Feature flag for proxying accesses to buffer values 118 proxy_buffer_attributes : bool = False 119 120 # Name of the function to be traced. It will only be used when 121 # ``root`` is an instance of ``nn.Module`` 122 traced_func_name: str = "forward" 123 124 # Maps the containing module's name to the operator name 125 scope : Scope 126 127 # Records the module call stack 128 module_stack: OrderedDict[str, Tuple[str, Any]] 129 130 # Mapping of node name to module scope 131 node_name_to_scope: Dict[str, Tuple[str, type]] 132 133 @compatibility(is_backward_compatible=True) 134 def create_node(self, kind : str, target : Target, 135 args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, 136 type_expr : Optional[Any] = None) -> Node: 137 """ 138 Inserts a graph node given target, args, kwargs, and name. 139 140 This method can be overridden to do extra checking, validation, or 141 modification of values used in node creation. For example, one might 142 want to disallow in-place operations from being recorded. 143 """ 144 145 if kind == 'call_function' and self.check_mutable_operations: 146 check_for_mutable_operation(target, args, kwargs) 147 148 node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) 149 # TODO node_name_to_scope will be depreciated in favor of 150 # node.meta['nn_module_stack'] 151 self.node_name_to_scope[node.name] = ( 152 self.scope.module_path, 153 self.scope.module_type, 154 ) 155 # Optionally set stack trace on the created Node for debugging purposes 156 if fx_traceback.has_preserved_node_meta(): 157 current_meta: Dict[str, Any] = fx_traceback.get_current_meta() 158 159 stack_trace = current_meta.get("stack_trace") 160 if stack_trace: 161 node.stack_trace = stack_trace 162 # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta 163 # If other meta fields are needed, they can be added here 164 for field in _COPY_META_FIELDS: 165 if field in current_meta: 166 node.meta[field] = copy.copy(current_meta[field]) 167 168 # Here we decrement to account for the sequence_nr having 169 # just been incremented while tracing this lowered aten op. 170 new_seq_nr = torch.autograd._get_sequence_nr() - 1 171 # The sequence_nr increments every time a new autograd Node 172 # is created. During the FWD pass we store the sequence_nr 173 # corresponding to the last autograd Node created on this fx 174 # node's meta. A single aten op can create multiple autograd 175 # nodes as is the case with in-place foreach ops. During the 176 # BWD pass we retrieve the sequence_nr stored on the current 177 # executing autograd Node. See NOTE [ Sequence Number ]. 178 if current_meta.get("in_grad_fn", 0) > 0: 179 new_seq_nr = current_meta["grad_fn_seq_nr"][-1] 180 node.meta["seq_nr"] = new_seq_nr 181 182 elif self.module_stack: 183 node.meta['nn_module_stack'] = copy.copy(self.module_stack) 184 185 log.debug("create_node %s", node) 186 return node 187 188 @compatibility(is_backward_compatible=True) 189 def proxy(self, node: Node) -> 'Proxy': 190 return Proxy(node, self) 191 192 @compatibility(is_backward_compatible=True) 193 def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], 194 name: Optional[str] = None, type_expr : Optional[Any] = None, 195 proxy_factory_fn: Callable[[Node], 'Proxy'] = None): 196 ''' 197 Create a Node from the given arguments, then return the Node 198 wrapped in a Proxy object. 199 200 If kind = 'placeholder', then we're creating a Node that 201 represents the parameter of a function. If we need to encode 202 a default parameter, we use the ``args`` tuple. ``args`` is 203 otherwise empty for ``placeholder`` Nodes. 204 ''' 205 206 args_ = self.create_arg(args) 207 kwargs_ = self.create_arg(kwargs) 208 assert isinstance(args_, tuple) 209 assert isinstance(kwargs_, dict) 210 211 node = self.create_node(kind, target, args_, kwargs_, name, type_expr) 212 213 if not proxy_factory_fn: 214 proxy = self.proxy(node) 215 else: 216 proxy = proxy_factory_fn(node) 217 218 if self.record_stack_traces and not proxy.node.stack_trace: 219 proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format()) 220 221 222 return proxy 223 224 def _find_user_frame(self): 225 """ 226 Find the Python stack frame executing the user code during 227 symbolic tracing. 228 """ 229 # We have to do a little dance here. Basically, walk up the callstack and 230 # record the first frame not in the pytorch source. This is the frame executing 231 # the user code during tracing. 232 frame = inspect.currentframe() 233 234 pt_files = ['torch/fx/proxy.py', 235 'torch/fx/_symbolic_trace.py', 236 'torch/fx/experimental/proxy_tensor.py', 237 'torch/_ops.py', 238 'torch/_tensor.py', 239 'torch/utils/_python_dispatch.py', 240 'torch/_prims_common/wrappers.py', 241 'torch/_refs/__init__.py', 242 'torch/_refs/nn/functional/__init__.py', 243 'torch/utils/_stats.py', 244 ] 245 while frame: 246 frame = frame.f_back 247 if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): 248 break 249 250 if not frame: 251 return None 252 253 return frame 254 255 @compatibility(is_backward_compatible=True) 256 def create_arg(self, a: Any) -> Argument: 257 """ 258 A method that lowers the objects seen as arguments during symbolic evaluation 259 into Argument types that can be stored in IR. 260 261 Can be override to support more trace-specific types. 262 """ 263 if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): 264 return a.__fx_create_arg__(self) 265 # aggregates 266 elif isinstance(a, tuple) and hasattr(a, '_fields'): 267 # NamedTuple constructors don't seem to like getting a generator 268 # expression as an argument to their constructor, so build this 269 # intermediate tuple and unpack it into the NamedTuple constructor 270 args = tuple(self.create_arg(elem) for elem in a) 271 return type(a)(*args) # type: ignore[arg-type] 272 elif isinstance(a, (tuple, list)): 273 return type(a)(self.create_arg(elem) for elem in a) 274 elif isinstance(a, dict): 275 r = {} 276 for k, v in a.items(): 277 # Check for invalid dict keys. We do not want a Proxy to appear 278 # anywhere within the key. Since keys can be collection types, 279 # we iterate through the key with map_aggregate 280 k = self.create_arg(k) 281 282 def no_node(arg): 283 if isinstance(arg, Node): 284 raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " 285 f"Node. Got key: {k}") 286 map_aggregate(k, no_node) 287 288 r[k] = self.create_arg(v) 289 return r 290 elif isinstance(a, slice): 291 return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) 292 293 elif isinstance(a, range): 294 return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) 295 296 elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): 297 return a 298 299 if isinstance(a, Proxy): 300 # base case: we unwrap the Proxy object 301 return a.node 302 303 if is_dataclass(a): 304 kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} 305 return self.create_node("call_function", a.__class__, (), kwargs) 306 307 elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: 308 return a 309 raise NotImplementedError(f"argument of type: {type(a)}") 310 311 @compatibility(is_backward_compatible=True) 312 def to_bool(self, obj: 'Proxy') -> bool: 313 """Called when a proxy object is being converted to a boolean, such as 314 when used in control flow. Normally we don't know what to do because 315 we don't know the value of the proxy, but a custom tracer can attach more 316 information to the graph node using create_node and can choose to return a value. 317 """ 318 raise TraceError('symbolically traced variables cannot be used as inputs to control flow') 319 320 @compatibility(is_backward_compatible=True) 321 def iter(self, obj: 'Proxy') -> Iterator: 322 """Called when a proxy object is being iterated over, such as 323 when used in control flow. Normally we don't know what to do because 324 we don't know the value of the proxy, but a custom tracer can attach more 325 information to the graph node using create_node and can choose to return an iterator. 326 """ 327 raise TraceError('Proxy object cannot be iterated. This can be ' 328 'attempted when the Proxy is used in a loop or' 329 ' as a *args or **kwargs function argument. ' 330 'See the torch.fx docs on pytorch.org for a ' 331 'more detailed explanation of what types of ' 332 'control flow can be traced, and check out the' 333 ' Proxy docstring for help troubleshooting ' 334 'Proxy iteration errors') 335 336 @compatibility(is_backward_compatible=True) 337 def keys(self, obj: 'Proxy') -> Any: 338 """Called when a proxy object is has the keys() method called. 339 This is what happens when ** is called on a proxy. This should return an 340 iterator it ** is suppose to work in your custom tracer. 341 """ 342 return Attribute(obj, 'keys')() 343 344 345# used in Proxy object when just appending to the graph while not tracing. 346@compatibility(is_backward_compatible=True) 347class GraphAppendingTracer(TracerBase): 348 def __init__(self, graph: Graph): 349 super().__init__() 350 self.graph = graph 351 self.scope = Scope("", None) 352 self.module_stack = collections.OrderedDict() 353 self.node_name_to_scope = {} 354 355@compatibility(is_backward_compatible=False) 356def assert_fn(x): 357 assert x 358 359@compatibility(is_backward_compatible=True) 360class TraceError(ValueError): 361 pass 362 363@compatibility(is_backward_compatible=True) 364class Proxy: 365 """ 366 ``Proxy`` objects are ``Node`` wrappers that flow through the 367 program during symbolic tracing and record all the operations 368 (``torch`` function calls, method calls, operators) that they touch 369 into the growing FX Graph. 370 371 If you're doing graph transforms, you can wrap your own ``Proxy`` 372 method around a raw ``Node`` so that you can use the overloaded 373 operators to add additional things to a ``Graph``. 374 375 ``Proxy`` objects cannot be iterated. In other words, the symbolic 376 tracer will throw an error if a ``Proxy`` is used in a loop or as 377 an ``*args``/``**kwargs`` function argument. 378 379 There are two main ways around this: 380 1. Factor out the untraceable logic into a top-level function and 381 use ``fx.wrap`` on it. 382 2. If the control flow is static (i.e. the loop trip count is 383 based on some hyperparameter), the code can be kept in its original 384 position and refactored into something like:: 385 386 for i in range(self.some_hyperparameter): 387 indexed_item = proxied_value[i] 388 389 For a more detailed description into the Proxy internals, check out 390 the "Proxy" section in `torch/fx/README.md` 391 """ 392 393 @compatibility(is_backward_compatible=True) 394 def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): 395 if tracer is None: 396 # This allows you to create a Proxy object around a raw Node 397 tracer = GraphAppendingTracer(node.graph) 398 self.tracer = tracer 399 self.node = node 400 401 def __repr__(self) -> str: 402 return f'Proxy({self.node.name})' 403 404 def __getattr__(self, k) -> 'Attribute': 405 # note: not added to the graph yet, if this is a method call 406 # we peephole optimize to the method invocation 407 return Attribute(self, k) 408 409 def __getstate__(self) -> Dict: 410 return self.__dict__ 411 412 def __deepcopy__(self, memo) -> Dict: 413 # We have to explicitly override this method, because otherwise deepcopy 414 # will go to __getattr__(self, "__deepcopy__") and return a 415 # Attribute(__deepcopy__), and may go into an infinite loop in some cases. 416 import copy 417 new_dict = {} 418 for k, v in self.__dict__.items(): 419 try: 420 new_obj = copy.deepcopy(v, memo) 421 except Exception: 422 log.warning( 423 "Shallow copy %s of Proxy because it cannot be deepcopied. " 424 "Proxy is created for node %s", k, self.node.name) 425 new_obj = copy.copy(v) 426 new_dict[k] = new_obj 427 assert "node" in new_dict 428 assert "tracer" in new_dict 429 new_proxy = Proxy(new_dict["node"], new_dict["tracer"]) 430 for k, v in new_dict.items(): 431 new_proxy.__dict__[k] = v 432 return new_proxy 433 434 def __setstate__(self, d): 435 # This is called when being unpickled/loaded. 436 self.__dict__ = d 437 438 def __call__(self, *args, **kwargs) -> 'Proxy': 439 return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) 440 441 def __iter__(self) -> Iterator['Proxy']: 442 frame = inspect.currentframe() 443 assert frame is not None 444 calling_frame = frame.f_back 445 assert calling_frame is not None 446 inst_list = list(dis.get_instructions(calling_frame.f_code)) 447 if sys.version_info >= (3, 11): 448 from bisect import bisect_left 449 inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) 450 else: 451 inst_idx = calling_frame.f_lasti // 2 452 inst = inst_list[inst_idx] 453 if inst.opname == 'UNPACK_SEQUENCE': 454 return (self[i] for i in range(inst.argval)) # type: ignore[index] 455 456 return self.tracer.iter(self) 457 458 def __abs__(self): 459 return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) 460 461 def __bool__(self) -> bool: 462 if self.tracer.trace_asserts: 463 # check if this boolean is used in an assertion, bytecode pattern for assertions 464 # is pretty stable for Python 3.7--3.9 465 frame = inspect.currentframe() 466 assert frame is not None 467 calling_frame = frame.f_back 468 assert calling_frame is not None 469 insts = list(dis.get_instructions(calling_frame.f_code)) 470 if sys.version_info >= (3, 11): 471 from bisect import bisect_left 472 cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) 473 else: 474 cur = calling_frame.f_lasti // 2 475 inst = insts[cur] 476 477 if inst.opname == 'POP_JUMP_IF_TRUE': 478 first = insts[cur + 1] 479 assert inst.arg is not None 480 last = insts[inst.arg // 2 - 1] 481 starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' 482 or first.opname == 'LOAD_ASSERTION_ERROR') 483 if starts_with_assert and last.opname == 'RAISE_VARARGS': 484 self.tracer.create_proxy('call_function', assert_fn, (self,), {}) 485 return True 486 487 return self.tracer.to_bool(self) 488 489 @compatibility(is_backward_compatible=True) 490 def keys(self): 491 return self.tracer.keys(self) 492 493 def __len__(self): 494 raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " 495 "this call to be recorded, please call torch.fx.wrap('len') at " 496 "module scope") 497 498 @classmethod 499 def __torch_function__(cls, orig_method, types, args=None, kwargs=None): 500 args = args if args else () 501 kwargs = kwargs if kwargs else {} 502 503 tracers : Dict[Any, None] = {} 504 505 def find_tracer(a): 506 if isinstance(a, cls): 507 tracers[a.tracer] = None 508 torch.fx.node.map_aggregate(args, find_tracer) 509 torch.fx.node.map_aggregate(kwargs, find_tracer) 510 511 if len(tracers) > 1: 512 raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' 513 f'trying to trace operations {orig_method}') 514 tracer = next(iter(tracers.keys())) 515 516 if isinstance(orig_method, torch._C.ScriptMethod): 517 args = (orig_method.owner,) + args 518 return tracer.create_proxy('call_method', orig_method.name, args, kwargs) 519 if torch.overrides.is_tensor_method_or_property(orig_method): 520 return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) 521 else: 522 if isinstance(orig_method, torch._ops.HigherOrderOperator): 523 # TODO: Define how to symbolically trace HigherOrderOperators 524 raise RuntimeError("Unable to symbolically trace HigherOrderOperators") 525 return tracer.create_proxy('call_function', orig_method, args, kwargs, 526 name=tracer.graph._target_to_str(orig_method.__name__)) 527 528 529@compatibility(is_backward_compatible=True) 530class Attribute(Proxy): 531 @compatibility(is_backward_compatible=True) 532 def __init__(self, root: Proxy, attr: str): 533 self.root = root 534 self.attr = attr 535 self.tracer = root.tracer 536 self._node: Optional[Node] = None 537 538 @property 539 def node(self): 540 # the node for attributes is added lazily, since most will just be method calls 541 # which do not rely on the getitem call 542 if self._node is None: 543 self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node 544 return self._node 545 546 def __call__(self, *args, **kwargs): 547 return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) 548 549 550@compatibility(is_backward_compatible=False) 551class ParameterProxy(Proxy): 552 """ 553 A special proxy which lets "shape", "size", "dim", and a few other 554 attribute accesses pass through to the underlying module parameter object, 555 so that conditional tests on these attributes will not throw exception during tracing 556 """ 557 def __init__(self, tracer: TracerBase, node: Node, name, param): 558 super().__init__(node, tracer) 559 assert isinstance(param, torch.nn.Parameter) 560 self.param = param 561 self.name = name 562 563 def __repr__(self) -> str: 564 return f'ParameterProxy({self.name})' 565 566 @property 567 def shape(self): 568 return self.param.shape 569 570 def size(self): 571 return self.param.size() 572 573 def dim(self): 574 return self.param.dim() 575 576 @property 577 def ndim(self): 578 return self.param.ndim 579 580 def numel(self): 581 return self.param.numel() 582 583 def nelement(self): 584 return self.param.nelement() 585 586 587for method in magic_methods: 588 def _scope(method): 589 def impl(*args, **kwargs): 590 tracer = args[0].tracer 591 target = getattr(operator, method) 592 return tracer.create_proxy('call_function', target, args, kwargs) 593 impl.__name__ = method 594 as_magic = f'__{method.strip("_")}__' 595 setattr(Proxy, as_magic, impl) 596 _scope(method) 597 598def _define_reflectable(orig_method_name): 599 method_name = f'__r{orig_method_name.strip("_")}__' 600 601 def impl(self, rhs): 602 target = getattr(operator, orig_method_name) 603 return self.tracer.create_proxy('call_function', target, (rhs, self), {}) 604 impl.__name__ = method_name 605 impl.__qualname__ = method_name 606 setattr(Proxy, method_name, impl) 607 608for orig_method_name in reflectable_magic_methods: 609 _define_reflectable(orig_method_name) 610