xref: /aosp_15_r20/external/pytorch/torch/fx/proxy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import enum
4import dis
5import copy
6import sys
7import torch
8import inspect
9import operator
10import collections
11import logging
12
13from dataclasses import is_dataclass, fields
14
15
16from .graph import magic_methods, reflectable_magic_methods, Graph
17from torch.utils._traceback import CapturedTraceback
18from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
19from .node import Target, Node, Argument, base_types, map_aggregate
20from ._compatibility import compatibility
21from .operator_schemas import check_for_mutable_operation
22import torch.fx.traceback as fx_traceback
23
24__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
25           'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
26           'ScopeContextManager']
27
28
29log = logging.getLogger(__name__)
30
31
32@compatibility(is_backward_compatible=False)
33class Scope:
34    """ Scope object that records the module path and the module type
35    of a module. Scope is used to track the information of the module
36    that contains a Node in a Graph of GraphModule. For example::
37
38        class Sub(torch.nn.Module):
39            def forward(self, x):
40                # This will be a call_method Node in GraphModule,
41                # scope for this would be (module_path="sub", module_type=Sub)
42                return x.transpose(1, 2)
43
44        class M(torch.nn.Module):
45            def __init__(self) -> None:
46                self.sub = Sub()
47
48            def forward(self, x):
49                # This will be a call_method Node as well,
50                # scope for this would be (module_path="", None)
51                x = x.transpose(1, 2)
52                x = self.sub(x)
53                return x
54
55    """
56
57    def __init__(self, module_path: str, module_type: Any):
58        super().__init__()
59        self.module_path = module_path
60        self.module_type = module_type
61
62
63@compatibility(is_backward_compatible=False)
64class ScopeContextManager:
65    """ A context manager to track the Scope of Node during symbolic tracing.
66    When entering a forward function of a Module, we'll update the scope information of
67    the current module, and when we exit, we'll restore the previous scope information.
68    """
69
70    def __init__(
71        self,
72        scope: Scope,
73        current_scope: Scope,
74    ):
75        super().__init__()
76        # Keep a copy of prev scope to restore on exit
77        self._prev_scope = copy.copy(scope)
78        # Update scope to current scope
79        scope.module_path = current_scope.module_path
80        scope.module_type = current_scope.module_type
81        # Save a reference so we can restore it
82        self._scope = scope
83
84    def __enter__(self):
85        return self._scope
86
87    def __exit__(self, *args):
88        self._scope.module_path = self._prev_scope.module_path
89        self._scope.module_type = self._prev_scope.module_type
90        return
91
92
93_COPY_META_FIELDS = [
94    "nn_module_stack",
95    "torch_fn",
96    "source_fn_stack",
97    "original_aten",
98    "recompute",
99    "ac_graph_id",
100    "from_node",
101    "quantization_tag",  # TODO deprecated
102    "_numeric_debug_handle",  # TODO deprecated
103    "custom",
104    "partitioner_tag"
105]
106
107
108@compatibility(is_backward_compatible=True)
109class TracerBase:
110    graph: Graph
111    record_stack_traces : bool = False
112    # Feature flag for mutable schema checking
113    # Enableby default in 1.12
114    check_mutable_operations : bool = False
115    # Feature flag for assert tracing
116    trace_asserts : bool = False
117    # Feature flag for proxying accesses to buffer values
118    proxy_buffer_attributes : bool = False
119
120    # Name of the function to be traced. It will only be used when
121    # ``root`` is an instance of ``nn.Module``
122    traced_func_name: str = "forward"
123
124    # Maps the containing module's name to the operator name
125    scope : Scope
126
127    # Records the module call stack
128    module_stack: OrderedDict[str, Tuple[str, Any]]
129
130    # Mapping of node name to module scope
131    node_name_to_scope: Dict[str, Tuple[str, type]]
132
133    @compatibility(is_backward_compatible=True)
134    def create_node(self, kind : str, target : Target,
135                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
136                    type_expr : Optional[Any] = None) -> Node:
137        """
138        Inserts a graph node given target, args, kwargs, and name.
139
140        This method can be overridden to do extra checking, validation, or
141        modification of values used in node creation. For example, one might
142        want to disallow in-place operations from being recorded.
143        """
144
145        if kind == 'call_function' and self.check_mutable_operations:
146            check_for_mutable_operation(target, args, kwargs)
147
148        node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
149        # TODO node_name_to_scope will be depreciated in favor of
150        # node.meta['nn_module_stack']
151        self.node_name_to_scope[node.name] = (
152            self.scope.module_path,
153            self.scope.module_type,
154        )
155        # Optionally set stack trace on the created Node for debugging purposes
156        if fx_traceback.has_preserved_node_meta():
157            current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
158
159            stack_trace = current_meta.get("stack_trace")
160            if stack_trace:
161                node.stack_trace = stack_trace
162            # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
163            # If other meta fields are needed, they can be added here
164            for field in _COPY_META_FIELDS:
165                if field in current_meta:
166                    node.meta[field] = copy.copy(current_meta[field])
167
168            # Here we decrement to account for the sequence_nr having
169            # just been incremented while tracing this lowered aten op.
170            new_seq_nr = torch.autograd._get_sequence_nr() - 1
171            # The sequence_nr increments every time a new autograd Node
172            # is created. During the FWD pass we store the sequence_nr
173            # corresponding to the last autograd Node created on this fx
174            # node's meta.  A single aten op can create multiple autograd
175            # nodes as is the case with in-place foreach ops. During the
176            # BWD pass we retrieve the sequence_nr stored on the current
177            # executing autograd Node. See NOTE [ Sequence Number ].
178            if current_meta.get("in_grad_fn", 0) > 0:
179                new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
180            node.meta["seq_nr"] = new_seq_nr
181
182        elif self.module_stack:
183            node.meta['nn_module_stack'] = copy.copy(self.module_stack)
184
185        log.debug("create_node %s", node)
186        return node
187
188    @compatibility(is_backward_compatible=True)
189    def proxy(self, node: Node) -> 'Proxy':
190        return Proxy(node, self)
191
192    @compatibility(is_backward_compatible=True)
193    def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
194                     name: Optional[str] = None, type_expr : Optional[Any] = None,
195                     proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
196        '''
197        Create a Node from the given arguments, then return the Node
198        wrapped in a Proxy object.
199
200        If kind = 'placeholder', then we're creating a Node that
201        represents the parameter of a function. If we need to encode
202        a default parameter, we use the ``args`` tuple. ``args`` is
203        otherwise empty for ``placeholder`` Nodes.
204        '''
205
206        args_ = self.create_arg(args)
207        kwargs_ = self.create_arg(kwargs)
208        assert isinstance(args_, tuple)
209        assert isinstance(kwargs_, dict)
210
211        node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
212
213        if not proxy_factory_fn:
214            proxy = self.proxy(node)
215        else:
216            proxy = proxy_factory_fn(node)
217
218        if self.record_stack_traces and not proxy.node.stack_trace:
219            proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
220
221
222        return proxy
223
224    def _find_user_frame(self):
225        """
226        Find the Python stack frame executing the user code during
227        symbolic tracing.
228        """
229        # We have to do a little dance here. Basically, walk up the callstack and
230        # record the first frame not in the pytorch source. This is the frame executing
231        # the user code during tracing.
232        frame = inspect.currentframe()
233
234        pt_files = ['torch/fx/proxy.py',
235                    'torch/fx/_symbolic_trace.py',
236                    'torch/fx/experimental/proxy_tensor.py',
237                    'torch/_ops.py',
238                    'torch/_tensor.py',
239                    'torch/utils/_python_dispatch.py',
240                    'torch/_prims_common/wrappers.py',
241                    'torch/_refs/__init__.py',
242                    'torch/_refs/nn/functional/__init__.py',
243                    'torch/utils/_stats.py',
244                    ]
245        while frame:
246            frame = frame.f_back
247            if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
248                break
249
250        if not frame:
251            return None
252
253        return frame
254
255    @compatibility(is_backward_compatible=True)
256    def create_arg(self, a: Any) -> Argument:
257        """
258        A method that lowers the objects seen as arguments during symbolic evaluation
259        into Argument types that can be stored in IR.
260
261        Can be override to support more trace-specific types.
262        """
263        if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
264            return a.__fx_create_arg__(self)
265        # aggregates
266        elif isinstance(a, tuple) and hasattr(a, '_fields'):
267            # NamedTuple constructors don't seem to like getting a generator
268            # expression as an argument to their constructor, so build this
269            # intermediate tuple and unpack it into the NamedTuple constructor
270            args = tuple(self.create_arg(elem) for elem in a)
271            return type(a)(*args)  # type: ignore[arg-type]
272        elif isinstance(a, (tuple, list)):
273            return type(a)(self.create_arg(elem) for elem in a)
274        elif isinstance(a, dict):
275            r = {}
276            for k, v in a.items():
277                # Check for invalid dict keys. We do not want a Proxy to appear
278                # anywhere within the key. Since keys can be collection types,
279                # we iterate through the key with map_aggregate
280                k = self.create_arg(k)
281
282                def no_node(arg):
283                    if isinstance(arg, Node):
284                        raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
285                                           f"Node. Got key: {k}")
286                map_aggregate(k, no_node)
287
288                r[k] = self.create_arg(v)
289            return r
290        elif isinstance(a, slice):
291            return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
292
293        elif isinstance(a, range):
294            return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
295
296        elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
297            return a
298
299        if isinstance(a, Proxy):
300            # base case: we unwrap the Proxy object
301            return a.node
302
303        if is_dataclass(a):
304            kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
305            return self.create_node("call_function", a.__class__, (), kwargs)
306
307        elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
308            return a
309        raise NotImplementedError(f"argument of type: {type(a)}")
310
311    @compatibility(is_backward_compatible=True)
312    def to_bool(self, obj: 'Proxy') -> bool:
313        """Called when a proxy object is being converted to a boolean, such as
314        when used in control flow.  Normally we don't know what to do because
315        we don't know the value of the proxy, but a custom tracer can attach more
316        information to the graph node using create_node and can choose to return a value.
317        """
318        raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
319
320    @compatibility(is_backward_compatible=True)
321    def iter(self, obj: 'Proxy') -> Iterator:
322        """Called when a proxy object is being iterated over, such as
323        when used in control flow.  Normally we don't know what to do because
324        we don't know the value of the proxy, but a custom tracer can attach more
325        information to the graph node using create_node and can choose to return an iterator.
326        """
327        raise TraceError('Proxy object cannot be iterated. This can be '
328                         'attempted when the Proxy is used in a loop or'
329                         ' as a *args or **kwargs function argument. '
330                         'See the torch.fx docs on pytorch.org for a '
331                         'more detailed explanation of what types of '
332                         'control flow can be traced, and check out the'
333                         ' Proxy docstring for help troubleshooting '
334                         'Proxy iteration errors')
335
336    @compatibility(is_backward_compatible=True)
337    def keys(self, obj: 'Proxy') -> Any:
338        """Called when a proxy object is has the keys() method called.
339        This is what happens when ** is called on a proxy. This should return an
340        iterator it ** is suppose to work in your custom tracer.
341        """
342        return Attribute(obj, 'keys')()
343
344
345# used in Proxy object when just appending to the graph while not tracing.
346@compatibility(is_backward_compatible=True)
347class GraphAppendingTracer(TracerBase):
348    def __init__(self, graph: Graph):
349        super().__init__()
350        self.graph = graph
351        self.scope = Scope("", None)
352        self.module_stack = collections.OrderedDict()
353        self.node_name_to_scope = {}
354
355@compatibility(is_backward_compatible=False)
356def assert_fn(x):
357    assert x
358
359@compatibility(is_backward_compatible=True)
360class TraceError(ValueError):
361    pass
362
363@compatibility(is_backward_compatible=True)
364class Proxy:
365    """
366    ``Proxy`` objects are ``Node`` wrappers that flow through the
367    program during symbolic tracing and record all the operations
368    (``torch`` function calls, method calls, operators) that they touch
369    into the growing FX Graph.
370
371    If you're doing graph transforms, you can wrap your own ``Proxy``
372    method around a raw ``Node`` so that you can use the overloaded
373    operators to add additional things to a ``Graph``.
374
375    ``Proxy`` objects cannot be iterated. In other words, the symbolic
376    tracer will throw an error if a ``Proxy`` is used in a loop or as
377    an ``*args``/``**kwargs`` function argument.
378
379    There are two main ways around this:
380    1. Factor out the untraceable logic into a top-level function and
381    use ``fx.wrap`` on it.
382    2. If the control flow is static (i.e. the loop trip count is
383    based on some hyperparameter), the code can be kept in its original
384    position and refactored into something like::
385
386        for i in range(self.some_hyperparameter):
387            indexed_item = proxied_value[i]
388
389    For a more detailed description into the Proxy internals, check out
390    the "Proxy" section in `torch/fx/README.md`
391    """
392
393    @compatibility(is_backward_compatible=True)
394    def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
395        if tracer is None:
396            # This allows you to create a Proxy object around a raw Node
397            tracer = GraphAppendingTracer(node.graph)
398        self.tracer = tracer
399        self.node = node
400
401    def __repr__(self) -> str:
402        return f'Proxy({self.node.name})'
403
404    def __getattr__(self, k) -> 'Attribute':
405        # note: not added to the graph yet, if this is a method call
406        # we peephole optimize to the method invocation
407        return Attribute(self, k)
408
409    def __getstate__(self) -> Dict:
410        return self.__dict__
411
412    def __deepcopy__(self, memo) -> Dict:
413        # We have to explicitly override this method, because otherwise deepcopy
414        # will go to __getattr__(self, "__deepcopy__") and return a
415        # Attribute(__deepcopy__), and may go into an infinite loop in some cases.
416        import copy
417        new_dict = {}
418        for k, v in self.__dict__.items():
419            try:
420                new_obj = copy.deepcopy(v, memo)
421            except Exception:
422                log.warning(
423                    "Shallow copy %s of Proxy because it cannot be deepcopied. "
424                    "Proxy is created for node %s", k, self.node.name)
425                new_obj = copy.copy(v)
426            new_dict[k] = new_obj
427        assert "node" in new_dict
428        assert "tracer" in new_dict
429        new_proxy = Proxy(new_dict["node"], new_dict["tracer"])
430        for k, v in new_dict.items():
431            new_proxy.__dict__[k] = v
432        return new_proxy
433
434    def __setstate__(self, d):
435        # This is called when being unpickled/loaded.
436        self.__dict__ = d
437
438    def __call__(self, *args, **kwargs) -> 'Proxy':
439        return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
440
441    def __iter__(self) -> Iterator['Proxy']:
442        frame = inspect.currentframe()
443        assert frame is not None
444        calling_frame = frame.f_back
445        assert calling_frame is not None
446        inst_list = list(dis.get_instructions(calling_frame.f_code))
447        if sys.version_info >= (3, 11):
448            from bisect import bisect_left
449            inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
450        else:
451            inst_idx = calling_frame.f_lasti // 2
452        inst = inst_list[inst_idx]
453        if inst.opname == 'UNPACK_SEQUENCE':
454            return (self[i] for i in range(inst.argval))  # type: ignore[index]
455
456        return self.tracer.iter(self)
457
458    def __abs__(self):
459        return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
460
461    def __bool__(self) -> bool:
462        if self.tracer.trace_asserts:
463            # check if this boolean is used in an assertion, bytecode pattern for assertions
464            # is pretty stable for Python 3.7--3.9
465            frame = inspect.currentframe()
466            assert frame is not None
467            calling_frame = frame.f_back
468            assert calling_frame is not None
469            insts = list(dis.get_instructions(calling_frame.f_code))
470            if sys.version_info >= (3, 11):
471                from bisect import bisect_left
472                cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
473            else:
474                cur = calling_frame.f_lasti // 2
475            inst = insts[cur]
476
477            if inst.opname == 'POP_JUMP_IF_TRUE':
478                first = insts[cur + 1]
479                assert inst.arg is not None
480                last = insts[inst.arg // 2 - 1]
481                starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
482                                      or first.opname == 'LOAD_ASSERTION_ERROR')
483                if starts_with_assert and last.opname == 'RAISE_VARARGS':
484                    self.tracer.create_proxy('call_function', assert_fn, (self,), {})
485                    return True
486
487        return self.tracer.to_bool(self)
488
489    @compatibility(is_backward_compatible=True)
490    def keys(self):
491        return self.tracer.keys(self)
492
493    def __len__(self):
494        raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
495                           "this call to be recorded, please call torch.fx.wrap('len') at "
496                           "module scope")
497
498    @classmethod
499    def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
500        args = args if args else ()
501        kwargs = kwargs if kwargs else {}
502
503        tracers : Dict[Any, None] = {}
504
505        def find_tracer(a):
506            if isinstance(a, cls):
507                tracers[a.tracer] = None
508        torch.fx.node.map_aggregate(args, find_tracer)
509        torch.fx.node.map_aggregate(kwargs, find_tracer)
510
511        if len(tracers) > 1:
512            raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
513                               f'trying to trace operations {orig_method}')
514        tracer = next(iter(tracers.keys()))
515
516        if isinstance(orig_method, torch._C.ScriptMethod):
517            args = (orig_method.owner,) + args
518            return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
519        if torch.overrides.is_tensor_method_or_property(orig_method):
520            return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
521        else:
522            if isinstance(orig_method, torch._ops.HigherOrderOperator):
523                # TODO: Define how to symbolically trace HigherOrderOperators
524                raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
525            return tracer.create_proxy('call_function', orig_method, args, kwargs,
526                                       name=tracer.graph._target_to_str(orig_method.__name__))
527
528
529@compatibility(is_backward_compatible=True)
530class Attribute(Proxy):
531    @compatibility(is_backward_compatible=True)
532    def __init__(self, root: Proxy, attr: str):
533        self.root = root
534        self.attr = attr
535        self.tracer = root.tracer
536        self._node: Optional[Node] = None
537
538    @property
539    def node(self):
540        # the node for attributes is added lazily, since most will just be method calls
541        # which do not rely on the getitem call
542        if self._node is None:
543            self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
544        return self._node
545
546    def __call__(self, *args, **kwargs):
547        return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
548
549
550@compatibility(is_backward_compatible=False)
551class ParameterProxy(Proxy):
552    """
553    A special proxy which lets "shape", "size", "dim", and a few other
554    attribute accesses pass through to the underlying  module parameter object,
555    so that conditional tests on these attributes will not throw exception during tracing
556    """
557    def __init__(self, tracer: TracerBase, node: Node, name, param):
558        super().__init__(node, tracer)
559        assert isinstance(param, torch.nn.Parameter)
560        self.param = param
561        self.name = name
562
563    def __repr__(self) -> str:
564        return f'ParameterProxy({self.name})'
565
566    @property
567    def shape(self):
568        return self.param.shape
569
570    def size(self):
571        return self.param.size()
572
573    def dim(self):
574        return self.param.dim()
575
576    @property
577    def ndim(self):
578        return self.param.ndim
579
580    def numel(self):
581        return self.param.numel()
582
583    def nelement(self):
584        return self.param.nelement()
585
586
587for method in magic_methods:
588    def _scope(method):
589        def impl(*args, **kwargs):
590            tracer = args[0].tracer
591            target = getattr(operator, method)
592            return tracer.create_proxy('call_function', target, args, kwargs)
593        impl.__name__ = method
594        as_magic = f'__{method.strip("_")}__'
595        setattr(Proxy, as_magic, impl)
596    _scope(method)
597
598def _define_reflectable(orig_method_name):
599    method_name = f'__r{orig_method_name.strip("_")}__'
600
601    def impl(self, rhs):
602        target = getattr(operator, orig_method_name)
603        return self.tracer.create_proxy('call_function', target, (rhs, self), {})
604    impl.__name__ = method_name
605    impl.__qualname__ = method_name
606    setattr(Proxy, method_name, impl)
607
608for orig_method_name in reflectable_magic_methods:
609    _define_reflectable(orig_method_name)
610