1import builtins 2import copy 3import dataclasses 4import inspect 5import io 6import os 7import sys 8import typing 9import warnings 10import zipfile 11from enum import auto, Enum 12from typing import ( 13 Any, 14 Callable, 15 Dict, 16 Iterator, 17 List, 18 Optional, 19 Tuple, 20 Type, 21 TYPE_CHECKING, 22 Union, 23) 24 25import torch 26import torch.utils._pytree as pytree 27from torch.fx._compatibility import compatibility 28from torch.fx.passes.infra.pass_base import PassResult 29from torch.fx.passes.infra.pass_manager import PassManager 30from torch.utils._pytree import ( 31 FlattenFunc, 32 FromDumpableContextFn, 33 ToDumpableContextFn, 34 UnflattenFunc, 35) 36 37 38if TYPE_CHECKING: 39 # Import the following modules during type checking to enable code intelligence features, 40 # Do not import unconditionally, as they import sympy and importing sympy is very slow 41 from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint 42 43 44__all__ = [ 45 "Constraint", 46 "Dim", 47 "ExportBackwardSignature", 48 "ExportGraphSignature", 49 "ExportedProgram", 50 "ModuleCallEntry", 51 "ModuleCallSignature", 52 "dims", 53 "export", 54 "export_for_training", 55 "load", 56 "register_dataclass", 57 "save", 58 "unflatten", 59 "FlatArgsAdapter", 60 "UnflattenedModule", 61] 62 63 64from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection 65from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature 66from .graph_signature import ExportBackwardSignature, ExportGraphSignature 67from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule 68 69 70PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] 71 72 73def export_for_training( 74 mod: torch.nn.Module, 75 args: Tuple[Any, ...], 76 kwargs: Optional[Dict[str, Any]] = None, 77 *, 78 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 79 strict: bool = True, 80 preserve_module_call_signature: Tuple[str, ...] = (), 81) -> ExportedProgram: 82 """ 83 :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing 84 only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, 85 which can subsequently be executed with different inputs or serialized. The 86 traced graph (1) produces normalized operators in the all ATen operator set 87 (as well as any user-specified custom operators), (2) has eliminated all Python control 88 flow and data structures (with certain exceptions), and (3) records the set of 89 shape constraints needed to show that this normalization and control-flow elimination 90 is sound for future inputs. This API is intended for PT2 quantization training use cases 91 and will soon be the default IR of torch.export.export in the near future. 92 93 **Soundness Guarantee** 94 95 See :func:`export()` docstring for more details. 96 97 Args: 98 mod: We will trace the forward method of this module. 99 100 args: Example positional inputs. 101 102 kwargs: Optional example keyword inputs. 103 104 dynamic_shapes: 105 An optional argument where the type should either be: 106 1) a dict from argument names of ``f`` to their dynamic shape specifications, 107 2) a tuple that specifies dynamic shape specifications for each input in original order. 108 If you are specifying dynamism on keyword args, you will need to pass them in the order that 109 is defined in the original function signature. 110 111 The dynamic shape of a tensor argument can be specified as either 112 (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is 113 not required to include static dimension indices in this dict, but when they are, 114 they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, 115 where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions 116 are denoted by None. Arguments that are dicts or tuples / lists of tensors are 117 recursively specified by using mappings or sequences of contained specifications. 118 119 strict: When enabled (default), the export function will trace the program through 120 TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the 121 exported program will not validate the implicit assumptions baked into the graph and 122 may cause behavior divergence between the original model and the exported one. This is 123 useful when users need to workaround bugs in the tracer, or simply want incrementally 124 enable safety in their models. Note that this does not affect the resulting IR spec 125 to be different and the model will be serialized in the same way regardless of what value 126 is passed here. 127 WARNING: This option is experimental and use this at your own risk. 128 129 Returns: 130 An :class:`ExportedProgram` containing the traced callable. 131 132 **Acceptable input/output types** 133 134 Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: 135 136 - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. 137 - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. 138 - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and 139 ``OrderedDict`` containing all above types. 140 141 """ 142 from ._trace import _export_for_training 143 144 if not isinstance(mod, torch.nn.Module): 145 raise ValueError( 146 f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." 147 ) 148 if isinstance(mod, torch.jit.ScriptModule): 149 raise ValueError( 150 "Exporting a ScriptModule is not supported. " 151 "Maybe try converting your ScriptModule to an ExportedProgram " 152 "using `TS2EPConverter(mod, args, kwargs).convert()` instead." 153 ) 154 return _export_for_training( 155 mod, 156 args, 157 kwargs, 158 dynamic_shapes, 159 strict=strict, 160 preserve_module_call_signature=preserve_module_call_signature, 161 ) 162 163 164def export( 165 mod: torch.nn.Module, 166 args: Tuple[Any, ...], 167 kwargs: Optional[Dict[str, Any]] = None, 168 *, 169 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 170 strict: bool = True, 171 preserve_module_call_signature: Tuple[str, ...] = (), 172) -> ExportedProgram: 173 """ 174 :func:`export` takes an arbitrary Python callable (an nn.Module, a function or 175 a method) along with example inputs, and produces a traced graph representing 176 only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, 177 which can subsequently be executed with different inputs or serialized. The 178 traced graph (1) produces normalized operators in the functional ATen operator set 179 (as well as any user-specified custom operators), (2) has eliminated all Python control 180 flow and data structures (with certain exceptions), and (3) records the set of 181 shape constraints needed to show that this normalization and control-flow elimination 182 is sound for future inputs. 183 184 **Soundness Guarantee** 185 186 While tracing, :func:`export()` takes note of shape-related assumptions 187 made by the user program and the underlying PyTorch operator kernels. 188 The output :class:`ExportedProgram` is considered valid only when these 189 assumptions hold true. 190 191 Tracing makes assumptions on the shapes (not values) of input tensors. 192 Such assumptions must be validated at graph capture time for :func:`export` 193 to succeed. Specifically: 194 195 - Assumptions on static shapes of input tensors are automatically validated without additional effort. 196 - Assumptions on dynamic shape of input tensors require explicit specification 197 by using the :func:`Dim` API to construct dynamic dimensions and by associating 198 them with example inputs through the ``dynamic_shapes`` argument. 199 200 If any assumption can not be validated, a fatal error will be raised. When that happens, 201 the error message will include suggested fixes to the specification that are needed 202 to validate the assumptions. For example :func:`export` might suggest the 203 following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the 204 shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``:: 205 206 dim = Dim("dim0_x", max=5) 207 208 This example means the generated code requires dimension 0 of input ``x`` to be less 209 than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension 210 definitions and then copy them verbatim into your code without needing to change the 211 ``dynamic_shapes`` argument to your :func:`export` call. 212 213 Args: 214 mod: We will trace the forward method of this module. 215 216 args: Example positional inputs. 217 218 kwargs: Optional example keyword inputs. 219 220 dynamic_shapes: 221 An optional argument where the type should either be: 222 1) a dict from argument names of ``f`` to their dynamic shape specifications, 223 2) a tuple that specifies dynamic shape specifications for each input in original order. 224 If you are specifying dynamism on keyword args, you will need to pass them in the order that 225 is defined in the original function signature. 226 227 The dynamic shape of a tensor argument can be specified as either 228 (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is 229 not required to include static dimension indices in this dict, but when they are, 230 they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, 231 where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions 232 are denoted by None. Arguments that are dicts or tuples / lists of tensors are 233 recursively specified by using mappings or sequences of contained specifications. 234 235 strict: When enabled (default), the export function will trace the program through 236 TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the 237 exported program will not validate the implicit assumptions baked into the graph and 238 may cause behavior divergence between the original model and the exported one. This is 239 useful when users need to workaround bugs in the tracer, or simply want incrementally 240 enable safety in their models. Note that this does not affect the resulting IR spec 241 to be different and the model will be serialized in the same way regardless of what value 242 is passed here. 243 WARNING: This option is experimental and use this at your own risk. 244 245 Returns: 246 An :class:`ExportedProgram` containing the traced callable. 247 248 **Acceptable input/output types** 249 250 Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: 251 252 - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. 253 - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. 254 - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and 255 ``OrderedDict`` containing all above types. 256 257 """ 258 from ._trace import _export 259 260 if not isinstance(mod, torch.nn.Module): 261 raise ValueError( 262 f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." 263 ) 264 if isinstance(mod, torch.jit.ScriptModule): 265 raise ValueError( 266 "Exporting a ScriptModule is not supported. " 267 "Maybe try converting your ScriptModule to an ExportedProgram " 268 "using `TS2EPConverter(mod, args, kwargs).convert()` instead." 269 ) 270 return _export( 271 mod, 272 args, 273 kwargs, 274 dynamic_shapes, 275 strict=strict, 276 preserve_module_call_signature=preserve_module_call_signature, 277 pre_dispatch=True, 278 ) 279 280 281def save( 282 ep: ExportedProgram, 283 f: Union[str, os.PathLike, io.BytesIO], 284 *, 285 extra_files: Optional[Dict[str, Any]] = None, 286 opset_version: Optional[Dict[str, int]] = None, 287) -> None: 288 """ 289 290 .. warning:: 291 Under active development, saved files may not be usable in newer versions 292 of PyTorch. 293 294 Saves an :class:`ExportedProgram` to a file-like object. It can then be 295 loaded using the Python API :func:`torch.export.load <torch.export.load>`. 296 297 Args: 298 ep (ExportedProgram): The exported program to save. 299 300 f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to 301 implement write and flush) or a string containing a file name. 302 303 extra_files (Optional[Dict[str, Any]]): Map from filename to contents 304 which will be stored as part of f. 305 306 opset_version (Optional[Dict[str, int]]): A map of opset names 307 to the version of this opset 308 309 310 Example:: 311 312 import torch 313 import io 314 315 class MyModule(torch.nn.Module): 316 def forward(self, x): 317 return x + 10 318 319 ep = torch.export.export(MyModule(), (torch.randn(5),)) 320 321 # Save to file 322 torch.export.save(ep, 'exported_program.pt2') 323 324 # Save to io.BytesIO buffer 325 buffer = io.BytesIO() 326 torch.export.save(ep, buffer) 327 328 # Save with extra files 329 extra_files = {'foo.txt': b'bar'.decode('utf-8')} 330 torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) 331 332 """ 333 if not isinstance(ep, ExportedProgram): 334 raise TypeError( 335 f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." 336 ) 337 338 from torch._export.serde.schema import SCHEMA_VERSION 339 from torch._export.serde.serialize import serialize, SerializedArtifact 340 341 artifact: SerializedArtifact = serialize(ep, opset_version) 342 343 if isinstance(f, (str, os.PathLike)): 344 f = os.fspath(f) 345 346 with zipfile.ZipFile(f, "w") as zipf: 347 # Save every field in the SerializedArtifact to a file. 348 assert isinstance(artifact.exported_program, bytes) 349 zipf.writestr("serialized_exported_program.json", artifact.exported_program) 350 zipf.writestr("serialized_state_dict.pt", artifact.state_dict) 351 zipf.writestr("serialized_constants.pt", artifact.constants) 352 zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs) 353 354 zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION))) 355 356 # Add extra files if provided 357 if extra_files: 358 for extra_file_name, content in extra_files.items(): 359 encoded_content = content.encode("utf-8") 360 zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) 361 362 363def load( 364 f: Union[str, os.PathLike, io.BytesIO], 365 *, 366 extra_files: Optional[Dict[str, Any]] = None, 367 expected_opset_version: Optional[Dict[str, int]] = None, 368) -> ExportedProgram: 369 """ 370 371 .. warning:: 372 Under active development, saved files may not be usable in newer versions 373 of PyTorch. 374 375 Loads an :class:`ExportedProgram` previously saved with 376 :func:`torch.export.save <torch.export.save>`. 377 378 Args: 379 ep (ExportedProgram): The exported program to save. 380 381 f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to 382 implement write and flush) or a string containing a file name. 383 384 extra_files (Optional[Dict[str, Any]]): The extra filenames given in 385 this map would be loaded and their content would be stored in the 386 provided map. 387 388 expected_opset_version (Optional[Dict[str, int]]): A map of opset names 389 to expected opset versions 390 391 Returns: 392 An :class:`ExportedProgram` object 393 394 Example:: 395 396 import torch 397 import io 398 399 # Load ExportedProgram from file 400 ep = torch.export.load('exported_program.pt2') 401 402 # Load ExportedProgram from io.BytesIO object 403 with open('exported_program.pt2', 'rb') as f: 404 buffer = io.BytesIO(f.read()) 405 buffer.seek(0) 406 ep = torch.export.load(buffer) 407 408 # Load with extra files. 409 extra_files = {'foo.txt': ''} # values will be replaced with data 410 ep = torch.export.load('exported_program.pt2', extra_files=extra_files) 411 print(extra_files['foo.txt']) 412 print(ep(torch.randn(5))) 413 """ 414 if isinstance(f, (str, os.PathLike)): 415 f = os.fspath(f) 416 417 extra_files = extra_files or {} 418 419 with zipfile.ZipFile(f, "r") as zipf: 420 # Check the version 421 version = zipf.read("version").decode().split(".") 422 from torch._export.serde.schema import SCHEMA_VERSION 423 424 assert len(version) == len(SCHEMA_VERSION) 425 if version[0] != str(SCHEMA_VERSION[0]): 426 raise RuntimeError( 427 f"Serialized version {version} does not match our current " 428 f"schema version {SCHEMA_VERSION}." 429 ) 430 431 from torch._export.serde.serialize import deserialize, SerializedArtifact 432 433 # Load serialized_ep and serialized_state_dict from the zip file 434 435 serialized_exported_program: Optional[bytes] = None 436 serialized_state_dict: Optional[bytes] = None 437 serialized_constants: Optional[bytes] = None 438 serialized_example_inputs: Optional[bytes] = None 439 440 for file_info in zipf.infolist(): 441 file_content = zipf.read(file_info.filename) 442 443 if file_info.filename == "serialized_exported_program.json": 444 serialized_exported_program = file_content 445 elif file_info.filename == "serialized_state_dict.json": 446 warnings.warn("This version of file is deprecated") 447 serialized_state_dict = file_content 448 elif file_info.filename == "serialized_constants.json": 449 warnings.warn("This version of file is deprecated") 450 serialized_constants = file_content 451 elif file_info.filename == "serialized_state_dict.pt": 452 serialized_state_dict = file_content 453 elif file_info.filename == "serialized_constants.pt": 454 serialized_constants = file_content 455 elif file_info.filename == "serialized_example_inputs.pt": 456 serialized_example_inputs = file_content 457 elif file_info.filename.startswith("extra_files"): 458 filename = file_info.filename.split("/", 1)[1] 459 extra_files[filename] = file_content.decode("utf-8") 460 461 assert serialized_exported_program is not None 462 assert serialized_state_dict is not None 463 assert serialized_constants is not None 464 assert serialized_example_inputs is not None 465 artifact: SerializedArtifact = SerializedArtifact( 466 serialized_exported_program, 467 serialized_state_dict, 468 serialized_constants, 469 serialized_example_inputs, 470 ) 471 472 # Deserialize ExportedProgram 473 ep = deserialize(artifact, expected_opset_version) 474 475 return ep 476 477 478def register_dataclass( 479 cls: Type[Any], 480 *, 481 serialized_type_name: Optional[str] = None, 482) -> None: 483 """ 484 Registers a dataclass as a valid input/output type for :func:`torch.export.export`. 485 486 Args: 487 cls: the dataclass type to register 488 serialized_type_name: The serialized name for the dataclass. This is 489 required if you want to serialize the pytree TreeSpec containing this 490 dataclass. 491 492 Example:: 493 494 @dataclass 495 class InputDataClass: 496 feature: torch.Tensor 497 bias: int 498 499 class OutputDataClass: 500 res: torch.Tensor 501 502 torch.export.register_dataclass(InputDataClass) 503 torch.export.register_dataclass(OutputDataClass) 504 505 def fn(o: InputDataClass) -> torch.Tensor: 506 res = res=o.feature + o.bias 507 return OutputDataClass(res=res) 508 509 ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) 510 print(ep) 511 512 """ 513 514 from torch._export.utils import register_dataclass_as_pytree_node 515 516 return register_dataclass_as_pytree_node( 517 cls, serialized_type_name=serialized_type_name 518 ) 519