xref: /aosp_15_r20/external/pytorch/torch/export/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import builtins
2import copy
3import dataclasses
4import inspect
5import io
6import os
7import sys
8import typing
9import warnings
10import zipfile
11from enum import auto, Enum
12from typing import (
13    Any,
14    Callable,
15    Dict,
16    Iterator,
17    List,
18    Optional,
19    Tuple,
20    Type,
21    TYPE_CHECKING,
22    Union,
23)
24
25import torch
26import torch.utils._pytree as pytree
27from torch.fx._compatibility import compatibility
28from torch.fx.passes.infra.pass_base import PassResult
29from torch.fx.passes.infra.pass_manager import PassManager
30from torch.utils._pytree import (
31    FlattenFunc,
32    FromDumpableContextFn,
33    ToDumpableContextFn,
34    UnflattenFunc,
35)
36
37
38if TYPE_CHECKING:
39    # Import the following modules during type checking to enable code intelligence features,
40    # Do not import unconditionally, as they import sympy and importing sympy is very slow
41    from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
42
43
44__all__ = [
45    "Constraint",
46    "Dim",
47    "ExportBackwardSignature",
48    "ExportGraphSignature",
49    "ExportedProgram",
50    "ModuleCallEntry",
51    "ModuleCallSignature",
52    "dims",
53    "export",
54    "export_for_training",
55    "load",
56    "register_dataclass",
57    "save",
58    "unflatten",
59    "FlatArgsAdapter",
60    "UnflattenedModule",
61]
62
63
64from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
65from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
66from .graph_signature import ExportBackwardSignature, ExportGraphSignature
67from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
68
69
70PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
71
72
73def export_for_training(
74    mod: torch.nn.Module,
75    args: Tuple[Any, ...],
76    kwargs: Optional[Dict[str, Any]] = None,
77    *,
78    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
79    strict: bool = True,
80    preserve_module_call_signature: Tuple[str, ...] = (),
81) -> ExportedProgram:
82    """
83    :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
84    only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
85    which can subsequently be executed with different inputs or serialized. The
86    traced graph (1) produces normalized operators in the all ATen operator set
87    (as well as any user-specified custom operators), (2) has eliminated all Python control
88    flow and data structures (with certain exceptions), and (3) records the set of
89    shape constraints needed to show that this normalization and control-flow elimination
90    is sound for future inputs. This API is intended for PT2 quantization training use cases
91    and will soon be the default IR of torch.export.export in the near future.
92
93    **Soundness Guarantee**
94
95    See :func:`export()` docstring for more details.
96
97    Args:
98        mod: We will trace the forward method of this module.
99
100        args: Example positional inputs.
101
102        kwargs: Optional example keyword inputs.
103
104        dynamic_shapes:
105         An optional argument where the type should either be:
106         1) a dict from argument names of ``f`` to their dynamic shape specifications,
107         2) a tuple that specifies dynamic shape specifications for each input in original order.
108         If you are specifying dynamism on keyword args, you will need to pass them in the order that
109         is defined in the original function signature.
110
111         The dynamic shape of a tensor argument can be specified as either
112         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
113         not required to include static dimension indices in this dict, but when they are,
114         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
115         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
116         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
117         recursively specified by using mappings or sequences of contained specifications.
118
119        strict: When enabled (default), the export function will trace the program through
120         TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
121         exported program will not validate the implicit assumptions baked into the graph and
122         may cause behavior divergence between the original model and the exported one. This is
123         useful when users need to workaround bugs in the tracer, or simply want incrementally
124         enable safety in their models. Note that this does not affect the resulting IR spec
125         to be different and the model will be serialized in the same way regardless of what value
126         is passed here.
127         WARNING: This option is experimental and use this at your own risk.
128
129    Returns:
130        An :class:`ExportedProgram` containing the traced callable.
131
132    **Acceptable input/output types**
133
134    Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
135
136    - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
137    - Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
138    - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
139      ``OrderedDict`` containing all above types.
140
141    """
142    from ._trace import _export_for_training
143
144    if not isinstance(mod, torch.nn.Module):
145        raise ValueError(
146            f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
147        )
148    if isinstance(mod, torch.jit.ScriptModule):
149        raise ValueError(
150            "Exporting a ScriptModule is not supported. "
151            "Maybe try converting your ScriptModule to an ExportedProgram "
152            "using `TS2EPConverter(mod, args, kwargs).convert()` instead."
153        )
154    return _export_for_training(
155        mod,
156        args,
157        kwargs,
158        dynamic_shapes,
159        strict=strict,
160        preserve_module_call_signature=preserve_module_call_signature,
161    )
162
163
164def export(
165    mod: torch.nn.Module,
166    args: Tuple[Any, ...],
167    kwargs: Optional[Dict[str, Any]] = None,
168    *,
169    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
170    strict: bool = True,
171    preserve_module_call_signature: Tuple[str, ...] = (),
172) -> ExportedProgram:
173    """
174    :func:`export` takes an arbitrary Python callable (an nn.Module, a function or
175    a method) along with example inputs, and produces a traced graph representing
176    only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
177    which can subsequently be executed with different inputs or serialized.  The
178    traced graph (1) produces normalized operators in the functional ATen operator set
179    (as well as any user-specified custom operators), (2) has eliminated all Python control
180    flow and data structures (with certain exceptions), and (3) records the set of
181    shape constraints needed to show that this normalization and control-flow elimination
182    is sound for future inputs.
183
184    **Soundness Guarantee**
185
186    While tracing, :func:`export()` takes note of shape-related assumptions
187    made by the user program and the underlying PyTorch operator kernels.
188    The output :class:`ExportedProgram` is considered valid only when these
189    assumptions hold true.
190
191    Tracing makes assumptions on the shapes (not values) of input tensors.
192    Such assumptions must be validated at graph capture time for :func:`export`
193    to succeed. Specifically:
194
195    - Assumptions on static shapes of input tensors are automatically validated without additional effort.
196    - Assumptions on dynamic shape of input tensors require explicit specification
197      by using the :func:`Dim` API to construct dynamic dimensions and by associating
198      them with example inputs through the ``dynamic_shapes`` argument.
199
200    If any assumption can not be validated, a fatal error will be raised. When that happens,
201    the error message will include suggested fixes to the specification that are needed
202    to validate the assumptions. For example :func:`export` might suggest the
203    following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the
204    shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``::
205
206        dim = Dim("dim0_x", max=5)
207
208    This example means the generated code requires dimension 0 of input ``x`` to be less
209    than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension
210    definitions and then copy them verbatim into your code without needing to change the
211    ``dynamic_shapes`` argument to your :func:`export` call.
212
213    Args:
214        mod: We will trace the forward method of this module.
215
216        args: Example positional inputs.
217
218        kwargs: Optional example keyword inputs.
219
220        dynamic_shapes:
221         An optional argument where the type should either be:
222         1) a dict from argument names of ``f`` to their dynamic shape specifications,
223         2) a tuple that specifies dynamic shape specifications for each input in original order.
224         If you are specifying dynamism on keyword args, you will need to pass them in the order that
225         is defined in the original function signature.
226
227         The dynamic shape of a tensor argument can be specified as either
228         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
229         not required to include static dimension indices in this dict, but when they are,
230         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
231         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
232         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
233         recursively specified by using mappings or sequences of contained specifications.
234
235        strict: When enabled (default), the export function will trace the program through
236         TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
237         exported program will not validate the implicit assumptions baked into the graph and
238         may cause behavior divergence between the original model and the exported one. This is
239         useful when users need to workaround bugs in the tracer, or simply want incrementally
240         enable safety in their models. Note that this does not affect the resulting IR spec
241         to be different and the model will be serialized in the same way regardless of what value
242         is passed here.
243         WARNING: This option is experimental and use this at your own risk.
244
245    Returns:
246        An :class:`ExportedProgram` containing the traced callable.
247
248    **Acceptable input/output types**
249
250    Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
251
252    - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
253    - Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
254    - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
255      ``OrderedDict`` containing all above types.
256
257    """
258    from ._trace import _export
259
260    if not isinstance(mod, torch.nn.Module):
261        raise ValueError(
262            f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
263        )
264    if isinstance(mod, torch.jit.ScriptModule):
265        raise ValueError(
266            "Exporting a ScriptModule is not supported. "
267            "Maybe try converting your ScriptModule to an ExportedProgram "
268            "using `TS2EPConverter(mod, args, kwargs).convert()` instead."
269        )
270    return _export(
271        mod,
272        args,
273        kwargs,
274        dynamic_shapes,
275        strict=strict,
276        preserve_module_call_signature=preserve_module_call_signature,
277        pre_dispatch=True,
278    )
279
280
281def save(
282    ep: ExportedProgram,
283    f: Union[str, os.PathLike, io.BytesIO],
284    *,
285    extra_files: Optional[Dict[str, Any]] = None,
286    opset_version: Optional[Dict[str, int]] = None,
287) -> None:
288    """
289
290    .. warning::
291        Under active development, saved files may not be usable in newer versions
292        of PyTorch.
293
294    Saves an :class:`ExportedProgram` to a file-like object. It can then be
295    loaded using the Python API :func:`torch.export.load <torch.export.load>`.
296
297    Args:
298        ep (ExportedProgram): The exported program to save.
299
300        f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
301         implement write and flush) or a string containing a file name.
302
303        extra_files (Optional[Dict[str, Any]]): Map from filename to contents
304         which will be stored as part of f.
305
306        opset_version (Optional[Dict[str, int]]): A map of opset names
307         to the version of this opset
308
309
310    Example::
311
312        import torch
313        import io
314
315        class MyModule(torch.nn.Module):
316            def forward(self, x):
317                return x + 10
318
319        ep = torch.export.export(MyModule(), (torch.randn(5),))
320
321        # Save to file
322        torch.export.save(ep, 'exported_program.pt2')
323
324        # Save to io.BytesIO buffer
325        buffer = io.BytesIO()
326        torch.export.save(ep, buffer)
327
328        # Save with extra files
329        extra_files = {'foo.txt': b'bar'.decode('utf-8')}
330        torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
331
332    """
333    if not isinstance(ep, ExportedProgram):
334        raise TypeError(
335            f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead."
336        )
337
338    from torch._export.serde.schema import SCHEMA_VERSION
339    from torch._export.serde.serialize import serialize, SerializedArtifact
340
341    artifact: SerializedArtifact = serialize(ep, opset_version)
342
343    if isinstance(f, (str, os.PathLike)):
344        f = os.fspath(f)
345
346    with zipfile.ZipFile(f, "w") as zipf:
347        # Save every field in the SerializedArtifact to a file.
348        assert isinstance(artifact.exported_program, bytes)
349        zipf.writestr("serialized_exported_program.json", artifact.exported_program)
350        zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
351        zipf.writestr("serialized_constants.pt", artifact.constants)
352        zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs)
353
354        zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
355
356        # Add extra files if provided
357        if extra_files:
358            for extra_file_name, content in extra_files.items():
359                encoded_content = content.encode("utf-8")
360                zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
361
362
363def load(
364    f: Union[str, os.PathLike, io.BytesIO],
365    *,
366    extra_files: Optional[Dict[str, Any]] = None,
367    expected_opset_version: Optional[Dict[str, int]] = None,
368) -> ExportedProgram:
369    """
370
371    .. warning::
372        Under active development, saved files may not be usable in newer versions
373        of PyTorch.
374
375    Loads an :class:`ExportedProgram` previously saved with
376    :func:`torch.export.save <torch.export.save>`.
377
378    Args:
379        ep (ExportedProgram): The exported program to save.
380
381        f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
382         implement write and flush) or a string containing a file name.
383
384        extra_files (Optional[Dict[str, Any]]): The extra filenames given in
385         this map would be loaded and their content would be stored in the
386         provided map.
387
388        expected_opset_version (Optional[Dict[str, int]]): A map of opset names
389         to expected opset versions
390
391    Returns:
392        An :class:`ExportedProgram` object
393
394    Example::
395
396        import torch
397        import io
398
399        # Load ExportedProgram from file
400        ep = torch.export.load('exported_program.pt2')
401
402        # Load ExportedProgram from io.BytesIO object
403        with open('exported_program.pt2', 'rb') as f:
404            buffer = io.BytesIO(f.read())
405        buffer.seek(0)
406        ep = torch.export.load(buffer)
407
408        # Load with extra files.
409        extra_files = {'foo.txt': ''}  # values will be replaced with data
410        ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
411        print(extra_files['foo.txt'])
412        print(ep(torch.randn(5)))
413    """
414    if isinstance(f, (str, os.PathLike)):
415        f = os.fspath(f)
416
417    extra_files = extra_files or {}
418
419    with zipfile.ZipFile(f, "r") as zipf:
420        # Check the version
421        version = zipf.read("version").decode().split(".")
422        from torch._export.serde.schema import SCHEMA_VERSION
423
424        assert len(version) == len(SCHEMA_VERSION)
425        if version[0] != str(SCHEMA_VERSION[0]):
426            raise RuntimeError(
427                f"Serialized version {version} does not match our current "
428                f"schema version {SCHEMA_VERSION}."
429            )
430
431        from torch._export.serde.serialize import deserialize, SerializedArtifact
432
433        # Load serialized_ep and serialized_state_dict from the zip file
434
435        serialized_exported_program: Optional[bytes] = None
436        serialized_state_dict: Optional[bytes] = None
437        serialized_constants: Optional[bytes] = None
438        serialized_example_inputs: Optional[bytes] = None
439
440        for file_info in zipf.infolist():
441            file_content = zipf.read(file_info.filename)
442
443            if file_info.filename == "serialized_exported_program.json":
444                serialized_exported_program = file_content
445            elif file_info.filename == "serialized_state_dict.json":
446                warnings.warn("This version of file is deprecated")
447                serialized_state_dict = file_content
448            elif file_info.filename == "serialized_constants.json":
449                warnings.warn("This version of file is deprecated")
450                serialized_constants = file_content
451            elif file_info.filename == "serialized_state_dict.pt":
452                serialized_state_dict = file_content
453            elif file_info.filename == "serialized_constants.pt":
454                serialized_constants = file_content
455            elif file_info.filename == "serialized_example_inputs.pt":
456                serialized_example_inputs = file_content
457            elif file_info.filename.startswith("extra_files"):
458                filename = file_info.filename.split("/", 1)[1]
459                extra_files[filename] = file_content.decode("utf-8")
460
461        assert serialized_exported_program is not None
462        assert serialized_state_dict is not None
463        assert serialized_constants is not None
464        assert serialized_example_inputs is not None
465        artifact: SerializedArtifact = SerializedArtifact(
466            serialized_exported_program,
467            serialized_state_dict,
468            serialized_constants,
469            serialized_example_inputs,
470        )
471
472        # Deserialize ExportedProgram
473        ep = deserialize(artifact, expected_opset_version)
474
475        return ep
476
477
478def register_dataclass(
479    cls: Type[Any],
480    *,
481    serialized_type_name: Optional[str] = None,
482) -> None:
483    """
484    Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
485
486    Args:
487        cls: the dataclass type to register
488        serialized_type_name: The serialized name for the dataclass. This is
489        required if you want to serialize the pytree TreeSpec containing this
490        dataclass.
491
492    Example::
493
494        @dataclass
495        class InputDataClass:
496            feature: torch.Tensor
497            bias: int
498
499        class OutputDataClass:
500            res: torch.Tensor
501
502        torch.export.register_dataclass(InputDataClass)
503        torch.export.register_dataclass(OutputDataClass)
504
505        def fn(o: InputDataClass) -> torch.Tensor:
506            res = res=o.feature + o.bias
507            return OutputDataClass(res=res)
508
509        ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
510        print(ep)
511
512    """
513
514    from torch._export.utils import register_dataclass_as_pytree_node
515
516    return register_dataclass_as_pytree_node(
517        cls, serialized_type_name=serialized_type_name
518    )
519