xref: /aosp_15_r20/external/pytorch/torch/_export/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import dataclasses
4import functools
5import io
6import json
7import logging
8import os
9import re
10import sys
11import types
12import warnings
13import weakref
14import zipfile
15from collections import OrderedDict
16from contextlib import contextmanager
17from functools import lru_cache
18
19from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20from unittest.mock import patch
21
22import torch
23import torch.fx
24import torch.utils._pytree as pytree
25
26from torch._dispatch.python import enable_python_dispatcher
27from torch._utils_internal import log_export_usage
28from torch.export._tree_utils import reorder_kwargs
29from torch.export.graph_signature import (
30    ArgumentSpec,
31    ConstantArgument,
32    ExportGraphSignature,
33    InputKind,
34    InputSpec,
35    OutputKind,
36    OutputSpec,
37    SymIntArgument,
38    TensorArgument,
39)
40from torch.fx import traceback as fx_traceback
41from torch.fx._compatibility import compatibility
42from torch.fx.experimental.proxy_tensor import make_fx
43from torch._subclasses.fake_tensor import unset_fake_temporarily
44from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
45
46from .wrappers import _wrap_submodules
47
48log = logging.getLogger(__name__)
49
50@dataclasses.dataclass
51class ExportDynamoConfig:
52    """
53    Manage Export-specific configurations of Dynamo.
54    """
55    allow_rnn: bool = True
56
57
58# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
59# is called multiple times.
60@lru_cache
61def capture_pre_autograd_graph_warning():
62    from torch._inductor import config
63
64    log.warning("+============================+")
65    log.warning("|     !!!   WARNING   !!!    |")
66    log.warning("+============================+")
67    log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
68    log.warning("Please switch to use torch.export.export_for_training instead.")
69    if config.is_fbcode():
70        log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.")  # noqa: B950
71
72
73@compatibility(is_backward_compatible=False)
74def capture_pre_autograd_graph(
75    f: torch.nn.Module,
76    args: Tuple[Any],
77    kwargs: Optional[Dict[str, Any]] = None,
78    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
79) -> torch.nn.Module:
80    """
81    A helper function that is intended to trace a module before any pre-autograd
82    decomposition is run. The produced module will be "non-functional" and
83    composed of aten operators. Later this API will be deleted in favor of more general
84    torch.export API.
85
86    Args:
87      f: nn.Module to be traced
88
89      args: example positional inputs.
90
91      kwargs: optional example keyword inputs.
92
93      dynamic_shapes: Should either be:
94         1) a dict from argument names of ``f`` to their dynamic shape specifications,
95         2) a tuple that specifies dynamic shape specifications for each input in original order.
96         If you are specifying dynamism on keyword args, you will need to pass them in the order that
97         is defined in the original function signature.
98
99         The dynamic shape of a tensor argument can be specified as either
100         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
101         not required to include static dimension indices in this dict, but when they are,
102         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
103         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
104         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
105         recursively specified by using mappings or sequences of contained specifications.
106
107    Returns:
108        An nn.Module containing the traced method.
109
110    """
111    from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
112    from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
113    from torch._export.non_strict_utils import make_constraints
114    from torch._subclasses.functional_tensor import FunctionalTensor
115    from torch.export._unlift import _create_stateful_graph_module
116    from torch.export.dynamic_shapes import _combine_args
117
118    capture_pre_autograd_graph_warning()
119
120    if sys.platform == "win32":
121        raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
122
123    assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
124
125    if kwargs is None:
126        kwargs = {}
127
128    if capture_pre_autograd_graph_using_training_ir():
129        @lru_cache
130        def print_export_warning():
131            log.warning("Using torch.export.export_for_training(...,strict=True)")
132        print_export_warning()
133        module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
134    else:
135        log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
136
137        # Do not decompose dropout for exported models, because in eval mode the dropout
138        # op disappears from the graph, which makes it difficult to switch to train mode.
139        # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
140        decomp_table = {
141            op: op.decompose
142            for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
143            if op != torch.ops.aten.dropout.default
144        }
145        with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
146            m = torch._dynamo.export(
147                f,
148                dynamic_shapes=dynamic_shapes,
149                assume_static_by_default=True,
150                tracing_mode="symbolic",
151                decomposition_table=decomp_table,
152                pre_dispatch=True,
153                aten_graph=True,
154                _log_export_usage=False,
155            )(
156                *args,
157                **kwargs,
158            )[0]
159
160            _, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
161
162            m.meta["inline_constraints"] = {
163                k: v
164                for k, v in fake_mode.shape_env.var_to_range.items()
165                if re.match(r"^[if]\d+$", str(k))
166            }
167
168            if isinstance(f, torch.nn.Module):
169                from torch.export._trace import _restore_state_dict
170                _restore_state_dict(f, m)
171
172            flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
173            combined_args = _combine_args(f, args, kwargs)
174            range_constraints = make_constraints(
175                fake_mode,
176                m,
177                combined_args,
178                dynamic_shapes,
179                0,
180            )
181
182            module = _create_stateful_graph_module(
183                m,
184                range_constraints=range_constraints,
185            )
186
187    error_message = \
188        """
189        Calling train() or eval() is not supported for exported models.
190        Alternatively, you may override these methods to do custom user behavior as follows:
191
192            def _my_train(self, mode: bool = True):
193                ...
194
195            def _my_eval(self):
196                ...
197
198            model.train = types.MethodType(_my_train, model)
199            model.eval = types.MethodType(_my_eval, model)
200        """
201
202    def _train(self, mode: bool = True):
203        raise NotImplementedError(error_message)
204
205    def _eval(self, mode: bool = True):
206        raise NotImplementedError(error_message)
207
208    module.train = types.MethodType(_train, module)  # type: ignore[method-assign]
209    module.eval = types.MethodType(_eval, module)  # type: ignore[method-assign]
210
211    # Remove Proxy because they cannot be deepcopied or pickled.
212    if hasattr(module, "_buffers"):
213        torch._export.utils.remove_proxy_from_state_dict(
214            module._buffers, in_place=True
215        )
216    return module
217
218
219def aot_compile(
220    f: Callable,
221    args: Tuple[Any],
222    kwargs: Optional[Dict[str, Any]] = None,
223    *,
224    dynamic_shapes: Optional[Dict[str, Any]] = None,
225    options: Optional[Dict[str, Any]] = None,
226    remove_runtime_assertions: bool = False,
227    disable_constraint_solver: bool = False,
228    same_signature: bool = True,
229) -> str:
230    """
231    Note: this function is not stable yet
232
233    Traces either an nn.Module's forward function or just a callable with PyTorch
234    operations inside, generates executable cpp code from the program, and returns
235    the path to the generated shared library
236
237    Args:
238        f: the `nn.Module` or callable to trace.
239
240        args: example positional inputs.
241
242        kwargs: optional example keyword inputs.
243
244        dynamic_shapes: Should either be:
245            1) a dict from argument names of ``f`` to their dynamic shape specifications,
246            2) a tuple that specifies dynamic shape specifications for each input in original order.
247            If you are specifying dynamism on keyword args, you will need to pass them in the order that
248            is defined in the original function signature.
249
250            The dynamic shape of a tensor argument can be specified as either
251            (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
252            not required to include static dimension indices in this dict, but when they are,
253            they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
254            where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
255            are denoted by None. Arguments that are dicts or tuples / lists of tensors are
256            recursively specified by using mappings or sequences of contained specifications.
257
258        options: A dictionary of options to control inductor
259
260        disable_constraint_solver: Whether the dim constraint solver must be disabled.
261
262    Returns:
263        Path to the generated shared library
264    """
265    from torch.export._trace import _export_to_torch_ir
266    from torch._inductor.decomposition import select_decomp_table
267    from torch._inductor import config
268
269    if config.is_predispatch:
270        gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
271    else:
272        # We want to export to Torch IR here to utilize the pre_grad passes in
273        # inductor, which run on Torch IR.
274        gm = _export_to_torch_ir(
275            f,
276            args,
277            kwargs,
278            dynamic_shapes,
279            disable_constraint_solver=disable_constraint_solver,
280            same_signature=same_signature,
281            # Disabling this flag, because instead we can rely on the mapping
282            # dynamo_flat_name_to_original_fqn which is coming from Dynamo.
283            restore_fqn=False,
284        )
285
286    with torch.no_grad():
287        so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options)  # type: ignore[arg-type]
288
289    return so_path
290
291def aot_load(so_path: str, device: str) -> Callable:
292    """
293    Loads a shared library generated by aot_compile and returns a callable
294
295    Args:
296        so_path: Path to the shared library
297
298    Returns:
299        A callable
300    """
301    if device == "cpu":
302        runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)  # type: ignore[call-arg]
303    elif device == "cuda" or device.startswith("cuda:"):
304        runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)  # type: ignore[assignment, call-arg]
305    else:
306        raise RuntimeError("Unsupported device " + device)
307
308    def optimized(*args, **kwargs):
309        call_spec = runner.get_call_spec()  # type: ignore[attr-defined]
310        in_spec = pytree.treespec_loads(call_spec[0])
311        out_spec = pytree.treespec_loads(call_spec[1])
312        flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
313        flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
314        flat_outputs = runner.run(flat_inputs)  # type: ignore[attr-defined]
315        return pytree.tree_unflatten(flat_outputs, out_spec)
316
317    return optimized
318