1*da0073e9SAndroid Build Coastguard Worker# Nodes represent a definition of a value in our graph of operators. 2*da0073e9SAndroid Build Coastguard Workerfrom typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set 3*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility 4*da0073e9SAndroid Build Coastguard Workerfrom .immutable_collections import immutable_dict, immutable_list 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerimport builtins 7*da0073e9SAndroid Build Coastguard Workerimport types 8*da0073e9SAndroid Build Coastguard Workerimport inspect 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair 11*da0073e9SAndroid Build Coastguard Workerfrom .._ops import ops as _ops 12*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _NodeBase 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 15*da0073e9SAndroid Build Coastguard Worker from .graph import Graph 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard WorkerBaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, 20*da0073e9SAndroid Build Coastguard Worker torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, 21*da0073e9SAndroid Build Coastguard Worker torch.SymInt, torch.SymBool, torch.SymFloat] 22*da0073e9SAndroid Build Coastguard Workerbase_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard WorkerTarget = Union[Callable[..., Any], str] 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard WorkerArgument = Optional[Union[ 27*da0073e9SAndroid Build Coastguard Worker Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types 28*da0073e9SAndroid Build Coastguard Worker List[Any], # actually Argument 29*da0073e9SAndroid Build Coastguard Worker Dict[str, Any], # actually Argument 30*da0073e9SAndroid Build Coastguard Worker slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing 31*da0073e9SAndroid Build Coastguard Worker range, 32*da0073e9SAndroid Build Coastguard Worker 'Node', 33*da0073e9SAndroid Build Coastguard Worker BaseArgumentTypes 34*da0073e9SAndroid Build Coastguard Worker]] 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { 39*da0073e9SAndroid Build Coastguard Worker torch._C._set_grad_enabled, 40*da0073e9SAndroid Build Coastguard Worker torch.amp._enter_autocast, 41*da0073e9SAndroid Build Coastguard Worker torch.amp._exit_autocast, 42*da0073e9SAndroid Build Coastguard Worker} 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, 45*da0073e9SAndroid Build Coastguard Worker# or add logic to correctly mark all inplace ops as side effectful. 46*da0073e9SAndroid Build Coastguard Worker_side_effectful_functions: Set[Callable] = { 47*da0073e9SAndroid Build Coastguard Worker torch._assert, 48*da0073e9SAndroid Build Coastguard Worker torch._assert_async, 49*da0073e9SAndroid Build Coastguard Worker _ops.aten._assert_async.msg, 50*da0073e9SAndroid Build Coastguard Worker _ops.aten._assert_scalar.default, 51*da0073e9SAndroid Build Coastguard Worker _ops.aten.sym_constrain_range.default, 52*da0073e9SAndroid Build Coastguard Worker _ops.aten.sym_constrain_range_for_size.default, 53*da0073e9SAndroid Build Coastguard Worker _ops.profiler._record_function_enter, 54*da0073e9SAndroid Build Coastguard Worker _ops.profiler._record_function_enter_new, 55*da0073e9SAndroid Build Coastguard Worker _ops.profiler._record_function_exit, 56*da0073e9SAndroid Build Coastguard Worker _ops.inductor.accumulate_grad_.default, 57*da0073e9SAndroid Build Coastguard Worker} | _side_effectful_need_to_be_preserved_pre_dispatch 58*da0073e9SAndroid Build Coastguard Workerif hasattr(_ops.inductor, "resize_storage_bytes_"): 59*da0073e9SAndroid Build Coastguard Worker _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 63*da0073e9SAndroid Build Coastguard Workerdef has_side_effect(fn: Callable) -> Callable: 64*da0073e9SAndroid Build Coastguard Worker _side_effectful_functions.add(fn) 65*da0073e9SAndroid Build Coastguard Worker return fn 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker# this is fixed on master, WAR for 1.5 69*da0073e9SAndroid Build Coastguard Workerdef _find_module_of_method(orig_method: Callable[..., Any]) -> str: 70*da0073e9SAndroid Build Coastguard Worker name = orig_method.__name__ 71*da0073e9SAndroid Build Coastguard Worker module = orig_method.__module__ 72*da0073e9SAndroid Build Coastguard Worker if module is not None: 73*da0073e9SAndroid Build Coastguard Worker return module 74*da0073e9SAndroid Build Coastguard Worker for guess in [torch, torch.nn.functional]: 75*da0073e9SAndroid Build Coastguard Worker if getattr(guess, name, None) is orig_method: 76*da0073e9SAndroid Build Coastguard Worker return guess.__name__ 77*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'cannot find module for {orig_method}') 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker# Borrowed from CPython typing module 80*da0073e9SAndroid Build Coastguard Worker# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 81*da0073e9SAndroid Build Coastguard Workerdef _type_repr(obj: object) -> str: 82*da0073e9SAndroid Build Coastguard Worker """Return the repr() of an object, special-casing types (internal helper). 83*da0073e9SAndroid Build Coastguard Worker If obj is a type, we return a shorter version than the default 84*da0073e9SAndroid Build Coastguard Worker type.__repr__, based on the module and qualified name, which is 85*da0073e9SAndroid Build Coastguard Worker typically enough to uniquely identify a type. For everything 86*da0073e9SAndroid Build Coastguard Worker else, we fall back on repr(obj). 87*da0073e9SAndroid Build Coastguard Worker """ 88*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, type): 89*da0073e9SAndroid Build Coastguard Worker if obj.__module__ == 'builtins': 90*da0073e9SAndroid Build Coastguard Worker return obj.__qualname__ 91*da0073e9SAndroid Build Coastguard Worker return f'{obj.__module__}.{obj.__qualname__}' 92*da0073e9SAndroid Build Coastguard Worker if obj is ...: 93*da0073e9SAndroid Build Coastguard Worker return '...' 94*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, types.FunctionType): 95*da0073e9SAndroid Build Coastguard Worker return obj.__name__ 96*da0073e9SAndroid Build Coastguard Worker return repr(obj) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Workerdef _get_qualified_name(func: Callable[..., Any]) -> str: 99*da0073e9SAndroid Build Coastguard Worker # things like getattr just appear in builtins 100*da0073e9SAndroid Build Coastguard Worker if getattr(builtins, func.__name__, None) is func: 101*da0073e9SAndroid Build Coastguard Worker return func.__name__ 102*da0073e9SAndroid Build Coastguard Worker # torch.Tensor.{fn} 103*da0073e9SAndroid Build Coastguard Worker if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) 104*da0073e9SAndroid Build Coastguard Worker and func is getattr(torch.Tensor, func.__name__, None)): 105*da0073e9SAndroid Build Coastguard Worker return f"torch.Tensor.{func.__name__}" 106*da0073e9SAndroid Build Coastguard Worker name = func.__name__ 107*da0073e9SAndroid Build Coastguard Worker if name == "<lambda>": 108*da0073e9SAndroid Build Coastguard Worker # For lambdas, try to get their defining name in the module 109*da0073e9SAndroid Build Coastguard Worker try: 110*da0073e9SAndroid Build Coastguard Worker name = inspect.getsource(func).split("=")[0].strip() 111*da0073e9SAndroid Build Coastguard Worker except Exception as e: 112*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unable to represent lambda") from e 113*da0073e9SAndroid Build Coastguard Worker module = _find_module_of_method(func) 114*da0073e9SAndroid Build Coastguard Worker module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module 115*da0073e9SAndroid Build Coastguard Worker # Fixup segment_reduce mismatch 116*da0073e9SAndroid Build Coastguard Worker if module == "torch" and name == "segment_reduce": 117*da0073e9SAndroid Build Coastguard Worker name = "_" + name 118*da0073e9SAndroid Build Coastguard Worker return f'{module}.{name}' 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Workerdef _format_arg(arg: object, max_list_len: float = float('inf')) -> str: 121*da0073e9SAndroid Build Coastguard Worker if hasattr(arg, '_custom_fx_repr_fn'): 122*da0073e9SAndroid Build Coastguard Worker return arg._custom_fx_repr_fn() 123*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, list): 124*da0073e9SAndroid Build Coastguard Worker items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) 125*da0073e9SAndroid Build Coastguard Worker maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' 126*da0073e9SAndroid Build Coastguard Worker return f'[{items}{maybe_len}]' 127*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, tuple): 128*da0073e9SAndroid Build Coastguard Worker items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) 129*da0073e9SAndroid Build Coastguard Worker maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' 130*da0073e9SAndroid Build Coastguard Worker maybe_comma = ',' if len(arg) == 1 else '' 131*da0073e9SAndroid Build Coastguard Worker return f'({items}{maybe_comma}{maybe_len})' 132*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, dict): 133*da0073e9SAndroid Build Coastguard Worker items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) 134*da0073e9SAndroid Build Coastguard Worker return f'{{{items_str}}}' 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, Node): 137*da0073e9SAndroid Build Coastguard Worker return '%' + str(arg) 138*da0073e9SAndroid Build Coastguard Worker else: 139*da0073e9SAndroid Build Coastguard Worker return str(arg) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 142*da0073e9SAndroid Build Coastguard Workerclass Node(_NodeBase): 143*da0073e9SAndroid Build Coastguard Worker """ 144*da0073e9SAndroid Build Coastguard Worker ``Node`` is the data structure that represents individual operations within 145*da0073e9SAndroid Build Coastguard Worker a ``Graph``. For the most part, Nodes represent callsites to various entities, 146*da0073e9SAndroid Build Coastguard Worker such as operators, methods, and Modules (some exceptions include nodes that 147*da0073e9SAndroid Build Coastguard Worker specify function inputs and outputs). Each ``Node`` has a function specified 148*da0073e9SAndroid Build Coastguard Worker by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. 151*da0073e9SAndroid Build Coastguard Worker ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument 152*da0073e9SAndroid Build Coastguard Worker denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to 153*da0073e9SAndroid Build Coastguard Worker the function parameters (e.g. ``x``) in the graph printout. 154*da0073e9SAndroid Build Coastguard Worker - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the 155*da0073e9SAndroid Build Coastguard Worker fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. 156*da0073e9SAndroid Build Coastguard Worker ``args`` and ``kwargs`` are don't-care 157*da0073e9SAndroid Build Coastguard Worker - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign 158*da0073e9SAndroid Build Coastguard Worker to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, 159*da0073e9SAndroid Build Coastguard Worker following the Python calling convention 160*da0073e9SAndroid Build Coastguard Worker - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is 161*da0073e9SAndroid Build Coastguard Worker as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. 162*da0073e9SAndroid Build Coastguard Worker ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. 163*da0073e9SAndroid Build Coastguard Worker - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method 164*da0073e9SAndroid Build Coastguard Worker to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, 165*da0073e9SAndroid Build Coastguard Worker *including the self argument* 166*da0073e9SAndroid Build Coastguard Worker - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement 167*da0073e9SAndroid Build Coastguard Worker in the Graph printout. 168*da0073e9SAndroid Build Coastguard Worker """ 169*da0073e9SAndroid Build Coastguard Worker _args: Tuple['Argument', ...] 170*da0073e9SAndroid Build Coastguard Worker _kwargs: Dict[str, 'Argument'] 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 173*da0073e9SAndroid Build Coastguard Worker def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', 174*da0073e9SAndroid Build Coastguard Worker args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], 175*da0073e9SAndroid Build Coastguard Worker return_type : Optional[Any] = None) -> None: 176*da0073e9SAndroid Build Coastguard Worker """ 177*da0073e9SAndroid Build Coastguard Worker Instantiate an instance of ``Node``. Note: most often, you want to use the 178*da0073e9SAndroid Build Coastguard Worker Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather 179*da0073e9SAndroid Build Coastguard Worker than instantiating a ``Node`` directly. 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker Args: 182*da0073e9SAndroid Build Coastguard Worker graph (Graph): The ``Graph`` to which this ``Node`` should belong. 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker name (str): The name to which the output of this ``Node`` should be assigned 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker op (str): The opcode for this ``Node``. Can be one of 'placeholder', 187*da0073e9SAndroid Build Coastguard Worker 'call_method', 'call_module', 'call_function', 'get_attr', 188*da0073e9SAndroid Build Coastguard Worker 'output' 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker target ('Target'): The target this op should call. See the broader 191*da0073e9SAndroid Build Coastguard Worker ``Node`` docstring for more details. 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker args (Tuple['Argument']): The args to be passed to ``target`` 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker return_type (Optional[Any]): The python type expression representing the 198*da0073e9SAndroid Build Coastguard Worker type of the output of this node. This field can be used for 199*da0073e9SAndroid Build Coastguard Worker annotation of values in the generated code or for other types 200*da0073e9SAndroid Build Coastguard Worker of analyses. 201*da0073e9SAndroid Build Coastguard Worker """ 202*da0073e9SAndroid Build Coastguard Worker super().__init__() 203*da0073e9SAndroid Build Coastguard Worker self.graph = graph 204*da0073e9SAndroid Build Coastguard Worker self.name = name # unique name of value being created 205*da0073e9SAndroid Build Coastguard Worker assert op in _legal_ops 206*da0073e9SAndroid Build Coastguard Worker self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr 207*da0073e9SAndroid Build Coastguard Worker if op == 'call_function': 208*da0073e9SAndroid Build Coastguard Worker if not callable(target): 209*da0073e9SAndroid Build Coastguard Worker raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 210*da0073e9SAndroid Build Coastguard Worker 'but a Callable is expected') 211*da0073e9SAndroid Build Coastguard Worker else: 212*da0073e9SAndroid Build Coastguard Worker if not isinstance(target, str): 213*da0073e9SAndroid Build Coastguard Worker raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 214*da0073e9SAndroid Build Coastguard Worker 'but a str is expected') 215*da0073e9SAndroid Build Coastguard Worker self.target = target # for method/module/function, the name of the method/module/function/attr 216*da0073e9SAndroid Build Coastguard Worker # being invoked, e.g add, layer1, or torch.add 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker # All `Node`-valued inputs. Key is the Node, value is don't-care. 219*da0073e9SAndroid Build Coastguard Worker # The public API for this is `all_input_nodes`, this private attribute 220*da0073e9SAndroid Build Coastguard Worker # should not be accessed directly. 221*da0073e9SAndroid Build Coastguard Worker self._input_nodes : Dict[Node, None] = {} 222*da0073e9SAndroid Build Coastguard Worker self.__update_args_kwargs(args, kwargs) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker # All of the nodes that use the value produced by this Node 225*da0073e9SAndroid Build Coastguard Worker # Note one user may correspond to several uses, e.g. the node fo ``x + x`` 226*da0073e9SAndroid Build Coastguard Worker # would appear once here, but represents two uses. 227*da0073e9SAndroid Build Coastguard Worker # 228*da0073e9SAndroid Build Coastguard Worker # Is a dict to act as an "ordered set". Keys are significant, value dont-care 229*da0073e9SAndroid Build Coastguard Worker self.users : Dict[Node, None] = {} 230*da0073e9SAndroid Build Coastguard Worker # Type expression representing the output value of this node. 231*da0073e9SAndroid Build Coastguard Worker # This should contain the same class of Type objects that would appear 232*da0073e9SAndroid Build Coastguard Worker # as type annotations for function inputs/outputs. 233*da0073e9SAndroid Build Coastguard Worker # 234*da0073e9SAndroid Build Coastguard Worker # For placeholder nodes, this value will be used to type-annotate the 235*da0073e9SAndroid Build Coastguard Worker # generated function parameters. 236*da0073e9SAndroid Build Coastguard Worker # For the return node, this value will be used to type-annotate the 237*da0073e9SAndroid Build Coastguard Worker # generated function return type. (Note this is a special case. ``return`` 238*da0073e9SAndroid Build Coastguard Worker # does not produce a value, it's more of a notation. Thus, this value 239*da0073e9SAndroid Build Coastguard Worker # describes the type of args[0] in the ``return`` node. 240*da0073e9SAndroid Build Coastguard Worker self.type : Optional[Any] = return_type 241*da0073e9SAndroid Build Coastguard Worker self._sort_key: Any = () 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker # If set, use this fn to print this node 244*da0073e9SAndroid Build Coastguard Worker self._repr_fn : Optional[Callable[[Node], str]] = None 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # Dictionary to store metadata passes need to do their 247*da0073e9SAndroid Build Coastguard Worker # transformations. This metadata is preserved across node copies 248*da0073e9SAndroid Build Coastguard Worker self.meta : Dict[str, Any] = {} 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker def __getstate__(self) -> Dict[str, Any]: 251*da0073e9SAndroid Build Coastguard Worker state = self.__dict__.copy() 252*da0073e9SAndroid Build Coastguard Worker state["_erased"] = self._erased 253*da0073e9SAndroid Build Coastguard Worker state["_prev"] = self._prev 254*da0073e9SAndroid Build Coastguard Worker state["_next"] = self._next 255*da0073e9SAndroid Build Coastguard Worker return state 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state: Dict[str, Any]) -> None: 258*da0073e9SAndroid Build Coastguard Worker _erased = state.pop("_erased") 259*da0073e9SAndroid Build Coastguard Worker _prev = state.pop("_prev") 260*da0073e9SAndroid Build Coastguard Worker _next = state.pop("_next") 261*da0073e9SAndroid Build Coastguard Worker self.__dict__.update(state) 262*da0073e9SAndroid Build Coastguard Worker self._erased = _erased 263*da0073e9SAndroid Build Coastguard Worker self._prev = _prev 264*da0073e9SAndroid Build Coastguard Worker self._next = _next 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker @property 267*da0073e9SAndroid Build Coastguard Worker def next(self) -> 'Node': 268*da0073e9SAndroid Build Coastguard Worker """ 269*da0073e9SAndroid Build Coastguard Worker Returns the next ``Node`` in the linked list of Nodes. 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker Returns: 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker The next ``Node`` in the linked list of Nodes. 274*da0073e9SAndroid Build Coastguard Worker """ 275*da0073e9SAndroid Build Coastguard Worker return self._next 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker @property 278*da0073e9SAndroid Build Coastguard Worker def prev(self) -> 'Node': 279*da0073e9SAndroid Build Coastguard Worker """ 280*da0073e9SAndroid Build Coastguard Worker Returns the previous ``Node`` in the linked list of Nodes. 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker Returns: 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker The previous ``Node`` in the linked list of Nodes. 285*da0073e9SAndroid Build Coastguard Worker """ 286*da0073e9SAndroid Build Coastguard Worker return self._prev 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 289*da0073e9SAndroid Build Coastguard Worker def prepend(self, x: 'Node') -> None: 290*da0073e9SAndroid Build Coastguard Worker """ 291*da0073e9SAndroid Build Coastguard Worker Insert x before this node in the list of nodes in the graph. Example:: 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker Before: p -> self 294*da0073e9SAndroid Build Coastguard Worker bx -> x -> ax 295*da0073e9SAndroid Build Coastguard Worker After: p -> x -> self 296*da0073e9SAndroid Build Coastguard Worker bx -> ax 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker Args: 299*da0073e9SAndroid Build Coastguard Worker x (Node): The node to put before this node. Must be a member of the same graph. 300*da0073e9SAndroid Build Coastguard Worker """ 301*da0073e9SAndroid Build Coastguard Worker assert self.graph == x.graph, "Attempting to move a Node into a different Graph" 302*da0073e9SAndroid Build Coastguard Worker if self == x: 303*da0073e9SAndroid Build Coastguard Worker warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") 304*da0073e9SAndroid Build Coastguard Worker return 305*da0073e9SAndroid Build Coastguard Worker x._remove_from_list() 306*da0073e9SAndroid Build Coastguard Worker p = self._prev 307*da0073e9SAndroid Build Coastguard Worker p._next, x._prev = x, p 308*da0073e9SAndroid Build Coastguard Worker x._next, self._prev = self, x 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker # compute x._sort_key 311*da0073e9SAndroid Build Coastguard Worker psk = x._prev._sort_key 312*da0073e9SAndroid Build Coastguard Worker nsk = x._next._sort_key 313*da0073e9SAndroid Build Coastguard Worker if len(psk) > len(nsk): 314*da0073e9SAndroid Build Coastguard Worker idx: int 315*da0073e9SAndroid Build Coastguard Worker *prefix, idx = psk[:len(nsk) + 1] 316*da0073e9SAndroid Build Coastguard Worker x._sort_key = (*prefix, idx + 1) 317*da0073e9SAndroid Build Coastguard Worker elif len(psk) < len(nsk): 318*da0073e9SAndroid Build Coastguard Worker *prefix, idx = nsk[:len(psk) + 1] 319*da0073e9SAndroid Build Coastguard Worker x._sort_key = (*prefix, idx - 1) 320*da0073e9SAndroid Build Coastguard Worker else: # same length, increase length by 1 321*da0073e9SAndroid Build Coastguard Worker x._sort_key = (*psk, 0) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def __gt__(self, other: 'Node') -> bool: 324*da0073e9SAndroid Build Coastguard Worker return self._sort_key > other._sort_key 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker def __lt__(self, other: 'Node') -> bool: 327*da0073e9SAndroid Build Coastguard Worker return self._sort_key < other._sort_key 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def __ge__(self, other: 'Node') -> bool: 330*da0073e9SAndroid Build Coastguard Worker return self > other or self == other 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def __le__(self, other: 'Node') -> bool: 333*da0073e9SAndroid Build Coastguard Worker return self < other or self == other 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 336*da0073e9SAndroid Build Coastguard Worker def append(self, x: 'Node') -> None: 337*da0073e9SAndroid Build Coastguard Worker """ 338*da0073e9SAndroid Build Coastguard Worker Insert ``x`` after this node in the list of nodes in the graph. 339*da0073e9SAndroid Build Coastguard Worker Equivalent to ``self.next.prepend(x)`` 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker Args: 342*da0073e9SAndroid Build Coastguard Worker x (Node): The node to put after this node. Must be a member of the same graph. 343*da0073e9SAndroid Build Coastguard Worker """ 344*da0073e9SAndroid Build Coastguard Worker self._next.prepend(x) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker def _remove_from_list(self) -> None: 347*da0073e9SAndroid Build Coastguard Worker p, n = self._prev, self._next 348*da0073e9SAndroid Build Coastguard Worker p._next, n._prev = n, p 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker @property 351*da0073e9SAndroid Build Coastguard Worker def args(self) -> Tuple[Argument, ...]: 352*da0073e9SAndroid Build Coastguard Worker """ 353*da0073e9SAndroid Build Coastguard Worker The tuple of arguments to this ``Node``. The interpretation of arguments 354*da0073e9SAndroid Build Coastguard Worker depends on the node's opcode. See the :class:`Node` docstring for more 355*da0073e9SAndroid Build Coastguard Worker information. 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker Assignment to this property is allowed. All accounting of uses and users 358*da0073e9SAndroid Build Coastguard Worker is updated automatically on assignment. 359*da0073e9SAndroid Build Coastguard Worker """ 360*da0073e9SAndroid Build Coastguard Worker return self._args 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker @args.setter 363*da0073e9SAndroid Build Coastguard Worker def args(self, a : Tuple[Argument, ...]) -> None: 364*da0073e9SAndroid Build Coastguard Worker """ 365*da0073e9SAndroid Build Coastguard Worker Set the tuple of arguments to this Node. The interpretation of arguments 366*da0073e9SAndroid Build Coastguard Worker depends on the node's opcode. See the ``fx.Graph`` docstring for more 367*da0073e9SAndroid Build Coastguard Worker information. 368*da0073e9SAndroid Build Coastguard Worker """ 369*da0073e9SAndroid Build Coastguard Worker # DO NOT CALL `__update_args_kwargs` directly. The correct way to 370*da0073e9SAndroid Build Coastguard Worker # set `args` is via direct assignment, i.e. `node.args = new_args` 371*da0073e9SAndroid Build Coastguard Worker self.__update_args_kwargs(a, self._kwargs) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker @property 374*da0073e9SAndroid Build Coastguard Worker def kwargs(self) -> Dict[str, Argument]: 375*da0073e9SAndroid Build Coastguard Worker """ 376*da0073e9SAndroid Build Coastguard Worker The dict of keyword arguments to this ``Node``. The interpretation of arguments 377*da0073e9SAndroid Build Coastguard Worker depends on the node's opcode. See the :class:`Node` docstring for more 378*da0073e9SAndroid Build Coastguard Worker information. 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker Assignment to this property is allowed. All accounting of uses and users 381*da0073e9SAndroid Build Coastguard Worker is updated automatically on assignment. 382*da0073e9SAndroid Build Coastguard Worker """ 383*da0073e9SAndroid Build Coastguard Worker return self._kwargs 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker @kwargs.setter 386*da0073e9SAndroid Build Coastguard Worker def kwargs(self, k : Dict[str, Argument]) -> None: 387*da0073e9SAndroid Build Coastguard Worker """ 388*da0073e9SAndroid Build Coastguard Worker Set the dict of kwargs to this Node. The interpretation of arguments 389*da0073e9SAndroid Build Coastguard Worker depends on the node's opcode. See the ``fx.Graph`` docstring for more 390*da0073e9SAndroid Build Coastguard Worker information. 391*da0073e9SAndroid Build Coastguard Worker """ 392*da0073e9SAndroid Build Coastguard Worker # DO NOT CALL `__update_args_kwargs` directly. The correct way to 393*da0073e9SAndroid Build Coastguard Worker # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` 394*da0073e9SAndroid Build Coastguard Worker self.__update_args_kwargs(self._args, k) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker @property 397*da0073e9SAndroid Build Coastguard Worker def all_input_nodes(self) -> List['Node']: 398*da0073e9SAndroid Build Coastguard Worker """ 399*da0073e9SAndroid Build Coastguard Worker Return all Nodes that are inputs to this Node. This is equivalent to 400*da0073e9SAndroid Build Coastguard Worker iterating over ``args`` and ``kwargs`` and only collecting the values that 401*da0073e9SAndroid Build Coastguard Worker are Nodes. 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker Returns: 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this 406*da0073e9SAndroid Build Coastguard Worker ``Node``, in that order. 407*da0073e9SAndroid Build Coastguard Worker """ 408*da0073e9SAndroid Build Coastguard Worker return list(self._input_nodes.keys()) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 411*da0073e9SAndroid Build Coastguard Worker def update_arg(self, idx : int, arg : Argument) -> None: 412*da0073e9SAndroid Build Coastguard Worker """ 413*da0073e9SAndroid Build Coastguard Worker Update an existing positional argument to contain the new value 414*da0073e9SAndroid Build Coastguard Worker ``arg``. After calling, ``self.args[idx] == arg``. 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker Args: 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker idx (int): The index into ``self.args`` of the element to update 419*da0073e9SAndroid Build Coastguard Worker arg (Argument): The new argument value to write into ``args`` 420*da0073e9SAndroid Build Coastguard Worker """ 421*da0073e9SAndroid Build Coastguard Worker args = list(self.args) 422*da0073e9SAndroid Build Coastguard Worker args[idx] = arg 423*da0073e9SAndroid Build Coastguard Worker self.args = tuple(args) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 426*da0073e9SAndroid Build Coastguard Worker def insert_arg(self, idx : int, arg : Argument) -> None: 427*da0073e9SAndroid Build Coastguard Worker """ 428*da0073e9SAndroid Build Coastguard Worker Insert an positional argument to the argument list with given index. 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker Args: 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker idx (int): The index of the element in ``self.args`` to be inserted before. 433*da0073e9SAndroid Build Coastguard Worker arg (Argument): The new argument value to insert into ``args`` 434*da0073e9SAndroid Build Coastguard Worker """ 435*da0073e9SAndroid Build Coastguard Worker assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" 436*da0073e9SAndroid Build Coastguard Worker args_left = self.args[:idx] 437*da0073e9SAndroid Build Coastguard Worker args_right = self.args[idx:] 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker self._args = args_left + (arg,) + args_right 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker _new_input_nodes: Dict[Node, None] = {} 442*da0073e9SAndroid Build Coastguard Worker map_arg(arg, _new_input_nodes.setdefault) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker for new_use in _new_input_nodes.keys(): 445*da0073e9SAndroid Build Coastguard Worker if new_use not in self._input_nodes: 446*da0073e9SAndroid Build Coastguard Worker self._input_nodes.setdefault(new_use) 447*da0073e9SAndroid Build Coastguard Worker new_use.users.setdefault(self) 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 450*da0073e9SAndroid Build Coastguard Worker def update_kwarg(self, key : str, arg : Argument) -> None: 451*da0073e9SAndroid Build Coastguard Worker """ 452*da0073e9SAndroid Build Coastguard Worker Update an existing keyword argument to contain the new value 453*da0073e9SAndroid Build Coastguard Worker ``arg``. After calling, ``self.kwargs[key] == arg``. 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker Args: 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker key (str): The key in ``self.kwargs`` of the element to update 458*da0073e9SAndroid Build Coastguard Worker arg (Argument): The new argument value to write into ``kwargs`` 459*da0073e9SAndroid Build Coastguard Worker """ 460*da0073e9SAndroid Build Coastguard Worker self.kwargs = {**self.kwargs, key: arg} 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker @property 463*da0073e9SAndroid Build Coastguard Worker def stack_trace(self) -> Optional[str]: 464*da0073e9SAndroid Build Coastguard Worker """ 465*da0073e9SAndroid Build Coastguard Worker Return the Python stack trace that was recorded during tracing, if any. 466*da0073e9SAndroid Build Coastguard Worker When traced with fx.Tracer, this property is usually populated by 467*da0073e9SAndroid Build Coastguard Worker `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, 468*da0073e9SAndroid Build Coastguard Worker set `record_stack_traces = True` on the `Tracer` instance. 469*da0073e9SAndroid Build Coastguard Worker When traced with dynamo, this property will be populated by default by 470*da0073e9SAndroid Build Coastguard Worker `OutputGraph.create_proxy`. 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker stack_trace would have the innermost frame at the end of the string. 473*da0073e9SAndroid Build Coastguard Worker """ 474*da0073e9SAndroid Build Coastguard Worker return self.meta.get("stack_trace", None) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker @stack_trace.setter 477*da0073e9SAndroid Build Coastguard Worker def stack_trace(self, trace : Optional[str]) -> None: 478*da0073e9SAndroid Build Coastguard Worker self.meta["stack_trace"] = trace 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None: 481*da0073e9SAndroid Build Coastguard Worker """ 482*da0073e9SAndroid Build Coastguard Worker This API is internal. Do *not* call it directly. 483*da0073e9SAndroid Build Coastguard Worker """ 484*da0073e9SAndroid Build Coastguard Worker def update_users_and_input_nodes(n: Any) -> Any: 485*da0073e9SAndroid Build Coastguard Worker if isinstance(n, Node): 486*da0073e9SAndroid Build Coastguard Worker self._input_nodes.setdefault(n) 487*da0073e9SAndroid Build Coastguard Worker n.users.setdefault(self) 488*da0073e9SAndroid Build Coastguard Worker return n 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # Clear prior users and input_nodes 491*da0073e9SAndroid Build Coastguard Worker for old_use in self._input_nodes.keys(): 492*da0073e9SAndroid Build Coastguard Worker old_use.users.pop(self) 493*da0073e9SAndroid Build Coastguard Worker self._input_nodes = {} 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker # We do three things in a single pass of the args 496*da0073e9SAndroid Build Coastguard Worker # - Normalize list->immutable_list, dict->immutable_dict, etc 497*da0073e9SAndroid Build Coastguard Worker # - Populate self._input_nodes 498*da0073e9SAndroid Build Coastguard Worker # - Populate arg.users[self] for each arg 499*da0073e9SAndroid Build Coastguard Worker self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] 500*da0073e9SAndroid Build Coastguard Worker self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: 503*da0073e9SAndroid Build Coastguard Worker if self._repr_fn: 504*da0073e9SAndroid Build Coastguard Worker return self._repr_fn(self) 505*da0073e9SAndroid Build Coastguard Worker return self.name 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker def _pretty_print_target(self, target: object) -> str: 508*da0073e9SAndroid Build Coastguard Worker """ 509*da0073e9SAndroid Build Coastguard Worker Make target printouts more user-friendly. 510*da0073e9SAndroid Build Coastguard Worker 1) builtins will be printed as `builtins.xyz` 511*da0073e9SAndroid Build Coastguard Worker 2) operators will be printed as `operator.xyz` 512*da0073e9SAndroid Build Coastguard Worker 3) other callables will be printed with qualified name, e.g. torch.add 513*da0073e9SAndroid Build Coastguard Worker """ 514*da0073e9SAndroid Build Coastguard Worker if isinstance(target, str): 515*da0073e9SAndroid Build Coastguard Worker return target 516*da0073e9SAndroid Build Coastguard Worker if hasattr(target, '__module__'): 517*da0073e9SAndroid Build Coastguard Worker name = getattr(target, '__name__', None) 518*da0073e9SAndroid Build Coastguard Worker if name is None: 519*da0073e9SAndroid Build Coastguard Worker # Just to be defensive, if we don't have `__name__`, get the 520*da0073e9SAndroid Build Coastguard Worker # qualname. Not sure if this happens for any members of `operator` 521*da0073e9SAndroid Build Coastguard Worker # or `builtins`. This fallback path is not as good, since e.g. 522*da0073e9SAndroid Build Coastguard Worker # things in `operator` have `_operator` as their __module__. 523*da0073e9SAndroid Build Coastguard Worker # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` 524*da0073e9SAndroid Build Coastguard Worker return _get_qualified_name(target) # type: ignore[arg-type] 525*da0073e9SAndroid Build Coastguard Worker if target.__module__ == 'builtins': 526*da0073e9SAndroid Build Coastguard Worker return f'builtins.{name}' 527*da0073e9SAndroid Build Coastguard Worker elif target.__module__ == '_operator': 528*da0073e9SAndroid Build Coastguard Worker return f'operator.{name}' 529*da0073e9SAndroid Build Coastguard Worker return _get_qualified_name(target) # type: ignore[arg-type] 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 532*da0073e9SAndroid Build Coastguard Worker def format_node(self, 533*da0073e9SAndroid Build Coastguard Worker placeholder_names: Optional[List[str]] = None, 534*da0073e9SAndroid Build Coastguard Worker maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: 535*da0073e9SAndroid Build Coastguard Worker """ 536*da0073e9SAndroid Build Coastguard Worker Return a descriptive string representation of ``self``. 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker This method can be used with no arguments as a debugging 539*da0073e9SAndroid Build Coastguard Worker utility. 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker This function is also used internally in the ``__str__`` method 542*da0073e9SAndroid Build Coastguard Worker of ``Graph``. Together, the strings in ``placeholder_names`` 543*da0073e9SAndroid Build Coastguard Worker and ``maybe_return_typename`` make up the signature of the 544*da0073e9SAndroid Build Coastguard Worker autogenerated ``forward`` function in this Graph's surrounding 545*da0073e9SAndroid Build Coastguard Worker GraphModule. ``placeholder_names`` and ``maybe_return_typename`` 546*da0073e9SAndroid Build Coastguard Worker should not be used otherwise. 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker Args: 549*da0073e9SAndroid Build Coastguard Worker placeholder_names: A list that will store formatted strings 550*da0073e9SAndroid Build Coastguard Worker representing the placeholders in the generated 551*da0073e9SAndroid Build Coastguard Worker ``forward`` function. Internal use only. 552*da0073e9SAndroid Build Coastguard Worker maybe_return_typename: A single-element list that will store 553*da0073e9SAndroid Build Coastguard Worker a formatted string representing the output of the 554*da0073e9SAndroid Build Coastguard Worker generated ``forward`` function. Internal use only. 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker Returns: 557*da0073e9SAndroid Build Coastguard Worker str: If 1) we're using ``format_node`` as an internal helper 558*da0073e9SAndroid Build Coastguard Worker in the ``__str__`` method of ``Graph``, and 2) ``self`` 559*da0073e9SAndroid Build Coastguard Worker is a placeholder Node, return ``None``. Otherwise, 560*da0073e9SAndroid Build Coastguard Worker return a descriptive string representation of the 561*da0073e9SAndroid Build Coastguard Worker current Node. 562*da0073e9SAndroid Build Coastguard Worker """ 563*da0073e9SAndroid Build Coastguard Worker if self.op == 'placeholder': 564*da0073e9SAndroid Build Coastguard Worker assert isinstance(self.target, str) 565*da0073e9SAndroid Build Coastguard Worker arg_str = self.target 566*da0073e9SAndroid Build Coastguard Worker arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' 567*da0073e9SAndroid Build Coastguard Worker if placeholder_names: 568*da0073e9SAndroid Build Coastguard Worker placeholder_names.append(arg_str) 569*da0073e9SAndroid Build Coastguard Worker return None 570*da0073e9SAndroid Build Coastguard Worker maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' 571*da0073e9SAndroid Build Coastguard Worker default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' 572*da0073e9SAndroid Build Coastguard Worker return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' 573*da0073e9SAndroid Build Coastguard Worker elif self.op == 'get_attr': 574*da0073e9SAndroid Build Coastguard Worker maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' 575*da0073e9SAndroid Build Coastguard Worker return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ 576*da0073e9SAndroid Build Coastguard Worker f'{self.op}[target={self._pretty_print_target(self.target)}]' 577*da0073e9SAndroid Build Coastguard Worker elif self.op == 'output': 578*da0073e9SAndroid Build Coastguard Worker if self.type and maybe_return_typename: 579*da0073e9SAndroid Build Coastguard Worker maybe_return_typename[0] = f' -> {_type_repr(self.type)}' 580*da0073e9SAndroid Build Coastguard Worker return f'return {self.args[0]}' 581*da0073e9SAndroid Build Coastguard Worker else: 582*da0073e9SAndroid Build Coastguard Worker maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' 583*da0073e9SAndroid Build Coastguard Worker return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ 584*da0073e9SAndroid Build Coastguard Worker f'{self.op}[target={self._pretty_print_target(self.target)}](' \ 585*da0073e9SAndroid Build Coastguard Worker f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 588*da0073e9SAndroid Build Coastguard Worker def replace_all_uses_with(self, 589*da0073e9SAndroid Build Coastguard Worker replace_with: 'Node', 590*da0073e9SAndroid Build Coastguard Worker delete_user_cb: Callable[['Node'], bool] = lambda user: True, 591*da0073e9SAndroid Build Coastguard Worker *, 592*da0073e9SAndroid Build Coastguard Worker propagate_meta: bool = False 593*da0073e9SAndroid Build Coastguard Worker ) -> List['Node']: 594*da0073e9SAndroid Build Coastguard Worker """ 595*da0073e9SAndroid Build Coastguard Worker Replace all uses of ``self`` in the Graph with the Node ``replace_with``. 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker Args: 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker replace_with (Node): The node to replace all uses of ``self`` with. 600*da0073e9SAndroid Build Coastguard Worker delete_user_cb (Callable): Callback that is called to determine 601*da0073e9SAndroid Build Coastguard Worker whether a given user of the self node should be removed. 602*da0073e9SAndroid Build Coastguard Worker propagate_meta (bool): Whether or not to copy all properties 603*da0073e9SAndroid Build Coastguard Worker on the .meta field of the original node onto the replacement node. 604*da0073e9SAndroid Build Coastguard Worker For safety, this is only valid to do if the replacement node 605*da0073e9SAndroid Build Coastguard Worker doesn't already have an existing .meta field. 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker Returns: 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker The list of Nodes on which this change was made. 610*da0073e9SAndroid Build Coastguard Worker """ 611*da0073e9SAndroid Build Coastguard Worker if propagate_meta: 612*da0073e9SAndroid Build Coastguard Worker assert len(replace_with.meta) == 0, \ 613*da0073e9SAndroid Build Coastguard Worker 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ 614*da0073e9SAndroid Build Coastguard Worker 'but replace_with already has .meta keys' 615*da0073e9SAndroid Build Coastguard Worker for k, v in self.meta.items(): 616*da0073e9SAndroid Build Coastguard Worker replace_with.meta[k] = v 617*da0073e9SAndroid Build Coastguard Worker to_process = list(self.users) 618*da0073e9SAndroid Build Coastguard Worker skipped = [] 619*da0073e9SAndroid Build Coastguard Worker m = self.graph.owning_module 620*da0073e9SAndroid Build Coastguard Worker for use_node in to_process: 621*da0073e9SAndroid Build Coastguard Worker if not delete_user_cb(use_node): 622*da0073e9SAndroid Build Coastguard Worker skipped.append(use_node) 623*da0073e9SAndroid Build Coastguard Worker continue 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker def maybe_replace_node(n : Node) -> Node: 626*da0073e9SAndroid Build Coastguard Worker if n == self: 627*da0073e9SAndroid Build Coastguard Worker return replace_with 628*da0073e9SAndroid Build Coastguard Worker else: 629*da0073e9SAndroid Build Coastguard Worker return n 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker if getattr(m, "_replace_hook", None): 632*da0073e9SAndroid Build Coastguard Worker m._replace_hook(old=self, new=replace_with.name, user=use_node) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker new_args = map_arg(use_node.args, maybe_replace_node) 635*da0073e9SAndroid Build Coastguard Worker new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) 636*da0073e9SAndroid Build Coastguard Worker assert isinstance(new_args, tuple) 637*da0073e9SAndroid Build Coastguard Worker assert isinstance(new_kwargs, dict) 638*da0073e9SAndroid Build Coastguard Worker use_node.__update_args_kwargs(new_args, new_kwargs) 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker assert len(self.users) - len(skipped) == 0 641*da0073e9SAndroid Build Coastguard Worker return [n for n in to_process if n not in skipped] 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 644*da0073e9SAndroid Build Coastguard Worker def is_impure(self) -> bool: 645*da0073e9SAndroid Build Coastguard Worker """ 646*da0073e9SAndroid Build Coastguard Worker Returns whether this op is impure, i.e. if its op is a placeholder or 647*da0073e9SAndroid Build Coastguard Worker output, or if a call_function or call_module which is impure. 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker Returns: 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker bool: If the op is impure or not. 652*da0073e9SAndroid Build Coastguard Worker """ 653*da0073e9SAndroid Build Coastguard Worker if self.op in {"placeholder", "output"}: 654*da0073e9SAndroid Build Coastguard Worker return True 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker # Check if an impure function based on schema. 657*da0073e9SAndroid Build Coastguard Worker if self.op == "call_function": 658*da0073e9SAndroid Build Coastguard Worker schema = getattr(self.target, "_schema", None) 659*da0073e9SAndroid Build Coastguard Worker schema_mutable = schema is not None and schema.is_mutable 660*da0073e9SAndroid Build Coastguard Worker return schema_mutable or self.target in _side_effectful_functions 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker # Check if an impure module. 663*da0073e9SAndroid Build Coastguard Worker if self.op == "call_module": 664*da0073e9SAndroid Build Coastguard Worker assert ( 665*da0073e9SAndroid Build Coastguard Worker self.graph.owning_module is not None 666*da0073e9SAndroid Build Coastguard Worker ), "self.graph.owning_module not set for purity check" 667*da0073e9SAndroid Build Coastguard Worker target_mod = self.graph.owning_module.get_submodule(self.target) 668*da0073e9SAndroid Build Coastguard Worker assert ( 669*da0073e9SAndroid Build Coastguard Worker target_mod is not None 670*da0073e9SAndroid Build Coastguard Worker ), f"Did not find expected submodule target {self.target}" 671*da0073e9SAndroid Build Coastguard Worker return getattr(target_mod, "_is_impure", False) 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker return False 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=False) 676*da0073e9SAndroid Build Coastguard Worker def normalized_arguments( 677*da0073e9SAndroid Build Coastguard Worker self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, 678*da0073e9SAndroid Build Coastguard Worker kwarg_types : Optional[Dict[str, Any]] = None, 679*da0073e9SAndroid Build Coastguard Worker normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: 680*da0073e9SAndroid Build Coastguard Worker """ 681*da0073e9SAndroid Build Coastguard Worker Returns normalized arguments to Python targets. This means that 682*da0073e9SAndroid Build Coastguard Worker `args/kwargs` will be matched up to the module/functional's 683*da0073e9SAndroid Build Coastguard Worker signature and return exclusively kwargs in positional order 684*da0073e9SAndroid Build Coastguard Worker if `normalize_to_only_use_kwargs` is true. 685*da0073e9SAndroid Build Coastguard Worker Also populates default values. Does not support positional-only 686*da0073e9SAndroid Build Coastguard Worker parameters or varargs parameters. 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker Supports module calls. 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker May require `arg_types` and `kwarg_types` in order to disambiguate overloads. 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker Args: 693*da0073e9SAndroid Build Coastguard Worker root (torch.nn.Module): Module upon which to resolve module targets. 694*da0073e9SAndroid Build Coastguard Worker arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args 695*da0073e9SAndroid Build Coastguard Worker kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs 696*da0073e9SAndroid Build Coastguard Worker normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker Returns: 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker Returns NamedTuple ArgsKwargsPair, or `None` if not successful. 701*da0073e9SAndroid Build Coastguard Worker """ 702*da0073e9SAndroid Build Coastguard Worker if self.op == 'call_function': 703*da0073e9SAndroid Build Coastguard Worker assert callable(self.target) 704*da0073e9SAndroid Build Coastguard Worker return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] 705*da0073e9SAndroid Build Coastguard Worker elif self.op == 'call_module': 706*da0073e9SAndroid Build Coastguard Worker assert isinstance(self.target, str) 707*da0073e9SAndroid Build Coastguard Worker return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker return None 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker @compatibility(is_backward_compatible=True) 712*da0073e9SAndroid Build Coastguard Worker def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: 713*da0073e9SAndroid Build Coastguard Worker """ 714*da0073e9SAndroid Build Coastguard Worker Loop through input nodes of ``self``, and replace all instances of 715*da0073e9SAndroid Build Coastguard Worker ``old_input`` with ``new_input``. 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker Args: 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker old_input (Node): The old input node to be replaced. 720*da0073e9SAndroid Build Coastguard Worker new_input (Node): The new input node to replace ``old_input``. 721*da0073e9SAndroid Build Coastguard Worker """ 722*da0073e9SAndroid Build Coastguard Worker def maybe_replace_node(n : Node) -> Node: 723*da0073e9SAndroid Build Coastguard Worker return new_input if n == old_input else n 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker m = self.graph.owning_module 726*da0073e9SAndroid Build Coastguard Worker if getattr(m, "_replace_hook", None): 727*da0073e9SAndroid Build Coastguard Worker m._replace_hook(old=old_input, new=new_input.name, user=self) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker new_args = map_arg(self.args, maybe_replace_node) 730*da0073e9SAndroid Build Coastguard Worker new_kwargs = map_arg(self.kwargs, maybe_replace_node) 731*da0073e9SAndroid Build Coastguard Worker assert isinstance(new_args, tuple) 732*da0073e9SAndroid Build Coastguard Worker assert isinstance(new_kwargs, dict) 733*da0073e9SAndroid Build Coastguard Worker self.__update_args_kwargs(new_args, new_kwargs) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def _rename(self, candidate: str) -> None: 736*da0073e9SAndroid Build Coastguard Worker if candidate == self.name: 737*da0073e9SAndroid Build Coastguard Worker return 738*da0073e9SAndroid Build Coastguard Worker name = self.graph._graph_namespace.create_name(candidate, None) 739*da0073e9SAndroid Build Coastguard Worker self.name = name 740*da0073e9SAndroid Build Coastguard Worker self.graph._graph_namespace._rename_object(self, name) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, name: str, value: Any) -> None: 743*da0073e9SAndroid Build Coastguard Worker if name == 'name' and hasattr(self, "name"): 744*da0073e9SAndroid Build Coastguard Worker m = self.graph.owning_module 745*da0073e9SAndroid Build Coastguard Worker if getattr(m, "_replace_hook", None): 746*da0073e9SAndroid Build Coastguard Worker assert isinstance(value, str) 747*da0073e9SAndroid Build Coastguard Worker for user in self.users: 748*da0073e9SAndroid Build Coastguard Worker m._replace_hook(old=self, new=value, user=user) 749*da0073e9SAndroid Build Coastguard Worker update = False 750*da0073e9SAndroid Build Coastguard Worker if ( 751*da0073e9SAndroid Build Coastguard Worker hasattr(self, name) and 752*da0073e9SAndroid Build Coastguard Worker hasattr(self.graph, "_find_nodes_lookup_table") and 753*da0073e9SAndroid Build Coastguard Worker self in self.graph._find_nodes_lookup_table 754*da0073e9SAndroid Build Coastguard Worker ): 755*da0073e9SAndroid Build Coastguard Worker update = True 756*da0073e9SAndroid Build Coastguard Worker self.graph._find_nodes_lookup_table.remove(self) 757*da0073e9SAndroid Build Coastguard Worker object.__setattr__(self, name, value) 758*da0073e9SAndroid Build Coastguard Worker if update: 759*da0073e9SAndroid Build Coastguard Worker self.graph._find_nodes_lookup_table.insert(self) 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 762*da0073e9SAndroid Build Coastguard Workerdef map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: 763*da0073e9SAndroid Build Coastguard Worker """ 764*da0073e9SAndroid Build Coastguard Worker Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. 765*da0073e9SAndroid Build Coastguard Worker """ 766*da0073e9SAndroid Build Coastguard Worker assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" 767*da0073e9SAndroid Build Coastguard Worker return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True) 770*da0073e9SAndroid Build Coastguard Workerdef map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: 771*da0073e9SAndroid Build Coastguard Worker """ 772*da0073e9SAndroid Build Coastguard Worker Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. 773*da0073e9SAndroid Build Coastguard Worker """ 774*da0073e9SAndroid Build Coastguard Worker if isinstance(a, tuple): 775*da0073e9SAndroid Build Coastguard Worker t = tuple([map_aggregate(elem, fn) for elem in a]) 776*da0073e9SAndroid Build Coastguard Worker # Support NamedTuple (if it has `_fields`) by repacking into original type. 777*da0073e9SAndroid Build Coastguard Worker return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] 778*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, list): 779*da0073e9SAndroid Build Coastguard Worker return immutable_list([map_aggregate(elem, fn) for elem in a]) 780*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, dict): 781*da0073e9SAndroid Build Coastguard Worker rv = immutable_dict() 782*da0073e9SAndroid Build Coastguard Worker for k, v in a.items(): 783*da0073e9SAndroid Build Coastguard Worker dict.__setitem__(rv, k, map_aggregate(v, fn)) 784*da0073e9SAndroid Build Coastguard Worker return rv 785*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, slice): 786*da0073e9SAndroid Build Coastguard Worker return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) 787*da0073e9SAndroid Build Coastguard Worker else: 788*da0073e9SAndroid Build Coastguard Worker return fn(a) 789