xref: /aosp_15_r20/external/pytorch/torch/fx/graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import defaultdict
3from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
4import torch.utils._pytree as pytree
5from . import _pytree as fx_pytree
6from ._compatibility import compatibility
7from torch._C import _NodeIter
8
9import os
10import contextlib
11from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable
12from dataclasses import dataclass
13from contextlib import contextmanager
14import copy
15import enum
16import torch
17import keyword
18import re
19import builtins
20import math
21import warnings
22import inspect
23
24__all__ = ["PythonCode", "CodeGen", "Graph"]
25
26if TYPE_CHECKING:
27    from .graph_module import GraphModule  # noqa: F401
28    from ._symbolic_trace import Tracer   # noqa: F401
29
30
31# Mapping of builtins to their `typing` equivalent.
32_origin_type_map = {
33    list: List,
34    dict: Dict,
35    set: Set,
36    frozenset: FrozenSet,
37    tuple: Tuple,
38}
39
40
41# Signature for functions thattransforms the body (`list[str]`) of the
42# generated code
43TransformCodeFunc = Callable[[List[str]], List[str]]
44
45
46class _CustomBuiltin(NamedTuple):
47    """Additional objs that we add to every graph's globals.
48
49    The repr() for some standard library objects is not valid Python code without
50    an import. For common objects of this sort, we bundle them in the globals of
51    every FX graph.
52    """
53    # How to import this object from the standard library.
54    import_str: str
55    # The actual object, produced from that import string.
56    obj: Any
57
58_custom_builtins: Dict[str, _CustomBuiltin] = {}
59
60
61def _register_custom_builtin(name: str, import_str: str, obj: Any):
62    _custom_builtins[name] = _CustomBuiltin(import_str, obj)
63
64
65_register_custom_builtin('inf', 'from math import inf', math.inf)
66_register_custom_builtin('nan', 'from math import nan', math.nan)
67_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
68_register_custom_builtin('torch', 'import torch', torch)
69_register_custom_builtin('device', 'from torch import device', torch.device)
70_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
71_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)
72
73
74def _is_magic(x: str) -> bool:
75    return x.startswith('__') and x.endswith('__')
76
77
78def _snake_case(s: str) -> str:
79    """
80    Transforms the given string ``s`` to a Python-style variable name
81
82    Examples:
83        ``mod.snake_case`` -> ``mod.snake_case``
84        ``mod.pascalCase``-> ``mod.pascal_case``
85        ``mod.ALL_CAPS`` -> ``mod.all_caps``
86    """
87    chars = []
88    prev_lower = False
89    for c in s:
90        if prev_lower and c.isupper():
91            chars.append('_')
92        chars.append(c.lower())
93        prev_lower = c.islower()
94    return ''.join(chars)
95
96
97def _is_from_torch(obj: Any) -> bool:
98    module_name = getattr(obj, '__module__', None)
99    if module_name is not None:
100        base_module = module_name.partition('.')[0]
101        return (
102            base_module == 'torch' and
103            not module_name.startswith("torch._dynamo.") and
104            not module_name.startswith("torch._inductor.")
105        )
106
107    name = getattr(obj, '__name__', None)
108    # exclude torch because torch.torch.torch.torch works. idk mang
109    if name is not None and name != 'torch':
110        for guess in [torch, torch.nn.functional]:
111            if getattr(guess, name, None) is obj:
112                return True
113
114    return False
115
116
117class _Namespace:
118    """A context for associating names uniquely with objects.
119
120    The following invariants are enforced:
121    - Each object gets a single name.
122    - Each name is unique within a given namespace.
123    - Names generated do not shadow builtins, unless the object is indeed that builtin.
124    """
125    def __init__(self):
126        self._obj_to_name: Dict[Any, str] = {}
127        self._unassociated_names = set()
128        self._used_names: Set[str] = set()
129        self._base_count: Dict[str, int] = defaultdict(int)
130
131        self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
132        self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
133
134    def create_name(self, candidate: str, obj: Optional[Any]) -> str:
135        """Create a unique name.
136
137        Arguments:
138            candidate: used as the basis for the unique name, relevant to the user.
139            obj: If not None, an object that will be associated with the unique name.
140        """
141        if obj is not None and obj in self._obj_to_name:
142            return self._obj_to_name[obj]
143
144        # delete all characters that are illegal in a Python identifier
145        candidate = self._illegal_char_regex.sub('_', candidate)
146
147        if not candidate:
148            candidate = '_unnamed'
149
150        if candidate[0].isdigit():
151            candidate = f'_{candidate}'
152
153        match = self._name_suffix_regex.match(candidate)
154        if match is None:
155            base = candidate
156            num = None
157        else:
158            base, num_str = match.group(1, 2)
159            num = int(num_str)
160
161        candidate = base if num is None else f'{base}_{num}'
162        if not num:
163            num = self._base_count[base]
164
165        while candidate in self._used_names or self._is_illegal_name(candidate, obj):
166            num += 1
167            candidate = f'{base}_{num}'
168
169        self._used_names.add(candidate)
170        self._base_count[base] = num
171        if obj is None:
172            self._unassociated_names.add(candidate)
173        else:
174            self._obj_to_name[obj] = candidate
175        return candidate
176
177    def associate_name_with_obj(self, name: str, obj: Any):
178        """Associate a unique name with an object.
179
180        Neither `name` nor `obj` should be associated already.
181        """
182        assert obj not in self._obj_to_name
183        assert name in self._unassociated_names
184        self._obj_to_name[obj] = name
185        self._unassociated_names.remove(name)
186
187    def _is_illegal_name(self, name: str, obj: Any) -> bool:
188        # 1. keywords are never allowed as names.
189        if name in keyword.kwlist:
190            return True
191
192        # 2. Can't shadow a builtin name, unless you *are* that builtin.
193        if name in builtins.__dict__:
194            return obj is not builtins.__dict__[name]
195
196        # 3. Can't shadow our custom builtins either
197        if name in _custom_builtins:
198            return obj is not _custom_builtins[name].obj
199
200        return False
201
202    def _rename_object(self, obj: Any, name: str):
203        assert obj in self._obj_to_name
204        self._obj_to_name[obj] = name
205        self._used_names.add(name)
206
207dtype_abbrs = {
208    torch.bfloat16: 'bf16',
209    torch.float64: 'f64',
210    torch.float32: 'f32',
211    torch.float16: 'f16',
212    torch.float8_e4m3fn: 'f8e4m3fn',
213    torch.float8_e5m2: 'f8e5m2',
214    torch.float8_e4m3fnuz: 'f8e4m3fnuz',
215    torch.float8_e5m2fnuz: 'f8e5m2fnuz',
216    torch.complex32: 'c32',
217    torch.complex64: 'c64',
218    torch.complex128: 'c128',
219    torch.int8: 'i8',
220    torch.int16: 'i16',
221    torch.int32: 'i32',
222    torch.int64: 'i64',
223    torch.bool: 'b8',
224    torch.uint8: 'u8',
225    torch.uint16: 'u16',
226    torch.uint32: 'u32',
227    torch.uint64: 'u64',
228    torch.bits16: 'b16',
229}
230
231@compatibility(is_backward_compatible=True)
232@dataclass
233class PythonCode:
234    """
235    Represents all the information necessary to exec or save a graph as Python code.
236    """
237    # Python source code for the forward function definition.
238    src: str
239    # Values in global scope during execution of `src_def`.
240    globals: Dict[str, Any]
241    # Optional mapping from the forward function's line number to
242    # node index.
243    _lineno_map: Optional[Dict[int, Optional[int]]]
244
245
246def _format_target(base: str, target: str) -> str:
247    elems = target.split('.')
248    r = base
249    for e in elems:
250        if not e.isidentifier():
251            r = f'getattr({r}, "{e}")'
252        else:
253            r = f'{r}.{e}'
254    return r
255
256class _InsertPoint:
257    def __init__(self, graph, new_insert):
258        self.graph = graph
259        self.orig_insert, graph._insert = graph._insert, new_insert
260
261    def __enter__(self):
262        pass
263
264    def __exit__(self, type, value, tb):
265        self.graph._insert = self.orig_insert
266
267class _node_list:
268    def __init__(self, graph: 'Graph', direction: str = '_next'):
269        assert direction in ['_next', '_prev']
270        self.graph = graph
271        self.direction = direction
272
273    def __len__(self):
274        return self.graph._len
275
276    def __iter__(self):
277        assert self.direction == "_prev" or self.direction == "_next"
278        yield from _NodeIter(self.graph._root, self.direction == "_prev")
279
280    def __reversed__(self):
281        return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
282
283class _PyTreeInfo(NamedTuple):
284    """
285    Contains extra info stored when we're using Pytrees
286    """
287    orig_args: List[str]
288    in_spec: pytree.TreeSpec
289    out_spec: Optional[pytree.TreeSpec]
290
291@dataclass(frozen=True)
292class _ParsedStackTrace:
293    """
294    Represents the top-most frame of a parsed stack trace
295    """
296    file: str
297    lineno: str
298    name: str
299    code: str
300
301    def get_summary_str(self):
302        return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}'
303
304# get File:lineno code from stack_trace
305def _parse_stack_trace(stack_trace: str):
306    if stack_trace is None:
307        return None
308    pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
309    lines = stack_trace.strip().split('\n')
310    # stacktrace should have innermost frame last, so we
311    # iterate backwards to find the first line that starts
312    # with 'File '
313    summary_str = ""
314    for idx in range(len(lines) - 2, -1, -1):
315        line = lines[idx].strip()
316        matches = pattern.match(line)
317        if matches:
318            file = matches.group(1)
319            lineno = matches.group(2)
320            name = matches.group(3)
321            # next line should be the code
322            code = lines[idx + 1].strip()
323            return _ParsedStackTrace(file, lineno, name, code)
324    return None
325
326@compatibility(is_backward_compatible=False)
327class CodeGen:
328    def __init__(self):
329        self._body_transformer: Optional[TransformCodeFunc] = None
330        self._func_name: str = "forward"
331
332    def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
333        """
334        Given the free variables and a return annotation, generates the beginning of the FX function.
335        By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
336        """
337        # If the original function didn't have self as its first argument, we
338        # would have added it.
339        if len(free_vars) == 0 or free_vars[0] != 'self':
340            free_vars.insert(0, 'self')
341        return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
342
343    def generate_output(self, output_args: Argument) -> str:
344        """
345        Given the output arguments, generates the return statement of the FX function.
346        Note: The returned statement should not be indented.
347        """
348        return f'return {repr(output_args)}'
349
350    def process_inputs(self, *args: Any) -> Any:
351        """
352        Transforms the inputs so that the graph can take them as arguments, as
353        non-default codegen may result in the inputs to the function being
354        different from the inputs to the graph.
355
356        If the graph was directly runnable, this invariant should hold true
357        `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
358        """
359        return args
360
361    def process_outputs(self, outputs: Any) -> Any:
362        """
363        Transforms the outputs of the graph to be identical to the codegen.
364
365        See ``process_inputs`` for more details.
366        """
367        return outputs
368
369    def additional_globals(self) -> List[Tuple[str, Any]]:
370        """
371        If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
372        For example, return ['List', typing.List] if you need ``List`` in the global context.
373        """
374        return []
375
376    def _gen_python_code(
377        self, nodes, root_module: str, namespace: _Namespace, *,
378        verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False
379    ) -> PythonCode:
380        free_vars: List[str] = []
381        body: List[str] = []
382        globals_: Dict[str, Any] = {}
383        wrapped_fns: Dict[str, None] = {}
384
385        # Wrap string in list to pass by reference
386        maybe_return_annotation : List[str] = ['']
387        include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1")
388        include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1")
389
390        def add_global(name_hint: str, obj: Any):
391            """Add an obj to be tracked as a global.
392
393            We call this for names that reference objects external to the
394            Graph, like functions or types.
395
396            Returns: the global name that should be used to reference 'obj' in generated source.
397            """
398            if _is_from_torch(obj) and obj != torch.device:  # to support registering torch.device
399                # HACK: workaround for how torch custom ops are registered. We
400                # can't import them like normal modules so they must retain their
401                # fully qualified name.
402                return _get_qualified_name(obj)
403
404            # normalize the name hint to get a proper identifier
405            global_name = namespace.create_name(name_hint, obj)
406
407            if global_name in globals_:
408                assert globals_[global_name] is obj
409                return global_name
410            globals_[global_name] = obj
411            return global_name
412
413        # Pre-fill the globals table with registered builtins.
414        for name, (_, obj) in _custom_builtins.items():
415            add_global(name, obj)
416
417        def type_repr(o : Any):
418            if o == ():
419                # Empty tuple is used for empty tuple type annotation Tuple[()]
420                return '()'
421
422            typename = _type_repr(o)
423
424            if hasattr(o, '__origin__'):
425                # This is a generic type, e.g. typing.List[torch.Tensor]
426                origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
427                origin_typename = add_global(_type_repr(origin_type), origin_type)
428
429                if hasattr(o, '__args__'):
430                    # Assign global names for each of the inner type variables.
431                    args = [type_repr(arg) for arg in o.__args__]
432
433                    if len(args) == 0:
434                        # Bare type, such as `typing.Tuple` with no subscript
435                        # This code-path used in Python < 3.9
436                        return origin_typename
437
438                    return f'{origin_typename}[{",".join(args)}]'
439                else:
440                    # Bare type, such as `typing.Tuple` with no subscript
441                    # This code-path used in Python 3.9+
442                    return origin_typename
443
444            # Common case: this is a regular module name like 'foo.bar.baz'
445            return add_global(typename, o)
446
447        codes = {
448            "yellow": "\033[33m",
449            "cyan": "\033[36m",
450            "green": "\033[32m",
451            "blue": "\033[34m",
452            "red": "\033[31m",
453            "dim": "\033[2m",
454            "dim_blue": "\033[2m\033[34m",
455            "dim_green": "\033[2m\033[32m",
456            "reset": "\033[0m",
457        }
458
459        def make_wrapper_func(name):
460            def f(s):
461                if colored:
462                    return f"{codes[name]}{s}{codes['reset']}"
463                return s
464            return f
465
466        yellow = make_wrapper_func("yellow")
467        cyan = make_wrapper_func("cyan")
468        red = make_wrapper_func("red")
469        green = make_wrapper_func("green")
470        dim_green = make_wrapper_func("dim_green")
471        dim = make_wrapper_func("dim")
472        dim_blue = make_wrapper_func("dim_blue")
473        blue = make_wrapper_func("blue")
474
475        def _get_repr(arg: Any) -> str:
476            # Handle NamedTuples (if it has `_fields`) via add_global.
477            if isinstance(arg, tuple) and hasattr(arg, '_fields'):
478                qualified_name = _get_qualified_name(type(arg))
479                global_name = add_global(qualified_name, type(arg))
480                return f"{global_name}{repr(tuple(arg))}"
481            elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
482                qualified_name = _get_qualified_name(arg)
483                global_name = add_global(qualified_name, arg)
484                return f"{global_name}"
485            elif isinstance(arg, enum.Enum):
486                cls = arg.__class__
487                clsname = add_global(cls.__name__, cls)
488                return f"{clsname}.{arg.name}"
489            elif isinstance(arg, Node):
490                return repr(arg)
491            elif isinstance(arg, torch.Tensor):
492                size = list(arg.size())
493                dtype = str(arg.dtype).split(".")[-1]
494                return f"torch.Tensor(size={size}, dtype={dtype})"
495            else:
496                return blue(repr(arg))
497
498
499        def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
500            args_s = ', '.join(_get_repr(a) for a in args)
501            kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
502            if args_s and kwargs_s:
503                return f'{args_s}, {kwargs_s}'
504            return args_s or kwargs_s
505
506        # Run through reverse nodes and record the first instance of a use
507        # of a given node. This represents the *last* use of the node in the
508        # execution order of the program, which we will use to free unused
509        # values
510        node_to_last_use : Dict[Node, Node] = {}
511        user_to_last_uses : Dict[Node, List[Node]] = {}
512
513        def register_last_uses(n : Node, user : Node):
514            if n not in node_to_last_use:
515                node_to_last_use[n] = user
516                user_to_last_uses.setdefault(user, []).append(n)
517
518        for node in reversed(nodes):
519            map_arg(node.args, lambda n: register_last_uses(n, node))
520            map_arg(node.kwargs, lambda n: register_last_uses(n, node))
521
522        def delete_unused_values(user : Node):
523            """
524            Delete values after their last use. This ensures that values that are
525            not used in the remainder of the code are freed and the memory usage
526            of the code is optimal.
527            """
528            if user.op == 'placeholder':
529                return
530            if user.op == 'output':
531                body.append('\n')
532                return
533            nodes_to_delete = user_to_last_uses.get(user, [])
534
535            if len(user.users.keys()) == 0:
536                # This node is not used by any others. however it's also not
537                # removed by DCE since side-effect. We want to free it's outputs
538                # right after its execution done to save memory.
539                nodes_to_delete.append(user)
540
541            if len(nodes_to_delete):
542                to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
543                body.append(f';  {dim(to_delete_str)}\n')
544            else:
545                body.append('\n')
546
547        prev_stacktrace = None
548
549        def append_stacktrace_summary(node : Node):
550            """
551            Append a summary of the stacktrace to the generated code. This is
552            useful for debugging.
553            """
554            nonlocal prev_stacktrace
555
556            if node.op not in {'placeholder', 'output'}:
557                if node.stack_trace:
558                    if node.stack_trace != prev_stacktrace:
559                        prev_stacktrace = node.stack_trace
560                        summary_str = ""
561
562                        if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
563                            summary_str = parsed_stack_trace.get_summary_str()
564
565                        body.append(f'\n {dim("# " + summary_str)}\n')
566                elif prev_stacktrace != "":
567                    prev_stacktrace = ""
568                    no_stacktrace_msg = "# No stacktrace found for following nodes"
569                    body.append(f'\n{dim(no_stacktrace_msg)}\n')
570
571        def stringify_shape(shape : Iterable) -> str:
572            return f"[{', '.join(str(x) for x in shape)}]"
573
574        def emit_node(node : Node):
575            maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
576
577            if verbose:
578                # override annotation with more detailed information
579                from torch.fx.experimental.proxy_tensor import py_sym_types
580                from torch.fx.passes.shape_prop import TensorMetadata
581
582                meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None)))
583                # use string as annotation, to make it valid python code
584
585                if isinstance(meta_val, torch.Tensor):
586                    stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else ""
587                    device_annotation = f"{meta_val.device}" if include_device else ""
588                    maybe_type_annotation = \
589                        f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \
590                        f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"'
591                elif isinstance(meta_val, py_sym_types):
592                    maybe_type_annotation = f': "Sym({meta_val})"'
593                elif isinstance(meta_val, TensorMetadata):
594                    maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
595
596            if node.op == 'placeholder':
597                assert isinstance(node.target, str)
598                maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}'
599                free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
600                raw_name = node.target.replace('*', '')
601                if raw_name != repr(node):
602                    body.append(f'{repr(node)} = {raw_name}\n')
603                return
604            elif node.op == 'call_method':
605                assert isinstance(node.target, str)
606                body.append(
607                    f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}'
608                    f'({_format_args(node.args[1:], node.kwargs)})')
609                return
610            elif node.op == 'call_function':
611                assert callable(node.target)
612                # pretty print operators
613                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
614                    assert isinstance(node.args, tuple)
615                    body.append(f'{repr(node)}{maybe_type_annotation} = '
616                                f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}')
617                    return
618
619                # pretty print inplace operators; required for jit.script to work properly
620                # not currently supported in normal FX graphs, but generated by torchdynamo
621                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
622                    body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))};  '
623                                f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}')
624                    return
625
626                qualified_name = _get_qualified_name(node.target)
627                global_name = add_global(qualified_name, node.target)
628                # special case for getattr: node.args could be 2-argument or 3-argument
629                # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
630                if global_name == 'getattr' and \
631                   isinstance(node.args, tuple) and \
632                   isinstance(node.args[1], str) and \
633                   node.args[1].isidentifier() and \
634                   len(node.args) == 2:
635                    body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}')
636                    return
637                body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
638                if node.meta.get('is_wrapped', False):
639                    wrapped_fns.setdefault(global_name)
640                return
641            elif node.op == 'call_module':
642                assert isinstance(node.target, str)
643                body.append(f'{repr(node)}{maybe_type_annotation} = '
644                            f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
645                return
646            elif node.op == 'get_attr':
647                assert isinstance(node.target, str)
648                body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
649                return
650            elif node.op == 'output':
651                if node.type is not None:
652                    maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
653                body.append(self.generate_output(node.args[0]))
654                return
655            raise NotImplementedError(f'node: {node.op} {node.target}')
656
657        for i, node in enumerate(nodes):
658            # NOTE: emit_node does not emit a string with newline. It depends
659            # on delete_unused_values to append one
660            if verbose:
661                append_stacktrace_summary(node)
662            # emit a counter comment to keep track of
663            # node index, which will be deleted later
664            # after going through _body_transformer
665            body.append(f"# COUNTER: {i}\n")
666            emit_node(node)
667            delete_unused_values(node)
668
669        if len(body) == 0:
670            # If the Graph has no non-placeholder nodes, no lines for the body
671            # have been emitted. To continue to have valid Python code, emit a
672            # single pass statement
673            body.append('pass\n')
674
675
676
677        if len(wrapped_fns) > 0:
678            wrap_name = add_global('wrap', torch.fx.wrap)
679            wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
680        else:
681            wrap_stmts = ''
682
683        if self._body_transformer:
684            body = self._body_transformer(body)
685
686        for name, value in self.additional_globals():
687            add_global(name, value)
688
689        prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
690
691        # remove counter and generate lineno to node index mapping
692        lineno_map: Dict[int, Optional[int]] = {}
693        prologue_len = prologue.count('\n') + 1
694        new_lines: List[str] = []
695        cur_idx = None
696        for line in ''.join(body).split('\n'):
697            counter = re.search(r"# COUNTER: (\d+)", line)
698            if counter and counter.group(1) is not None:
699                cur_idx = int(counter.group(1))
700            else:
701                lineno_map[len(new_lines) + prologue_len] = cur_idx
702                new_lines.append(line)
703
704        code = "\n".join(new_lines).lstrip('\n')
705        code = '\n'.join('    ' + line for line in code.split('\n'))
706
707        fn_code = f"""
708{wrap_stmts}
709
710{prologue}
711{code}"""
712        return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
713
714
715# Ideally, we'd like to refactor all of the pytree logic into this codegen
716# class. Unfortunately, there are 3 areas we currently need extra logic in FX.
717# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
718# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
719#    Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
720# 3. We currently can't register the pytree imports with `add_global` - not sure why.
721class _PyTreeCodeGen(CodeGen):
722    def __init__(self, pytree_info: _PyTreeInfo):
723        super().__init__()
724        self.pytree_info: _PyTreeInfo = pytree_info
725
726    def process_inputs(self, *inputs: Any) -> Any:
727        flat_args = pytree.arg_tree_leaves(*inputs)
728        return flat_args
729
730    def process_outputs(self, out: Any) -> Any:
731        if self.pytree_info is None or self.pytree_info.out_spec is None:
732            return out
733        if not isinstance(out, (list, tuple)):
734            out = [out]
735        assert self.pytree_info.out_spec is not None
736        return pytree.tree_unflatten(out, self.pytree_info.out_spec)
737
738    def gen_fn_def(self, free_vars, maybe_return_annotation):
739        # Given a user function/model:
740        #   myargs = [myargs0, myargs1]
741        #   mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
742        #   def forward(self, mypos, *myargs, mykey=None, **mykwargs):
743        #
744        # The generated code flattens all keywords into positional arguments for `forward()`
745        #   e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
746        #
747        # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
748        #   e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
749        #                           {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
750        #                          self._in_spec)
751        #
752        # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
753        #   e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
754        if self.pytree_info is None:
755            return super().gen_fn_def(free_vars, maybe_return_annotation)
756
757        fn_args = self.pytree_info.orig_args
758        has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
759        if has_orig_self:
760            free_vars.insert(0, 'self')
761        fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
762
763        if len(free_vars) > 0:  # pytree has placeholders in it
764            # when kwargs is present, in_spec is tuple(args, kwargs)
765            has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
766                self.pytree_info.in_spec.num_children == 2 and \
767                self.pytree_info.in_spec.children_specs[0].type == tuple and \
768                self.pytree_info.in_spec.children_specs[1].type == dict
769            fn_kwargs = '{}'
770            fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
771            if has_args_kwargs_tuple:
772                count_args = self.pytree_info.in_spec.children_specs[0].num_children
773                fn_args = self.pytree_info.orig_args[:count_args]
774                fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
775                                  self.pytree_info.in_spec.children_specs[1].context,
776                                  self.pytree_info.orig_args[count_args:])) + '}'
777                fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
778
779            # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
780            # we need to split it to two lines:
781            # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
782            # one for code: `var1, var2, = function_call()`
783            without_annotation = [x.split(":")[0] for x in free_vars]
784            has_annotation = [x + "; " for x in free_vars if ":" in x]
785            if len(has_annotation) > 0:
786                fn_definition += "\n    " + "".join(has_annotation) + "\n"
787            fn_definition += f"""
788    {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
789        return fn_definition
790
791    def generate_output(self, output_args):
792        if self.pytree_info and self.pytree_info.out_spec:
793            return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
794        else:
795            return super().generate_output(output_args)
796
797class _FindNodesLookupTable:
798    """
799    Side table for the graph for the purpose of doing fast queries
800    """
801    def __init__(self):
802        self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict)
803
804    def _key(self, node) -> Tuple[str, Optional[Target]]:
805        return (node.op, node.target if node.op == "call_function" else None)
806
807    def __contains__(self, node) -> bool:
808        return node in self.table[self._key(node)]
809
810    def insert(self, node: Node) -> None:
811        self.table[self._key(node)][node] = None
812
813    def remove(self, node: Node) -> None:
814        self.table[self._key(node)].pop(node)
815
816    def find_nodes(self, *, op: str, target: Optional['Target'] = None):
817        if op == "call_function":
818            assert target is not None
819            return dict(self.table[(op, target)]).keys()
820
821        if target is None:
822            return dict(self.table[(op, None)]).keys()
823
824        # op is call_method, get_attr, call_module
825        return [node for node in self.table[(op, None)].keys() if node.target == target]
826
827@compatibility(is_backward_compatible=True)
828class Graph:
829    """
830    ``Graph`` is the main data structure used in the FX Intermediate Representation.
831    It consists of a series of ``Node`` s, each representing callsites (or other
832    syntactic constructs). The list of ``Node`` s, taken together, constitute a
833    valid Python function.
834
835    For example, the following code
836
837    .. code-block:: python
838
839        import torch
840        import torch.fx
841
842        class MyModule(torch.nn.Module):
843            def __init__(self):
844                super().__init__()
845                self.param = torch.nn.Parameter(torch.rand(3, 4))
846                self.linear = torch.nn.Linear(4, 5)
847
848            def forward(self, x):
849                return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
850
851        m = MyModule()
852        gm = torch.fx.symbolic_trace(m)
853
854    Will produce the following Graph::
855
856        print(gm.graph)
857
858    .. code-block:: text
859
860        graph(x):
861            %linear_weight : [num_users=1] = self.linear.weight
862            %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
863            %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
864            %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
865            %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
866            %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
867            return topk_1
868
869    For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
870    """
871
872    @compatibility(is_backward_compatible=True)
873    def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
874                 tracer_extras: Optional[Dict[str, Any]] = None):
875        """
876        Construct an empty Graph.
877        """
878        self._root : Node = Node(self, '', 'root', '', (), {})
879        self._used_names : Dict[str, int] = {}  # base name -> number
880        self._insert = self._root.prepend
881        self._len = 0
882        self._graph_namespace = _Namespace()
883        self._owning_module = owning_module
884        self._tracer_cls = tracer_cls
885        self._tracer_extras = tracer_extras
886        self._codegen = CodeGen()
887        self._co_fields : Dict[str, Any] = {}
888        self._find_nodes_lookup_table = _FindNodesLookupTable()
889
890    @property
891    def owning_module(self):
892        return self._owning_module
893
894    @owning_module.setter
895    def owning_module(self, mod: Optional["GraphModule"]):
896        self._owning_module = mod
897
898    @property
899    def nodes(self) -> _node_list:
900        """
901        Get the list of Nodes that constitute this Graph.
902
903        Note that this ``Node`` list representation is a doubly-linked list. Mutations
904        during iteration (e.g. delete a Node, add a Node) are safe.
905
906        Returns:
907
908            A doubly-linked list of Nodes. Note that ``reversed`` can be called on
909            this list to switch iteration order.
910        """
911        return _node_list(self)
912
913    @compatibility(is_backward_compatible=False)
914    def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True):
915        """
916        Allows for fast query of nodes
917
918        Args:
919
920            op (str): the name of the operation
921
922            target (Optional[Target]): the target of the node. For call_function,
923                the target is required. For other ops, the target is optional.
924
925            sort (bool): whether to return nodes in the order they appear on
926                         on the graph.
927
928        Returns:
929
930            Iteratable of nodes with the requested op and target.
931        """
932        node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target)
933        if sort:
934            return sorted(node_list)
935        return node_list
936
937    @compatibility(is_backward_compatible=True)
938    def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
939        """
940        Copy all nodes from a given graph into ``self``.
941
942        Args:
943
944            g (Graph): The source graph from which to copy Nodes.
945
946            val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping
947                from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed
948                in with values in it already to override copying of certain values.
949
950        Returns:
951
952            The value in ``self`` that is now equivalent to the output value in ``g``,
953            if ``g`` had an ``output`` node. ``None`` otherwise.
954        """
955        for node in g.nodes:
956            if node in val_map:
957                continue
958            if node.op == 'output':
959                rv = map_arg(node.args[0], lambda n: val_map[n])
960                return rv if not return_output_node else (rv, node)
961            val_map[node] = self.node_copy(node, lambda n : val_map[n])
962        return None
963
964    def __deepcopy__(self, memo=None) -> 'Graph':
965        """
966        Explicitly implement __deepcopy__ to prevent excessive recursion depth
967        from the default implementation. This uses graph_copy to copy the nodes
968        in an iterative way, rather than recursive. It also populates the
969        memoization table to prevent unnecessary copies (e.g. references to
970        nodes or other parts of the Graph from a custom GraphModule implementation.
971        """
972        memo = memo if memo else {}
973        g = Graph(tracer_cls=self._tracer_cls)
974        output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
975        g._codegen = copy.deepcopy(self._codegen)
976        assert isinstance(output_vals, tuple)
977        output_val, old_output_node = output_vals
978        new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None))
979        new_output_node.meta = copy.copy(old_output_node.meta)
980        return g
981
982    @compatibility(is_backward_compatible=True)
983    def create_node(self, op: str, target: 'Target',
984                    args: Optional[Tuple['Argument', ...]] = None,
985                    kwargs: Optional[Dict[str, 'Argument']] = None,
986                    name: Optional[str] = None,
987                    type_expr: Optional[Any] = None) -> Node:
988        """
989        Create a ``Node`` and add it to the ``Graph`` at the current insert-point.
990        Note that the current insert-point can be set via :meth:`Graph.inserting_before`
991        and :meth:`Graph.inserting_after`.
992
993        Args:
994            op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
995                'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
996                described in the ``Graph`` docstring.
997
998            args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.
999
1000            kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node
1001
1002            name (Optional[str]): an optional string name for the ``Node``.
1003                This will influence the name of the value assigned to in the
1004                Python generated code.
1005
1006            type_expr (Optional[Any]): an optional type annotation representing the
1007                Python type the output of this node will have.
1008
1009        Returns:
1010
1011            The newly-created and inserted node.
1012        """
1013        assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
1014        args = () if args is None else args
1015        kwargs = {} if kwargs is None else kwargs
1016        assert isinstance(args, tuple), "args must be a tuple"
1017        assert isinstance(kwargs, dict), "kwargs must be a dict"
1018
1019        candidate = name if name is not None else self._target_to_str(target)
1020        name = self._graph_namespace.create_name(candidate, None)
1021        n = Node(self, name, op, target, args, kwargs, type_expr)
1022
1023        if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None:
1024            for f in self.owning_module._create_node_hooks:
1025                f(n)
1026
1027        self._graph_namespace.associate_name_with_obj(name, n)
1028
1029        self._insert(n)
1030        self._find_nodes_lookup_table.insert(n)
1031        self._len += 1
1032        return n
1033
1034    @compatibility(is_backward_compatible=False)
1035    def process_inputs(self, *args):
1036        """
1037        Processes args so that they can be passed to the FX graph.
1038        """
1039        return self._codegen.process_inputs(*args)
1040
1041    @compatibility(is_backward_compatible=False)
1042    def process_outputs(self, out):
1043        return self._codegen.process_outputs(out)
1044
1045
1046    @compatibility(is_backward_compatible=True)
1047    def erase_node(self, to_erase : Node) -> None:
1048        """
1049        Erases a ``Node`` from the ``Graph``. Throws an exception if
1050        there are still users of that node in the ``Graph``.
1051
1052        Args:
1053
1054            to_erase (Node): The ``Node`` to erase from the ``Graph``.
1055        """
1056        if len(to_erase.users) > 0:
1057            raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
1058                               f'users in the graph: {to_erase.users}!')
1059        if to_erase.graph != self:
1060            raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!")
1061        if to_erase._erased:
1062            warnings.warn(f"erase_node({to_erase}) on an already erased node")
1063            return
1064
1065        if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None:
1066            for f in self.owning_module._erase_node_hooks:
1067                f(to_erase)
1068
1069        self._find_nodes_lookup_table.remove(to_erase)
1070        to_erase._remove_from_list()
1071        to_erase._erased = True  # iterators may retain handles to erased nodes
1072        self._len -= 1
1073
1074        # Null out this Node's argument nodes so that the Nodes referred to
1075        # can update their ``users`` accordingly
1076        new_args = map_arg(to_erase.args, lambda n: None)
1077        assert isinstance(new_args, tuple)
1078        to_erase.args = new_args
1079        new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
1080        assert isinstance(new_kwargs, dict)
1081        to_erase.kwargs = new_kwargs
1082
1083    @compatibility(is_backward_compatible=True)
1084    def inserting_before(self, n: Optional[Node] = None):
1085        """Set the point at which create_node and companion methods will insert into the graph.
1086        When used within a 'with' statement, this will temporary set the insert point and
1087        then restore it when the with statement exits::
1088
1089            with g.inserting_before(n):
1090                ... # inserting before node n
1091            ... # insert point restored to what it was previously
1092            g.inserting_before(n) #  set the insert point permanently
1093
1094        Args:
1095
1096            n (Optional[Node]): The node before which to insert. If None this will insert before
1097                the beginning of the entire graph.
1098
1099        Returns:
1100            A resource manager that will restore the insert point on ``__exit__``.
1101        """
1102        if n is None:
1103            return self.inserting_after(self._root)
1104        assert n.graph == self, "Node to insert before is not in graph."
1105        return _InsertPoint(self, n.prepend)
1106
1107    @compatibility(is_backward_compatible=True)
1108    def inserting_after(self, n: Optional[Node] = None):
1109        """Set the point at which create_node and companion methods will insert into the graph.
1110        When used within a 'with' statement, this will temporary set the insert point and
1111        then restore it when the with statement exits::
1112
1113            with g.inserting_after(n):
1114                ... # inserting after node n
1115            ... # insert point restored to what it was previously
1116            g.inserting_after(n) #  set the insert point permanently
1117
1118        Args:
1119
1120            n (Optional[Node]): The node before which to insert. If None this will insert after
1121                the beginning of the entire graph.
1122
1123        Returns:
1124            A resource manager that will restore the insert point on ``__exit__``.
1125        """
1126        if n is None:
1127            return self.inserting_before(self._root)
1128        assert n.graph == self, "Node to insert after is not in graph."
1129        return _InsertPoint(self, n.append)
1130
1131    @compatibility(is_backward_compatible=True)
1132    def placeholder(self, name: str, type_expr: Optional[Any] = None,
1133                    default_value : Any = inspect.Signature.empty) -> Node:
1134        """
1135        Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
1136        a function input.
1137
1138        Args:
1139
1140            name (str): A name for the input value. This corresponds to the name
1141                of the positional argument to the function this ``Graph`` represents.
1142
1143            type_expr (Optional[Any]): an optional type annotation representing the
1144                Python type the output of this node will have. This is needed in some
1145                cases for proper code generation (e.g. when the function is used
1146                subsequently in TorchScript compilation).
1147
1148            default_value (Any): The default value this function argument should take
1149                on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
1150                should be passed as this argument to specify that the parameter does _not_
1151                have a default value.
1152
1153        .. note::
1154            The same insertion point and type expression rules apply for this method
1155            as ``Graph.create_node``.
1156        """
1157        args = () if default_value is inspect.Signature.empty else (default_value,)
1158        return self.create_node('placeholder', name, args=args, type_expr=type_expr)
1159
1160    @compatibility(is_backward_compatible=True)
1161    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
1162        """
1163        Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
1164        fetch of an attribute from the ``Module`` hierarchy.
1165
1166        Args:
1167
1168            qualified_name (str): the fully-qualified name of the attribute to be retrieved.
1169                For example, if the traced Module has a submodule named ``foo``, which has a
1170                submodule named ``bar``, which has an attribute named ``baz``, the qualified
1171                name ``foo.bar.baz`` should be passed as ``qualified_name``.
1172
1173            type_expr (Optional[Any]): an optional type annotation representing the
1174                Python type the output of this node will have.
1175
1176
1177        Returns:
1178
1179            The newly-created and inserted ``get_attr`` node.
1180
1181        .. note::
1182            The same insertion point and type expression rules apply for this method
1183            as ``Graph.create_node``.
1184        """
1185        def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool:
1186            module_path, _, name = qualified_name.rpartition(".")
1187
1188            try:
1189                submod: torch.nn.Module = mod.get_submodule(module_path)
1190            except AttributeError:
1191                warnings.warn(f"Failed to fetch module {module_path}!")
1192                return False
1193
1194            if not hasattr(submod, name):
1195                return False
1196
1197            res = getattr(submod, name)
1198
1199            if (not isinstance(res, torch.nn.Module)
1200                    and not isinstance(res, torch.nn.Parameter)
1201                    and name not in submod._buffers):
1202                return False
1203
1204            return True
1205
1206        if (self.owning_module and
1207                not _get_attr_reference_exists(self.owning_module, qualified_name)):
1208            warnings.warn("Attempted to insert a get_attr Node with no "
1209                          "underlying reference in the owning "
1210                          "GraphModule! Call "
1211                          "GraphModule.add_submodule to add the "
1212                          "necessary submodule, "
1213                          "GraphModule.add_parameter to add the "
1214                          "necessary Parameter, or "
1215                          "nn.Module.register_buffer to add the "
1216                          "necessary buffer", stacklevel=2)
1217        return self.create_node('get_attr', qualified_name, type_expr=type_expr)
1218
1219    @compatibility(is_backward_compatible=True)
1220    def call_module(self,
1221                    module_name: str,
1222                    args: Optional[Tuple['Argument', ...]] = None,
1223                    kwargs: Optional[Dict[str, 'Argument']] = None,
1224                    type_expr: Optional[Any] = None) -> Node:
1225        """
1226        Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node
1227        represents a call to the forward() function of a ``Module`` in the ``Module``
1228        hierarchy.
1229
1230        Args:
1231
1232            module_name (str): The qualified name of the ``Module`` in the ``Module``
1233                hierarchy to be called. For example, if the traced ``Module`` has a
1234                submodule named ``foo``, which has a submodule named ``bar``, the
1235                qualified name ``foo.bar`` should be passed as ``module_name`` to
1236                call that module.
1237
1238            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
1239                to the called method. Note that this should *not* include a ``self`` argument.
1240
1241            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
1242                to the called method
1243
1244            type_expr (Optional[Any]): an optional type annotation representing the
1245                Python type the output of this node will have.
1246
1247        Returns:
1248
1249            The newly-created and inserted ``call_module`` node.
1250
1251        .. note::
1252            The same insertion point and type expression rules apply for this method
1253            as :meth:`Graph.create_node`.
1254        """
1255        if (self.owning_module and
1256                self.owning_module.get_submodule(module_name) is None):
1257            warnings.warn("Attempted to insert a call_module Node with "
1258                          "no underlying reference in the owning "
1259                          "GraphModule! Call "
1260                          "GraphModule.add_submodule to add the "
1261                          "necessary submodule")
1262        return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
1263
1264    @compatibility(is_backward_compatible=True)
1265    def call_method(self,
1266                    method_name: str,
1267                    args: Optional[Tuple['Argument', ...]] = None,
1268                    kwargs: Optional[Dict[str, 'Argument']] = None,
1269                    type_expr: Optional[Any] = None) -> Node:
1270        """
1271        Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
1272        represents a call to a given method on the 0th element of ``args``.
1273
1274        Args:
1275
1276            method_name (str): The name of the method to apply to the self argument.
1277                For example, if args[0] is a ``Node`` representing a ``Tensor``,
1278                then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.
1279
1280            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
1281                to the called method. Note that this *should* include a ``self`` argument.
1282
1283            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
1284                to the called method
1285
1286            type_expr (Optional[Any]): an optional type annotation representing the
1287                Python type the output of this node will have.
1288
1289        Returns:
1290
1291            The newly created and inserted ``call_method`` node.
1292
1293        .. note::
1294            The same insertion point and type expression rules apply for this method
1295            as :meth:`Graph.create_node`.
1296        """
1297        return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
1298
1299    @compatibility(is_backward_compatible=True)
1300    def call_function(self,
1301                      the_function: Callable[..., Any],
1302                      args: Optional[Tuple['Argument', ...]] = None,
1303                      kwargs: Optional[Dict[str, 'Argument']] = None,
1304                      type_expr: Optional[Any] = None) -> Node:
1305        """
1306        Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
1307        represents a call to a Python callable, specified by ``the_function``.
1308
1309        Args:
1310
1311            the_function (Callable[..., Any]): The function to be called. Can be any PyTorch
1312                operator, Python function, or member of the ``builtins`` or ``operator``
1313                namespaces.
1314
1315            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
1316                to the called function.
1317
1318            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
1319                to the called function
1320
1321            type_expr (Optional[Any]): an optional type annotation representing the
1322                Python type the output of this node will have.
1323
1324        Returns:
1325
1326            The newly created and inserted ``call_function`` node.
1327
1328        .. note::
1329            The same insertion point and type expression rules apply for this method
1330            as :meth:`Graph.create_node`.
1331        """
1332        return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
1333
1334    @compatibility(is_backward_compatible=True)
1335    def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
1336        """
1337        Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
1338        the graph of node to the graph of self. Example::
1339
1340            # Copying all the nodes in `g` into `new_graph`
1341            g : torch.fx.Graph = ...
1342            new_graph = torch.fx.graph()
1343            value_remap = {}
1344            for node in g.nodes:
1345                value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
1346
1347        Args:
1348
1349            node (Node): The node to copy into ``self``.
1350
1351            arg_transform (Callable[[Node], Argument]): A function that transforms
1352                ``Node`` arguments in node's ``args`` and ``kwargs`` into the
1353                equivalent argument in ``self``. In the simplest case, this should
1354                retrieve a value out of a table mapping Nodes in the original
1355                graph to ``self``.
1356        """
1357        args = map_arg(node.args, arg_transform)
1358        kwargs = map_arg(node.kwargs, arg_transform)
1359        assert isinstance(args, tuple)
1360        assert isinstance(kwargs, dict)
1361        result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type)
1362        result_node.meta = copy.copy(node.meta)
1363        return result_node
1364
1365    @compatibility(is_backward_compatible=True)
1366    def output(self, result: 'Argument', type_expr: Optional[Any] = None):
1367        """
1368        Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
1369        a ``return`` statement in Python code. ``result`` is the value that should
1370        be returned.
1371
1372        Args:
1373
1374            result (Argument): The value to be returned.
1375
1376            type_expr (Optional[Any]): an optional type annotation representing the
1377                Python type the output of this node will have.
1378
1379        .. note::
1380
1381            The same insertion point and type expression rules apply for this method
1382            as ``Graph.create_node``.
1383        """
1384        return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
1385
1386    def _target_to_str(self, target : Target) -> str:
1387        if callable(target):
1388            op = target.__name__
1389        else:
1390            assert isinstance(target, str)
1391            op = target
1392            if _is_magic(op):
1393                op = op[2:-2]
1394        op = _snake_case(op)
1395        return op
1396
1397    @compatibility(is_backward_compatible=True)
1398    def python_code(
1399        self, root_module: str, *,
1400        verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False
1401    ) -> PythonCode:
1402        """
1403        Turn this ``Graph`` into valid Python code.
1404
1405        Args:
1406
1407            root_module (str): The name of the root module on which to look-up
1408                qualified name targets. This is usually 'self'.
1409
1410        Returns:
1411
1412            A PythonCode object, consisting of two fields:
1413                src: the Python source code representing the object
1414                globals: a dictionary of global names in `src` -> the objects that they reference.
1415        """
1416        # NOTE: [Graph Namespaces]
1417        #
1418        # There are two types of symbols in generated Python source code:
1419        # locals and globals.
1420        #   Locals are locally defined by the output of a node in the Graph.
1421        #   Globals are references to external objects, like functions or types.
1422        #
1423        # When generating Python code, we need to make sure to name things
1424        # appropriately. In particular:
1425        # - All names should be unique, to avoid weird shadowing bugs.
1426        # - These names need to be consistent, e.g. a object should always be
1427        #   referenced by the same name.
1428        #
1429        # To do this, we create a new namespace just for this source. All names
1430        # that get printed must come from this namespace.
1431        #
1432        # Why can't we re-use node.name? Because it was generated within the
1433        # namespace `self._graph_namespace`. In order to provide uniqueness
1434        # over both locals (node.name) *and* globals, we create a completely
1435        # new namespace to put all identifiers in.
1436        namespace = _Namespace()
1437
1438        # Override Node's repr to generate a valid name within our namespace.
1439        # Since repr() is designed to produce a valid Python expression, it
1440        # makes sense to re-use it. This way, it's easy to print something like
1441        # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
1442        # implemented cooperatively to allow this.
1443        def node_repr(n: Node):
1444            return namespace.create_name(n.name, n)
1445
1446        @contextmanager
1447        def override_node_repr(graph: Graph):
1448            orig_repr_fns = {}
1449            for node in graph.nodes:
1450                orig_repr_fns[node] = node._repr_fn
1451                node._repr_fn = node_repr
1452            try:
1453                yield None
1454            finally:
1455                # restore the original repr functions
1456                for node in graph.nodes:
1457                    node._repr_fn = orig_repr_fns[node]
1458
1459        with override_node_repr(self):
1460            return self._python_code(
1461                root_module, namespace,
1462                verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored
1463            )
1464
1465    def _python_code(
1466        self, root_module: str, namespace: _Namespace, *,
1467        verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False,
1468    ) -> PythonCode:
1469        return self._codegen._gen_python_code(
1470            self.nodes, root_module, namespace,
1471            verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored
1472        )
1473
1474
1475    def __str__(self) -> str:
1476        """
1477        Return a human-readable (not machine-readable) string representation
1478        of this Graph
1479        """
1480        placeholder_names : List[str] = []
1481        # This is a one-element array just so ``format_node`` can modify the closed
1482        # over value
1483        maybe_return_typename : List[str] = ['']
1484
1485        node_strs = [node.format_node(placeholder_names) for node in self.nodes]
1486        param_str = ', '.join(placeholder_names)
1487        s = f'graph({param_str}){maybe_return_typename[0]}:'
1488        for node_str in node_strs:
1489            if node_str:
1490                s += '\n    ' + node_str
1491        return s
1492
1493    @compatibility(is_backward_compatible=True)
1494    def print_tabular(self):
1495        """
1496        Prints the intermediate representation of the graph in tabular
1497        format. Note that this API requires the ``tabulate`` module to be
1498        installed.
1499        """
1500        try:
1501            from tabulate import tabulate
1502        except ImportError:
1503            print("`print_tabular` relies on the library `tabulate`, "
1504                  "which could not be found on this machine. Run `pip "
1505                  "install tabulate` to install the library.")
1506            raise
1507
1508        node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
1509                      for n in self.nodes]
1510        print(tabulate(node_specs,
1511              headers=['opcode', 'name', 'target', 'args', 'kwargs']))
1512
1513    @compatibility(is_backward_compatible=True)
1514    def lint(self):
1515        """
1516        Runs various checks on this Graph to make sure it is well-formed. In
1517        particular:
1518        - Checks Nodes have correct ownership (owned by this graph)
1519        - Checks Nodes appear in topological order
1520        - If this Graph has an owning GraphModule, checks that targets
1521        exist in that GraphModule
1522        """
1523
1524        # Check topo order
1525        def check_arg(arg : Node, n : Optional[Node] = None) -> None:
1526            context_str = f' of Node \'{n}\' ' if n else ' '
1527            if arg.graph is not self:
1528                raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
1529                                   f'but was used as an argument! If you are copying nodes from another graph, make '
1530                                   f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}')
1531            if arg not in seen_values:
1532                raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
1533                                   f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
1534
1535        seen_names : Set[str] = set()
1536        seen_values : Set[Node] = set()
1537        for node in self.nodes:
1538            if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
1539                raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
1540            if node.graph is not self:
1541                raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
1542            if node not in self._find_nodes_lookup_table:
1543                raise RuntimeError(f"Node '{node}' is not added to the side table")
1544            map_arg(node.args, lambda arg: check_arg(arg, node))
1545            map_arg(node.kwargs, lambda arg: check_arg(arg, node))
1546            seen_values.add(node)
1547
1548            if node.name in seen_names:
1549                raise RuntimeError(f'Node redefined name {node.name}!')
1550            seen_names.add(node.name)
1551
1552        # Check targets are legit
1553        if self.owning_module:
1554            num_warnings = 0
1555            MAX_WARNINGS = 5
1556            for node in self.nodes:
1557                if node.op == 'call_function':
1558                    if not callable(node.target):
1559                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
1560                                         'a Callable is expected')
1561                else:
1562                    if not isinstance(node.target, str):
1563                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
1564                                         'a str is expected')
1565                if node.op in ['get_attr', 'call_module']:
1566                    target_atoms = node.target.split('.')
1567                    m_itr = self.owning_module
1568                    for i, atom in enumerate(target_atoms):
1569                        new_m_itr = getattr(m_itr, atom, None)
1570                        seen_qualname = '.'.join(target_atoms[:i])
1571                        if new_m_itr is None:
1572                            raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
1573                                               f'{atom} of {seen_qualname}')
1574                        if (node.op == "call_module"
1575                                and not isinstance(new_m_itr, torch.nn.Module)):
1576                            raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
1577                                               'not reference an nn.Module')
1578                        elif (node.op == "get_attr"
1579                              and not isinstance(new_m_itr, torch.nn.Module)
1580                              and not isinstance(new_m_itr, torch.nn.Parameter)
1581                              and atom not in m_itr._buffers):
1582                            if num_warnings < MAX_WARNINGS:
1583                                # Don't emit this warning too frequently,
1584                                # for very large graphs this can become very expensive
1585                                # from a performance perspective.
1586                                warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
1587                                              'not reference an nn.Module, nn.Parameter, or buffer, which is '
1588                                              'what \'get_attr\' Nodes typically target')
1589                            num_warnings += 1
1590                        else:
1591                            m_itr = new_m_itr
1592            if num_warnings > MAX_WARNINGS:
1593                warnings.warn(
1594                    f'Additional {num_warnings - MAX_WARNINGS} warnings '
1595                    'suppressed about get_attr references'
1596                )
1597
1598    @compatibility(is_backward_compatible=True)
1599    def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None):
1600        """
1601        Remove all dead code from the graph, based on each node's number of
1602        users, and whether the nodes have any side effects. The graph must be
1603        topologically sorted before calling.
1604
1605        Args:
1606            is_impure_node (Optional[Callable[[Node], bool]]): A function that returns
1607            whether a node is impure. If this is None, then the default behavior is to
1608            use Node.is_impure.
1609
1610        Returns:
1611          bool: Whether the graph was changed as a result of the pass.
1612
1613        Example:
1614
1615        Before dead code is eliminated, `a` from `a = x + 1` below has no users
1616        and thus can be eliminated from the graph without having an effect.
1617
1618        .. code-block:: python
1619
1620            def forward(self, x):
1621                a = x + 1
1622                return x + self.attr_1
1623
1624        After dead code is eliminated, `a = x + 1` has been removed, and the rest
1625        of `forward` remains.
1626
1627        .. code-block:: python
1628
1629            def forward(self, x):
1630                return x + self.attr_1
1631
1632        .. warning::
1633
1634            Dead code elimination has some heuristics to avoid removing
1635            side-effectful nodes (see Node.is_impure) but in general coverage
1636            is very bad, so you should assume that this method is not sound
1637            to call unless you know that your FX graph consists entirely
1638            of functional operations or you supply your own custom
1639            function for detecting side-effectful nodes.
1640        """
1641        # Lint the graph first to make sure its topologically sorted, otherwise
1642        # DCE below will not behave as expected.
1643        self.lint()
1644
1645        def has_side_effect(node):
1646            if is_impure_node is not None:
1647                return is_impure_node(node)
1648            return node.is_impure()
1649
1650        # Reverse iterate so that when we remove a node, any nodes used as an
1651        # input to that node have an updated user count that no longer reflects
1652        # the removed node.
1653        changed = False
1654        for node in reversed(self.nodes):
1655            if not has_side_effect(node) and len(node.users) == 0:
1656                self.erase_node(node)
1657                changed = True
1658
1659        return changed
1660
1661    @compatibility(is_backward_compatible=False)
1662    def set_codegen(self, codegen: CodeGen):
1663        self._codegen = codegen
1664
1665    @compatibility(is_backward_compatible=False)
1666    def on_generate_code(
1667        self,
1668        make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]
1669    ):
1670        """Register a transformer function when python code is generated
1671
1672        Args:
1673            make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
1674                a function that returns a code transformer to be registered.
1675                This function is called by `on_generate_code` to obtain the
1676                code transformer.
1677
1678                This function is also given as its input the currently
1679                registered code transformer (or None if nothing is registered),
1680                in case it is not desirable to overwrite it. This is useful to
1681                chain code transformers together.
1682
1683        Returns:
1684            a context manager that when used in a `with` statement, to automatically
1685            restore the previously registered code transformer.
1686
1687        Example:
1688
1689        .. code-block:: python
1690
1691
1692            gm: fx.GraphModule = ...
1693
1694            # This is a code transformer we want to register. This code
1695            # transformer prepends a pdb import and trace statement at the very
1696            # beginning of the generated torch.fx code to allow for manual
1697            # debugging with the PDB library.
1698            def insert_pdb(body):
1699                return ["import pdb; pdb.set_trace()\\n", *body]
1700
1701            # Registers `insert_pdb`, and overwrites the current registered
1702            # code transformer (given by `_` to the lambda):
1703            gm.graph.on_generate_code(
1704                lambda _: insert_pdb
1705            )
1706
1707            # Or alternatively, registers a code transformer which first
1708            # runs `body` through existing registered transformer, then
1709            # through `insert_pdb`:
1710            gm.graph.on_generate_code(
1711                lambda current_trans: (
1712                    lambda body: insert_pdb(
1713                        current_trans(body) if current_trans
1714                        else body
1715                    )
1716                )
1717            )
1718
1719            gm.recompile()
1720            gm(*inputs)  # drops into pdb
1721
1722
1723        This function can also be used as a context manager, with the benefit to
1724        automatically restores the previously registered code transformer:
1725
1726        .. code-block:: python
1727
1728            # ... continue from previous example
1729
1730            with gm.graph.on_generate_code(lambda _: insert_pdb):
1731                # do more stuff with `gm`...
1732                gm.recompile()
1733                gm(*inputs)  # drops into pdb
1734
1735            # now previous code transformer is restored (but `gm`'s code with pdb
1736            # remains - that means you can run `gm` with pdb here too, until you
1737            # run next `recompile()`).
1738        """
1739        on_gen_code_old = self._codegen._body_transformer
1740        self._codegen._body_transformer = make_transformer(on_gen_code_old)
1741
1742        @contextlib.contextmanager
1743        def on_generate_code_context_manager():
1744            try:
1745                yield
1746            finally:
1747                self._codegen._body_transformer = on_gen_code_old
1748
1749        return on_generate_code_context_manager()
1750
1751
1752reflectable_magic_methods = {
1753    'add': '{} + {}',
1754    'sub': '{} - {}',
1755    'mul': '{} * {}',
1756    'floordiv': '{} // {}',
1757    'truediv': '{} / {}',
1758    'div': '{} / {}',
1759    'mod': '{} % {}',
1760    'pow': '{} ** {}',
1761    'lshift': '{} << {}',
1762    'rshift': '{} >> {}',
1763    'and_': '{} & {}',
1764    'or_': '{} | {}',
1765    'xor': '{} ^ {}',
1766    'getitem': '{}[{}]',
1767    'matmul': '{} @ {}',
1768}
1769
1770magic_methods = dict({
1771    'eq': '{} == {}',
1772    'ne': '{} != {}',
1773    'lt': '{} < {}',
1774    'gt': '{} > {}',
1775    'le': '{} <= {}',
1776    'ge': '{} >= {}',
1777    'pos': '+{}',
1778    'neg': '-{}',
1779    'invert': '~{}'}, **reflectable_magic_methods)
1780
1781inplace_methods = {
1782    'iadd': '{} += {}',
1783    'iand': '{} &= {}',
1784    'ifloordiv': '{} //= {}',
1785    'ilshift': '{} <<= {}',
1786    'imod': '{} %= {}',
1787    'imul': '{} *= {}',
1788    'imatmul': '{} @= {}',
1789    'ior': '{} |= {}',
1790    'ipow': '{} **= {}',
1791    'irshift': '{} >>= {}',
1792    'isub': '{} -= {}',
1793    'itruediv': '{} /= {}',
1794    'ixor': '{} ^= {}',
1795    'setitem': '{}[{}] = {}',
1796}
1797