1# mypy: allow-untyped-defs 2import builtins 3import copy 4import contextlib 5import functools 6import inspect 7import math 8import os 9import warnings 10import collections 11from itertools import chain 12from types import CodeType, FunctionType, ModuleType 13from typing import ( 14 Any, 15 Callable, 16 Dict, 17 List, 18 NamedTuple, 19 Optional, 20 Set, 21 Tuple, 22 Type, 23 Union, 24) 25 26import torch 27import torch.utils._pytree as pytree 28from torch._C import ScriptObject # type: ignore[attr-defined] 29from torch._library.fake_class_registry import FakeScriptObject 30 31from ._compatibility import compatibility 32from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph 33from .graph_module import GraphModule 34from ._lazy_graph_module import _make_graph_module 35from .node import Argument, base_types, map_aggregate 36from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager 37 38HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS 39 40# These need to run in global scope to handle nested calls correctly 41_orig_module_call: Callable = torch.nn.Module.__call__ 42_orig_module_getattr: Callable = torch.nn.Module.__getattr__ 43 44_proxyable_classes: Dict[Type, None] = {} 45 46_is_fx_tracing_flag = False 47 48 49def is_fx_tracing(): 50 return _is_fx_tracing_flag 51 52@compatibility(is_backward_compatible=True) 53class ProxyableClassMeta(type): 54 """ 55 ProxyableClassMeta allows you to make construction of a given Python class 56 symbolically traceable. For example:: 57 58 import torch 59 import torch.fx 60 61 class TensorPair(metaclass=torch.fx.ProxyableClassMeta): 62 def __init__(self, left, right): 63 self.left, self.right = left, right 64 65 def add(self, other): 66 l = self.left + other.left 67 r = self.right + other.right 68 return TensorPair(l, r) 69 70 def mul(self, other): 71 l = self.left * other.left 72 r = self.right * other.right 73 return TensorPair(l, r) 74 75 def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): 76 s = x.add(TensorPair(y, y)) 77 return s.mul(x) 78 79 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) 80 y = torch.randn(5, 3) 81 ref_out = use_tensor_pair_ctor(x, y) 82 83 traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) 84 print(traced.code) 85 ''' 86 def forward(self, x : __main___TensorPair, y : torch.Tensor): 87 tensor_pair = __main___TensorPair(y, y); y = None 88 add = x.add(tensor_pair); tensor_pair = None 89 mul = add.mul(x); add = x = None 90 return mul 91 ''' 92 93 From this example, we can see that construction of a class (``TensorPair``) 94 defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic 95 tracing. 96 """ 97 98 def __init__(cls, name, bases, attrs): 99 _proxyable_classes.setdefault(cls) 100 super().__init__(name, bases, attrs) 101 102 def __call__(cls, *args, **kwargs): 103 instance = cls.__new__(cls) # type: ignore[call-overload] 104 105 if not is_fx_tracing(): 106 cls.__init__(instance, *args, **kwargs) # type: ignore[misc] 107 return instance 108 109 found_proxies = [] 110 111 def check_proxy(a): 112 if isinstance(a, Proxy): 113 found_proxies.append(a) 114 115 map_aggregate(args, check_proxy) 116 map_aggregate(kwargs, check_proxy) 117 118 if len(found_proxies) != 0: 119 tracer = found_proxies[0].tracer 120 return tracer.create_proxy("call_function", cls, args, kwargs) 121 else: 122 cls.__init__(instance, *args, **kwargs) # type: ignore[misc] 123 return instance 124 125 126def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: 127 co = fn.__code__ 128 co_flags = co.co_flags & ~HAS_VARSTUFF 129 co_args: tuple 130 if hasattr(co, "co_qualname"): 131 # Python-3.11+ code signature 132 co_args = ( 133 nargs, 134 0, 135 0, 136 co.co_nlocals, 137 co.co_stacksize, 138 co_flags, 139 co.co_code, 140 co.co_consts, 141 co.co_names, 142 co.co_varnames, 143 co.co_filename, 144 co.co_name, 145 co.co_qualname, # type: ignore[attr-defined] 146 co.co_firstlineno, 147 co.co_lnotab, 148 co.co_exceptiontable, # type: ignore[attr-defined] 149 co.co_freevars, 150 co.co_cellvars, 151 ) 152 elif hasattr(co, "co_posonlyargcount"): 153 co_args = ( 154 nargs, 155 0, 156 0, 157 co.co_nlocals, 158 co.co_stacksize, 159 co_flags, 160 co.co_code, 161 co.co_consts, 162 co.co_names, 163 co.co_varnames, 164 co.co_filename, 165 co.co_name, 166 co.co_firstlineno, 167 co.co_lnotab, 168 co.co_freevars, 169 co.co_cellvars, 170 ) 171 else: 172 co_args = ( 173 nargs, 174 0, 175 co.co_nlocals, 176 co.co_stacksize, 177 co_flags, 178 co.co_code, 179 co.co_consts, 180 co.co_names, 181 co.co_varnames, 182 co.co_filename, 183 co.co_name, 184 co.co_firstlineno, 185 co.co_lnotab, 186 co.co_freevars, 187 co.co_cellvars, 188 ) 189 new_code = CodeType(*co_args) # type: ignore[arg-type] 190 return FunctionType( 191 new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ 192 ) 193 194 # we need to insert placeholder nodes for *args and **kwargs 195 # we can't call this function normally, otherwise it would try to unpack them 196 # instead, let's make python think that args and kwargs are normal variables 197 198 199@compatibility(is_backward_compatible=False) 200class PHBase: 201 """ 202 Object representing an input placeholder to `concrete_args` 203 """ 204 205 def __repr__(self): 206 return "PH" 207 208 209PH = PHBase() 210 211 212@compatibility(is_backward_compatible=False) 213class PHWithMeta(PHBase): 214 """ 215 Object representing an input placeholder to `concrete_args` 216 """ 217 def __init__(self, ph_key: Optional[str] = None): 218 super().__init__() 219 220 # Provide a hey for user to identify placeholder node during analysis 221 self.ph_key = ph_key 222 223 224def _transfer_attrs(fr, to): 225 for attr_name in dir(fr): 226 attr_val = getattr(fr, attr_name) 227 if ( 228 not callable(attr_val) 229 and not attr_name.startswith("__") 230 and not hasattr(to, attr_name) 231 ): 232 setattr(to, attr_name, attr_val) 233 234 235@compatibility(is_backward_compatible=True) 236class Tracer(TracerBase): 237 # Reference: https://github.com/pytorch/pytorch/issues/54354 238 # The first line of this docstring overrides the one Sphinx generates for the 239 # documentation. We need it so that Sphinx doesn't leak `math`s path from the 240 # build environment (e.g. `<module 'math' from '/leaked/path'). 241 242 """Tracer(autowrap_modules=(math,), autowrap_functions=()) 243 244 ``Tracer`` is the class that implements the symbolic tracing functionality 245 of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent 246 to ``Tracer().trace(m)``. 247 248 Tracer can be subclassed to override various behaviors of the tracing 249 process. The different behaviors that can be overridden are described 250 in the docstrings of the methods on this class. 251 """ 252 253 # Not checking BC on this API because the default value for `autowrap_modules` 254 # includes the local filepath to the `math` module, which would jitter 255 # across machines. 256 @compatibility(is_backward_compatible=True) 257 def __init__( 258 self, 259 autowrap_modules: Tuple[ModuleType] = (math,), 260 autowrap_functions: Tuple[Callable, ...] = (), 261 param_shapes_constant: bool = False, 262 ) -> None: 263 # This method's signature is overridden by the first line of this class' 264 # docstring. If this method's signature is modified, the signature that 265 # overrides it also should be modified accordingly. 266 267 """ 268 Construct a Tracer object. 269 270 Args: 271 272 autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, 273 Python modules whose functions should be wrapped automatically 274 without needing to use fx.wrap(). Backward-compatibility for 275 this parameter is guaranteed. 276 277 autowrap_functions (Tuple[Callable, ...]): defaults to `()`, 278 Python functions that should be wrapped automatically without 279 needing to use fx.wrap(). Backward compatibility for this 280 parameter is guaranteed. 281 282 param_shapes_constant (bool): When this flag is set, calls to shape, 283 size and a few other shape like attributes of a module's parameter 284 will be evaluated directly, rather than returning a new Proxy value 285 for an attribute access. Backward compatibility for this parameter 286 is guaranteed. 287 """ 288 289 super().__init__() 290 291 # Functions we will eagerly wrap when we see them while tracing 292 # this captures both `math.sqrt()` and `from math import sqrt` automatically 293 self._autowrap_function_ids: Set[int] = { 294 id(value) 295 for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) 296 if not name.startswith("_") and callable(value) 297 } 298 self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) 299 300 # Python modules to apply autowrap to at the start, in addition to 301 # modules we see while tracing 302 self._autowrap_search: List[ModuleType] = list(autowrap_modules) 303 self.param_shapes_constant = param_shapes_constant 304 305 self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None 306 self.root_module_name: str = "" 307 # Maps the containing module's name to the operator name 308 self.scope = Scope("", None) 309 # Records the module call stack 310 self.module_stack = collections.OrderedDict() 311 # Mapping of node name to module scope 312 self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} 313 314 _qualname_counter: Dict[str, int] = collections.defaultdict(int) 315 316 @compatibility(is_backward_compatible=True) 317 def get_fresh_qualname(self, prefix: str) -> str: 318 """ 319 Gets a fresh name for a prefix and returns it. This function ensures 320 that it will not clash with an existing attribute on the graph. 321 """ 322 # The idea here is that if the module doesn't have this prefix at all we 323 # should reset the counter to start from the beginning 324 # It's a ... little bit hacky (doesn't cover all cases) but the precise 325 # naming of the prefixes isn't a correctness issue, just a niceness 326 # issue 327 qualname = f"{prefix}0" 328 if not hasattr(self.root, qualname): 329 self._qualname_counter[prefix] = 0 330 return qualname 331 332 i = self._qualname_counter[prefix] 333 while True: 334 qualname = f"{prefix}{i}" 335 i += 1 336 if not hasattr(self.root, qualname): 337 break 338 self._qualname_counter[prefix] = i 339 340 return qualname 341 342 @compatibility(is_backward_compatible=True) 343 def create_arg(self, a: Any) -> "Argument": 344 """ 345 A method to specify the behavior of tracing when preparing values to 346 be used as arguments to nodes in the ``Graph``. 347 348 By default, the behavior includes: 349 350 #. Iterate through collection types (e.g. tuple, list, dict) and recursively 351 call ``create_args`` on the elements. 352 #. Given a Proxy object, return a reference to the underlying IR ``Node`` 353 #. Given a non-Proxy Tensor object, emit IR for various cases: 354 355 * For a Parameter, emit a ``get_attr`` node referring to that Parameter 356 * For a non-Parameter Tensor, store the Tensor away in a special 357 attribute referring to that attribute. 358 359 This method can be overridden to support more types. 360 361 Args: 362 363 a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. 364 365 366 Returns: 367 368 The value ``a`` converted into the appropriate ``Argument`` 369 """ 370 # The base tracer is used to construct Graphs when there is no associated 371 # module hierarchy, so it can never create parameter references. 372 # The default tracer adds the ability to refer to parameters when 373 # tracing modules. 374 if isinstance(a, torch.nn.Parameter): 375 for n, p in self.root.named_parameters(): 376 if a is p: 377 return self.create_node("get_attr", n, (), {}) 378 raise NameError("parameter is not a member of this module") 379 elif isinstance(a, torch.Tensor): 380 for n_, p_ in self.root.named_buffers(): 381 if a is p_: 382 return self.create_node("get_attr", n_, (), {}) 383 elif isinstance(a, torch.nn.Module): 384 for n_, p_ in self.root.named_modules(): 385 if a is p_: 386 return self.create_node("get_attr", n_, (), {}) 387 # For NamedTuple instances that appear literally as args, we emit 388 # a node to construct the NamedTuple and use that Node as the argument. 389 if isinstance(a, tuple) and hasattr(a, "_fields"): 390 args = tuple(self.create_arg(elem) for elem in a) 391 return self.create_node("call_function", a.__class__, args, {}) 392 393 # Tensors do not have a reliable string repr() from which they can be 394 # constructed (and we probably don't want to rely on that, either), so 395 # for any constant Tensor values we encounter, first search for if they 396 # are an attribute of some module in the module hierarchy. If so, emit 397 # a get_attr to retrieve that tensor. Otherwise, we'll store away the 398 # tensor value into a special attribute on the Module s.t. we can 399 # retrieve it with a get_attr. 400 if isinstance(a, (torch.Tensor, ScriptObject, FakeScriptObject)): 401 qualname: Optional[str] = self.tensor_attrs.get(a) 402 403 # Tensor was not found in the Module hierarchy, stow it away in a 404 # special attribute and set the qualname to refer to that 405 if not qualname: 406 base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj" 407 qualname = self.get_fresh_qualname(base_name) 408 assert isinstance(qualname, str) 409 self.tensor_attrs[a] = qualname 410 setattr(self.root, qualname, a) 411 412 return self.create_node("get_attr", qualname, (), {}) 413 414 if type(a) in _proxyable_classes: 415 # This is an instance of a proxyable class for which we did not 416 # witness its construction. Intern this as a constant attribute 417 418 # TODO: binary search 419 qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_") 420 assert isinstance(qualname, str) 421 setattr(self.root, qualname, a) 422 423 return self.create_node("get_attr", qualname, (), {}) 424 425 return super().create_arg(a) 426 427 @compatibility(is_backward_compatible=True) 428 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: 429 """ 430 A method to specify whether a given ``nn.Module`` is a "leaf" module. 431 432 Leaf modules are the atomic units that appear in 433 the IR, referenced by ``call_module`` calls. By default, 434 Modules in the PyTorch standard library namespace (torch.nn) 435 are leaf modules. All other modules are traced through and 436 their constituent ops are recorded, unless specified otherwise 437 via this parameter. 438 439 Args: 440 441 m (Module): The module being queried about 442 module_qualified_name (str): The path to root of this module. For example, 443 if you have a module hierarchy where submodule ``foo`` contains 444 submodule ``bar``, which contains submodule ``baz``, that module will 445 appear with the qualified name ``foo.bar.baz`` here. 446 """ 447 return ( 448 (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) 449 and not isinstance(m, torch.nn.Sequential) 450 ) 451 452 @compatibility(is_backward_compatible=True) 453 def path_of_module(self, mod: torch.nn.Module) -> str: 454 """ 455 Helper method to find the qualified name of ``mod`` in the Module hierarchy 456 of ``root``. For example, if ``root`` has a submodule named ``foo``, which has 457 a submodule named ``bar``, passing ``bar`` into this function will return 458 the string "foo.bar". 459 460 Args: 461 462 mod (str): The ``Module`` to retrieve the qualified name for. 463 """ 464 # Prefer the O(1) algorithm 465 if self.submodule_paths: 466 path = self.submodule_paths.get(mod) 467 if path is None: 468 raise NameError("module is not installed as a submodule") 469 assert isinstance(path, str) 470 return path 471 # O(N^2) fallback in the case that we didn't store the submodule 472 # paths. 473 else: 474 for n, p in self.root.named_modules(): 475 if mod is p: 476 return n 477 raise NameError("module is not installed as a submodule") 478 479 @compatibility(is_backward_compatible=True) 480 def call_module( 481 self, 482 m: torch.nn.Module, 483 forward: Callable[..., Any], 484 args: Tuple[Any, ...], 485 kwargs: Dict[str, Any], 486 ) -> Any: 487 """ 488 Method that specifies the behavior of this ``Tracer`` when it encounters 489 a call to an ``nn.Module`` instance. 490 491 By default, the behavior is to check if the called module is a leaf module 492 via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to 493 ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through 494 the operations in its ``forward`` function. 495 496 This method can be overridden to--for example--create nested traced 497 GraphModules, or any other behavior you would want while tracing across 498 ``Module`` boundaries. 499 500 Args: 501 502 m (Module): The module for which a call is being emitted 503 forward (Callable): The forward() method of the ``Module`` to be invoked 504 args (Tuple): args of the module callsite 505 kwargs (Dict): kwargs of the module callsite 506 507 Return: 508 509 The return value from the Module call. In the case that a ``call_module`` 510 node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever 511 value was returned from the ``Module`` invocation. 512 """ 513 module_qualified_name = self.path_of_module(m) 514 with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: 515 # module_stack is an ordered dict so writing then deleting the 516 # entry is equivalent to push/pop on a list 517 self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) 518 if not self.is_leaf_module(m, module_qualified_name): 519 ret_val = forward(*args, **kwargs) 520 else: 521 ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) 522 key, _ = self.module_stack.popitem(last=True) 523 assert key == _scope.module_path, f" Unexpected key {key}" 524 525 return ret_val 526 527 @compatibility(is_backward_compatible=False) 528 def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): 529 """ 530 Method that specifies the behavior of this ``Tracer`` when we call getattr 531 on a call to an ``nn.Module`` instance. 532 533 By default, the behavior is to return a proxy value for the attribute. It 534 also stores the proxy value in the ``parameter_proxy_cache``, so that future 535 calls will reuse the proxy rather than creating a new one. 536 537 This method can be overridden to --for example-- not return proxies when 538 querying parameters. 539 540 Args: 541 542 attr (str): The name of the attribute being queried 543 attr_val (Any): The value of the attribute 544 parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies 545 546 Return: 547 548 The return value from the getattr call. 549 """ 550 def maybe_get_proxy_for_attr( 551 attr_val, collection_to_search, parameter_proxy_cache 552 ): 553 for n, p in collection_to_search: 554 if attr_val is p: 555 if n not in parameter_proxy_cache: 556 kwargs = {} 557 if ( 558 "proxy_factory_fn" 559 in inspect.signature(self.create_proxy).parameters 560 ): 561 kwargs["proxy_factory_fn"] = ( 562 None 563 if not self.param_shapes_constant 564 else lambda node: ParameterProxy( 565 self, node, n, attr_val 566 ) 567 ) 568 val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] 569 parameter_proxy_cache[n] = val_proxy 570 return parameter_proxy_cache[n] 571 return None 572 573 if isinstance(attr_val, torch.nn.Parameter): 574 maybe_parameter_proxy = maybe_get_proxy_for_attr( 575 attr_val, self.root.named_parameters(), parameter_proxy_cache 576 ) 577 if maybe_parameter_proxy is not None: 578 return maybe_parameter_proxy 579 580 if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): 581 maybe_buffer_proxy = maybe_get_proxy_for_attr( 582 attr_val, self.root.named_buffers(), parameter_proxy_cache 583 ) 584 if maybe_buffer_proxy is not None: 585 return maybe_buffer_proxy 586 587 return attr_val 588 589 # This method will be refactored 590 @compatibility(is_backward_compatible=False) 591 def create_args_for_root(self, root_fn, is_module, concrete_args=None): 592 """ 593 Create ``placeholder`` nodes corresponding to the signature of the ``root`` 594 Module. This method introspects root's signature and emits those 595 nodes accordingly, also supporting ``*args`` and ``**kwargs``. 596 """ 597 # In some cases, a function or method has been decorated with a wrapper 598 # defined via ``functools.wraps``. In this case, the outer code object 599 # will likely not contain the actual parameters we care about, so unwrap 600 # the function to get to the innermost callable. 601 fn_for_analysis = inspect.unwrap(root_fn) 602 co = fn_for_analysis.__code__ 603 total_args = co.co_argcount + co.co_kwonlyargcount 604 orig_args = list(co.co_varnames) 605 names_iter = iter(co.co_varnames) 606 args: List[Any] = [] 607 skip_arg_idx = 0 608 if is_module: 609 if total_args == 0: 610 raise RuntimeError( 611 "``self`` argument cannot be part of *args expansion!" 612 ) 613 skip_arg_idx = 1 614 next(names_iter) # skip self 615 args.append(self.root) 616 617 sig = inspect.signature(fn_for_analysis) 618 619 620 # This covers the very specific case where we are passing in flat 621 # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). 622 # In this case, just take the concrete_args and pass them through. 623 name_idx = 0 624 if isinstance(concrete_args, tuple) and \ 625 len(concrete_args) > 0 and \ 626 (co.co_flags & HAS_VARSTUFF) and \ 627 total_args == 1: 628 for concrete_arg in concrete_args: 629 out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) 630 if isinstance(concrete_arg, PHBase): 631 if concrete_arg != PH: 632 # Transfer attrs in the case where you're using a placeholder other 633 # than the singleton PH (PH has no attributes to transfer). 634 # Proxies were created out of the placeholders. 635 # Transfer any metadata (put on the placeholders in the form of 636 # attributes set by the user) from the placeholder to the 637 # underlying nodes (the proxy is unwrapped by the user, but 638 # the metadata should hold). 639 _transfer_attrs(fr=concrete_arg, to=out.node) 640 args.append(out) 641 name_idx += 1 642 return root_fn, args 643 644 arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] 645 if isinstance(concrete_args, tuple): 646 if len(arg_names) != len(concrete_args): 647 raise RuntimeError( 648 f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" 649 ) 650 concrete_args = dict(zip(arg_names, concrete_args)) 651 652 def proxy_placeholder(name): 653 return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) 654 655 args.extend(proxy_placeholder(names) for names in arg_names) 656 657 if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: 658 # TODO: type annotations for *args and **kwargs 659 if co.co_flags & inspect.CO_VARARGS: 660 args.append(proxy_placeholder("*" + next(names_iter))) 661 if co.co_flags & inspect.CO_VARKEYWORDS: 662 args.append(proxy_placeholder("**" + next(names_iter))) 663 root_fn = _patch_function(root_fn, len(args)) 664 665 flat_args, in_spec = pytree.tree_flatten(tuple(args)) 666 if not all(child.is_leaf() for child in in_spec.children_specs): 667 # In the case that we have pytree-flattened inputs in 668 # `concrete_args`, generate a flattening wrapper around the 669 # original root function and return that. 670 self.graph._codegen = _PyTreeCodeGen( 671 _PyTreeInfo(orig_args[:total_args], in_spec, None) 672 ) 673 674 def flatten_fn(*args): 675 tree_args = pytree.tree_unflatten(list(args), in_spec) 676 tree_out = root_fn(*tree_args) 677 out_args, out_spec = pytree.tree_flatten(tree_out) 678 assert isinstance(self.graph._codegen, _PyTreeCodeGen) 679 self.graph._codegen.pytree_info = ( 680 self.graph._codegen.pytree_info._replace(out_spec=out_spec) 681 ) 682 return out_args 683 684 return flatten_fn, flat_args 685 return root_fn, args 686 687 @compatibility(is_backward_compatible=True) 688 def trace( 689 self, 690 root: Union[torch.nn.Module, Callable[..., Any]], 691 concrete_args: Optional[Dict[str, Any]] = None, 692 ) -> Graph: 693 """ 694 Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` 695 can either be an ``nn.Module`` instance or a Python callable. 696 697 Note that after this call, ``self.root`` may be different from the ``root`` passed 698 in here. For example, when a free function is passed to ``trace()``, we will 699 create an ``nn.Module`` instance to use as the root and add embedded constants 700 to. 701 702 703 Args: 704 705 root (Union[Module, Callable]): Either a ``Module`` or a function to be 706 traced through. Backwards-compatibility for this parameter is 707 guaranteed. 708 concrete_args (Optional[Dict[str, any]]): Concrete arguments that should 709 not be treated as Proxies. This parameter is experimental and 710 its backwards-compatibility is *NOT* guaranteed. 711 712 Returns: 713 714 A ``Graph`` representing the semantics of the passed-in ``root``. 715 """ 716 global _is_fx_tracing_flag 717 old_is_fx_tracing_flag = _is_fx_tracing_flag 718 _is_fx_tracing_flag = True 719 try: 720 if isinstance(root, torch.nn.Module): 721 722 # do real recompilation for _LazyGraphModule before retracing since the trace 723 # method can not trace the _lazy_forward method. Got error: 724 # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 725 # without this. 726 from torch.fx._lazy_graph_module import _LazyGraphModule 727 _LazyGraphModule.force_recompile(root) 728 729 self.root = root 730 731 assert hasattr( 732 type(root), self.traced_func_name 733 ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" 734 735 fn = getattr(type(root), self.traced_func_name) 736 self.root_module_name = root._get_name() 737 self.submodule_paths = {mod: name for name, mod in root.named_modules()} 738 else: 739 self.root = torch.nn.Module() 740 fn = root 741 742 tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) 743 self.graph = Graph(tracer_cls=tracer_cls) 744 if hasattr(fn, '__code__'): 745 code = fn.__code__ 746 self.graph._co_fields = { 747 'co_name': code.co_name, 748 'co_filename': code.co_filename, 749 'co_firstlineno': code.co_firstlineno, 750 } 751 752 # When we encounter a Tensor value that's not a parameter, we look if it 753 # is some other attribute on the model. Construct a dict mapping Tensor 754 # values to the qualified name here for efficiency. This is used downstream 755 # in create_arg 756 self.tensor_attrs: Dict[ 757 Union[ 758 torch.Tensor, 759 ScriptObject, 760 FakeScriptObject 761 ], str 762 ] = {} 763 764 def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): 765 for k, v in m.__dict__.items(): 766 if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)): 767 self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) 768 for k, v in m.named_children(): 769 collect_tensor_attrs(v, prefix_atoms + [k]) 770 771 collect_tensor_attrs(self.root, []) 772 773 assert isinstance(fn, FunctionType) 774 775 fn_globals = fn.__globals__ # run before it gets patched 776 fn, args = self.create_args_for_root( 777 fn, isinstance(root, torch.nn.Module), concrete_args 778 ) 779 780 parameter_proxy_cache: Dict[ 781 str, Proxy 782 ] = {} # Reduce number of get_attr calls 783 784 # Method dispatch on parameters is not recorded unless it's directly used. 785 # Thus, we need to insert a proxy when __getattr__ requests a parameter. 786 @functools.wraps(_orig_module_getattr) 787 def module_getattr_wrapper(mod, attr): 788 attr_val = _orig_module_getattr(mod, attr) 789 return self.getattr(attr, attr_val, parameter_proxy_cache) 790 791 @functools.wraps(_orig_module_call) 792 def module_call_wrapper(mod, *args, **kwargs): 793 def forward(*args, **kwargs): 794 return _orig_module_call(mod, *args, **kwargs) 795 796 _autowrap_check( 797 patcher, # type: ignore[has-type] 798 getattr(getattr(mod, "forward", mod), "__globals__", {}), 799 self._autowrap_function_ids, 800 ) 801 return self.call_module(mod, forward, args, kwargs) 802 803 with _new_patcher() as patcher: 804 # allow duplicate patches to support the case of nested calls 805 patcher.patch_method( 806 torch.nn.Module, 807 "__getattr__", 808 module_getattr_wrapper, 809 deduplicate=False, 810 ) 811 patcher.patch_method( 812 torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False 813 ) 814 _patch_wrapped_functions(patcher) 815 _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) 816 for module in self._autowrap_search: 817 _autowrap_check( 818 patcher, module.__dict__, self._autowrap_function_ids 819 ) 820 self.create_node( 821 "output", 822 "output", 823 (self.create_arg(fn(*args)),), 824 {}, 825 type_expr=fn.__annotations__.get("return", None), 826 ) 827 828 self.submodule_paths = None 829 finally: 830 _is_fx_tracing_flag = old_is_fx_tracing_flag 831 return self.graph 832 833 def __deepcopy__(self, memo): 834 # _autowrap_search contains modules, which cannot be deepcopied. 835 new_tracer = Tracer.__new__(Tracer) 836 837 for k, v in self.__dict__.items(): 838 if k in {'_autowrap_search'}: 839 new_obj = copy.copy(v) 840 else: 841 new_obj = copy.deepcopy(v, memo) 842 843 new_tracer.__dict__[k] = new_obj 844 845 return new_tracer 846 847 def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): 848 if concrete_args is not None and name in concrete_args: 849 cnt = 0 850 851 def replace_ph(x): 852 nonlocal cnt 853 cnt += 1 854 param = sig.parameters[name] 855 default = ( 856 () 857 if param.default is inspect.Parameter.empty 858 else (param.default,) 859 ) 860 out = self.create_proxy( 861 "placeholder", f"{name}_{str(cnt)}", default, {} 862 ) 863 if isinstance(x, PHBase): 864 if x != PH: 865 # Transfer attrs in the case where you're using a placeholder other 866 # than the singleton PH (PH has no attributes to transfer). 867 # Proxies were created out of the placeholders. 868 # Transfer any metadata (put on the placeholders in the form of 869 # attributes set by the user) from the placeholder to the 870 # underlying nodes (the proxy is unwrapped by the user, but 871 # the metadata should hold). 872 _transfer_attrs(fr=x, to=out.node) 873 874 return out 875 # Union[int, bool] == bool in Python <= 3.6 876 if ( 877 type(x) == bool 878 or type(x) in base_types 879 and type(x) != torch.Tensor 880 ): 881 torch._assert( 882 out == x, 883 f"{name} has been specialized to have value {x} but got another value", 884 ) 885 elif x is None: 886 args = ( 887 out, 888 f"{name} has been specialized to have value None but got another value", 889 ) 890 self.create_proxy("call_function", _assert_is_none, args, {}) 891 else: 892 warnings.warn( 893 f"Was not able to add assertion to guarantee correct input {name} to " 894 f"specialized function. It is up to the user to make sure that your inputs match the " 895 f"inputs you specialized the function with." 896 ) 897 898 return x 899 900 return pytree.tree_map(replace_ph, concrete_args[name]) 901 if name[0] == "*": 902 default = () 903 else: 904 param = sig.parameters[name] 905 default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] 906 return self.create_proxy( 907 "placeholder", 908 name, 909 default, 910 {}, 911 type_expr=fn_for_analysis.__annotations__.get(name, None) 912 ) 913 914 915# Dictionary of (id(globals dict), function name) => globals_dict to patch for 916# the purposes of the wrap() API. 917# We key by the globals dict id and function name to ensure we're wrapping a given 918# function only once. 919_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} 920 921# List of methods on classes to wrap (class type, function name) 922# this currently only works for Tensor.* methods that aren't traced properly 923_wrapped_methods_to_patch: List[Tuple[type, str]] = [] 924 925if os.environ.get("FX_PATCH_GETITEM") == "1": 926 # This change is needed to trace models like PositionalEmbedding from BERT: 927 # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py 928 # but causes issues in quantization documented here: 929 # https://github.com/pytorch/pytorch/issues/50710 930 # once that is fixed we can make this the default behavior. 931 _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) 932 933 934def _find_proxy(*objects_to_search): 935 """ 936 Recursively search a data structure for a Proxy() and return it, 937 return None if not found. 938 """ 939 proxy = None 940 941 def find_proxy(x): 942 nonlocal proxy 943 if isinstance(x, Proxy): 944 proxy = x 945 946 map_aggregate(objects_to_search, find_proxy) 947 return proxy 948 949 950def _create_wrapped_func(orig_fn): 951 @functools.wraps(orig_fn) 952 def wrapped(*args, **kwargs): 953 """ 954 Given an closed-over ``orig_function`` to invoke, search the args and kwargs for 955 a Proxy object. If there is one, emit a ``call_function`` node to preserve the 956 call to this leaf function directly. Otherwise, just return the results of 957 this function call, as this function is not being traced. 958 """ 959 proxy = _find_proxy(args, kwargs) 960 if proxy is not None: 961 return_proxy = proxy.tracer.create_proxy( 962 "call_function", orig_fn, args, kwargs 963 ) 964 return_proxy.node.meta["is_wrapped"] = True 965 return return_proxy 966 return orig_fn(*args, **kwargs) 967 968 return wrapped 969 970 971def _create_wrapped_method(cls, name): 972 orig_fn = getattr(cls, name) 973 974 @functools.wraps(orig_fn) 975 def wrapped(*args, **kwargs): 976 """ 977 Search the args and kwargs for a Proxy object. If there is one, 978 emit a ``call_method`` node to preserve the call to this method 979 directly. Otherwise, just return the results of this function 980 call, as this function is not being traced. 981 """ 982 proxy = _find_proxy(args, kwargs) 983 if proxy is not None: 984 return proxy.tracer.create_proxy("call_method", name, args, kwargs) 985 return orig_fn(*args, **kwargs) 986 987 return wrapped 988 989 990class _PatchedFn(NamedTuple): 991 frame_dict: Any 992 fn_name: str 993 orig_fn: Any 994 new_fn: Any 995 996 def revert(self): 997 raise NotImplementedError 998 999 def patch(self): 1000 raise NotImplementedError 1001 1002 1003class _PatchedFnSetItem(_PatchedFn): 1004 def revert(self): 1005 self.frame_dict[self.fn_name] = self.orig_fn 1006 1007 def patch(self): 1008 self.frame_dict[self.fn_name] = self.new_fn 1009 1010class _PatchedFnDel(_PatchedFn): 1011 def revert(self): 1012 del self.frame_dict[self.fn_name] 1013 1014 def patch(self): 1015 self.frame_dict[self.fn_name] = self.new_fn 1016 1017 1018class _PatchedFnSetAttr(_PatchedFn): 1019 def revert(self): 1020 setattr(self.frame_dict, self.fn_name, self.orig_fn) 1021 1022 def patch(self): 1023 setattr(self.frame_dict, self.fn_name, self.new_fn) 1024 1025class _Patcher: 1026 def __init__(self) -> None: 1027 super().__init__() 1028 self.patches_made: List[_PatchedFn] = [] 1029 self.visited: Set[int] = set() 1030 1031 def patch( 1032 self, 1033 frame_dict: Dict[str, Any], 1034 name: str, 1035 new_fn: Callable, 1036 deduplicate: bool = True, 1037 ): 1038 """ 1039 Replace frame_dict[name] with new_fn until we exit the context manager. 1040 """ 1041 new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] 1042 if name not in frame_dict and hasattr(builtins, name): 1043 self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) 1044 self.patches_made[-1].patch() 1045 elif getattr(frame_dict[name], "__fx_already_patched", False): 1046 return # already patched, no need to do it again 1047 else: 1048 self.patches_made.append( 1049 _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) 1050 ) 1051 self.patches_made[-1].patch() 1052 1053 def patch_method( 1054 self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True 1055 ): 1056 """ 1057 Replace object_or_dict.name with new_fn until we exit the context manager. 1058 """ 1059 new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] 1060 orig_fn = getattr(cls, name) 1061 if getattr(orig_fn, "__fx_already_patched", False): 1062 return # already patched, no need to do it again 1063 self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) 1064 self.patches_made[-1].patch() 1065 1066 def visit_once(self, thing: Any): 1067 """Return True on the first call to with thing, otherwise false""" 1068 idx = id(thing) 1069 if idx in self.visited: 1070 return False 1071 self.visited.add(idx) 1072 return True 1073 1074 def revert_all_patches(self): 1075 """ 1076 Remove all the stored patcheds. It doesn't modify patches_made. 1077 """ 1078 for patch in self.patches_made: 1079 patch.revert() 1080 return self.patches_made 1081 1082 def reapply_all_patches(self): 1083 """ 1084 Patch all the stored patcheds. It doesn't modify patches_made. 1085 """ 1086 for patch in self.patches_made: 1087 patch.patch() 1088 return self.patches_made 1089 1090 def __enter__(self): 1091 return self 1092 1093 def __exit__(self, exc_type, exc_val, exc_tb): 1094 """ 1095 Undo all the changes made via self.patch() and self.patch_method() 1096 """ 1097 while self.patches_made: 1098 # unpatch in reverse order to handle duplicates correctly 1099 self.patches_made.pop().revert() 1100 self.visited.clear() 1101 1102 1103CURRENT_PATCHER: Optional[_Patcher] = None 1104 1105@contextlib.contextmanager 1106def _new_patcher(): 1107 global CURRENT_PATCHER 1108 prior_patcher = CURRENT_PATCHER 1109 try: 1110 CURRENT_PATCHER = _Patcher() 1111 yield CURRENT_PATCHER 1112 finally: 1113 # Clear all the patches made by when using current patcher. 1114 assert CURRENT_PATCHER is not None 1115 CURRENT_PATCHER.revert_all_patches() 1116 CURRENT_PATCHER = prior_patcher 1117 1118 1119@contextlib.contextmanager 1120def _maybe_revert_all_patches(): 1121 current_patcher = CURRENT_PATCHER 1122 patches_made = None 1123 patches_removed = None 1124 try: 1125 if current_patcher is not None: 1126 patches_removed = current_patcher.revert_all_patches() 1127 yield 1128 finally: 1129 if current_patcher is not None: 1130 patches_made = current_patcher.reapply_all_patches() 1131 assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" 1132 1133def _patch_wrapped_functions(patcher: _Patcher): 1134 """ 1135 Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap 1136 the listed global functions in the `_create_wrapped_func` wrapper. 1137 """ 1138 for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): 1139 if name not in frame_dict and hasattr(builtins, name): 1140 orig_fn = getattr(builtins, name) 1141 else: 1142 orig_fn = frame_dict[name] 1143 patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) 1144 1145 for cls, name in _wrapped_methods_to_patch: 1146 patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) 1147 1148 1149def _autowrap_check( 1150 patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] 1151): 1152 """ 1153 Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. 1154 This method searches a scope for them and patches them if found. 1155 """ 1156 if patcher.visit_once(frame_dict): 1157 for name, value in frame_dict.items(): 1158 if ( 1159 not name.startswith("_") 1160 and callable(value) 1161 and id(value) in function_ids 1162 ): 1163 patcher.patch(frame_dict, name, _create_wrapped_func(value)) 1164 1165 1166@compatibility(is_backward_compatible=True) 1167def wrap(fn_or_name: Union[str, Callable]): 1168 """ 1169 This function can be called at module-level scope to register fn_or_name as a "leaf function". 1170 A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being 1171 traced through:: 1172 1173 # foo/bar/baz.py 1174 def my_custom_function(x, y): 1175 return x * x + y * y 1176 1177 torch.fx.wrap('my_custom_function') 1178 1179 def fn_to_be_traced(x, y): 1180 # When symbolic tracing, the below call to my_custom_function will be inserted into 1181 # the graph rather than tracing it. 1182 return my_custom_function(x, y) 1183 1184 This function can also equivalently be used as a decorator:: 1185 1186 # foo/bar/baz.py 1187 @torch.fx.wrap 1188 def my_custom_function(x, y): 1189 return x * x + y * y 1190 1191 A wrapped function can be thought of a "leaf function", analogous to the concept of 1192 "leaf modules", that is, they are functions that are left as calls in the FX trace 1193 rather than traced through. 1194 1195 Args: 1196 1197 fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the 1198 graph when it's called 1199 """ 1200 if not callable(fn_or_name) and not isinstance(fn_or_name, str): 1201 raise RuntimeError( 1202 "Unsupported type for global function! Must be either a callable or " 1203 "string name" 1204 ) 1205 1206 if callable(fn_or_name): 1207 assert not isinstance(fn_or_name, str) # to make mypy happy 1208 fn_name = fn_or_name.__name__ 1209 else: 1210 assert isinstance( 1211 fn_or_name, str 1212 ), "fn_or_name must be a global function or string name" 1213 fn_name = fn_or_name 1214 1215 currentframe = inspect.currentframe() 1216 assert currentframe is not None 1217 f = currentframe.f_back 1218 assert f is not None 1219 if f.f_code.co_name != "<module>": 1220 raise NotImplementedError("wrap must be called at the top level of a module") 1221 1222 # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search 1223 # semantics would be slightly different, but would add support `from x import wrapped_function` 1224 _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals 1225 return fn_or_name 1226 1227 1228@compatibility(is_backward_compatible=True) 1229def symbolic_trace( 1230 root: Union[torch.nn.Module, Callable[..., Any]], 1231 concrete_args: Optional[Dict[str, Any]] = None, 1232) -> GraphModule: 1233 """ 1234 Symbolic tracing API 1235 1236 Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` 1237 constructed by recording operations seen while tracing through ``root``. 1238 1239 ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. 1240 1241 For example:: 1242 1243 def f(a, b): 1244 if b == True: 1245 return a 1246 else: 1247 return a*2 1248 1249 FX can typically not trace through this due to the presence of control 1250 flow. However, we can use `concrete_args` to specialize on the value of 1251 `b` to trace through this:: 1252 1253 f = fx.symbolic_trace(f, concrete_args={'b': False}) 1254 assert f(3, False) == 6 1255 1256 Note that although you can still pass in different values of `b`, they will be ignored. 1257 1258 We can also use `concrete_args` to eliminate data-structure handling from 1259 our function. This will use pytrees to flatten your input. To avoid 1260 overspecializing, pass in `fx.PH` for values that shouldn't be 1261 specialized. For example:: 1262 1263 def f(x): 1264 out = 0 1265 for v in x.values(): 1266 out += v 1267 return out 1268 f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) 1269 assert f({'a': 1, 'b': 2, 'c': 4}) == 7 1270 1271 1272 Args: 1273 root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted 1274 into a Graph representation. 1275 concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized 1276 1277 Returns: 1278 GraphModule: a Module created from the recorded operations from ``root``. 1279 """ 1280 tracer = Tracer() 1281 graph = tracer.trace(root, concrete_args) 1282 name = ( 1283 root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ 1284 ) 1285 return _make_graph_module(tracer.root, graph, name) 1286 1287 1288@wrap 1289def _assert_is_none(value, msg): 1290 assert value is None, msg 1291