1# mypy: allow-untyped-defs 2import copy 3import dataclasses 4import functools 5import io 6import json 7import logging 8import os 9import re 10import sys 11import types 12import warnings 13import weakref 14import zipfile 15from collections import OrderedDict 16from contextlib import contextmanager 17from functools import lru_cache 18 19from typing import Any, Callable, Dict, List, Optional, Tuple, Union 20from unittest.mock import patch 21 22import torch 23import torch.fx 24import torch.utils._pytree as pytree 25 26from torch._dispatch.python import enable_python_dispatcher 27from torch._utils_internal import log_export_usage 28from torch.export._tree_utils import reorder_kwargs 29from torch.export.graph_signature import ( 30 ArgumentSpec, 31 ConstantArgument, 32 ExportGraphSignature, 33 InputKind, 34 InputSpec, 35 OutputKind, 36 OutputSpec, 37 SymIntArgument, 38 TensorArgument, 39) 40from torch.fx import traceback as fx_traceback 41from torch.fx._compatibility import compatibility 42from torch.fx.experimental.proxy_tensor import make_fx 43from torch._subclasses.fake_tensor import unset_fake_temporarily 44from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 45 46from .wrappers import _wrap_submodules 47 48log = logging.getLogger(__name__) 49 50@dataclasses.dataclass 51class ExportDynamoConfig: 52 """ 53 Manage Export-specific configurations of Dynamo. 54 """ 55 allow_rnn: bool = True 56 57 58# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph 59# is called multiple times. 60@lru_cache 61def capture_pre_autograd_graph_warning(): 62 from torch._inductor import config 63 64 log.warning("+============================+") 65 log.warning("| !!! WARNING !!! |") 66 log.warning("+============================+") 67 log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") 68 log.warning("Please switch to use torch.export.export_for_training instead.") 69 if config.is_fbcode(): 70 log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 71 72 73@compatibility(is_backward_compatible=False) 74def capture_pre_autograd_graph( 75 f: torch.nn.Module, 76 args: Tuple[Any], 77 kwargs: Optional[Dict[str, Any]] = None, 78 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 79) -> torch.nn.Module: 80 """ 81 A helper function that is intended to trace a module before any pre-autograd 82 decomposition is run. The produced module will be "non-functional" and 83 composed of aten operators. Later this API will be deleted in favor of more general 84 torch.export API. 85 86 Args: 87 f: nn.Module to be traced 88 89 args: example positional inputs. 90 91 kwargs: optional example keyword inputs. 92 93 dynamic_shapes: Should either be: 94 1) a dict from argument names of ``f`` to their dynamic shape specifications, 95 2) a tuple that specifies dynamic shape specifications for each input in original order. 96 If you are specifying dynamism on keyword args, you will need to pass them in the order that 97 is defined in the original function signature. 98 99 The dynamic shape of a tensor argument can be specified as either 100 (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is 101 not required to include static dimension indices in this dict, but when they are, 102 they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, 103 where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions 104 are denoted by None. Arguments that are dicts or tuples / lists of tensors are 105 recursively specified by using mappings or sequences of contained specifications. 106 107 Returns: 108 An nn.Module containing the traced method. 109 110 """ 111 from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps 112 from torch._utils_internal import capture_pre_autograd_graph_using_training_ir 113 from torch._export.non_strict_utils import make_constraints 114 from torch._subclasses.functional_tensor import FunctionalTensor 115 from torch.export._unlift import _create_stateful_graph_module 116 from torch.export.dynamic_shapes import _combine_args 117 118 capture_pre_autograd_graph_warning() 119 120 if sys.platform == "win32": 121 raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows") 122 123 assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." 124 125 if kwargs is None: 126 kwargs = {} 127 128 if capture_pre_autograd_graph_using_training_ir(): 129 @lru_cache 130 def print_export_warning(): 131 log.warning("Using torch.export.export_for_training(...,strict=True)") 132 print_export_warning() 133 module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() 134 else: 135 log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) 136 137 # Do not decompose dropout for exported models, because in eval mode the dropout 138 # op disappears from the graph, which makes it difficult to switch to train mode. 139 # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. 140 decomp_table = { 141 op: op.decompose 142 for op in FunctionalTensor.maybe_aliasing_or_mutating_ops 143 if op != torch.ops.aten.dropout.default 144 } 145 with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps(): 146 m = torch._dynamo.export( 147 f, 148 dynamic_shapes=dynamic_shapes, 149 assume_static_by_default=True, 150 tracing_mode="symbolic", 151 decomposition_table=decomp_table, 152 pre_dispatch=True, 153 aten_graph=True, 154 _log_export_usage=False, 155 )( 156 *args, 157 **kwargs, 158 )[0] 159 160 _, _, fake_mode = _extract_fake_inputs(m, args, kwargs) 161 162 m.meta["inline_constraints"] = { 163 k: v 164 for k, v in fake_mode.shape_env.var_to_range.items() 165 if re.match(r"^[if]\d+$", str(k)) 166 } 167 168 if isinstance(f, torch.nn.Module): 169 from torch.export._trace import _restore_state_dict 170 _restore_state_dict(f, m) 171 172 flat_args, _ = pytree.tree_flatten((args, kwargs or {})) 173 combined_args = _combine_args(f, args, kwargs) 174 range_constraints = make_constraints( 175 fake_mode, 176 m, 177 combined_args, 178 dynamic_shapes, 179 0, 180 ) 181 182 module = _create_stateful_graph_module( 183 m, 184 range_constraints=range_constraints, 185 ) 186 187 error_message = \ 188 """ 189 Calling train() or eval() is not supported for exported models. 190 Alternatively, you may override these methods to do custom user behavior as follows: 191 192 def _my_train(self, mode: bool = True): 193 ... 194 195 def _my_eval(self): 196 ... 197 198 model.train = types.MethodType(_my_train, model) 199 model.eval = types.MethodType(_my_eval, model) 200 """ 201 202 def _train(self, mode: bool = True): 203 raise NotImplementedError(error_message) 204 205 def _eval(self, mode: bool = True): 206 raise NotImplementedError(error_message) 207 208 module.train = types.MethodType(_train, module) # type: ignore[method-assign] 209 module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] 210 211 # Remove Proxy because they cannot be deepcopied or pickled. 212 if hasattr(module, "_buffers"): 213 torch._export.utils.remove_proxy_from_state_dict( 214 module._buffers, in_place=True 215 ) 216 return module 217 218 219def aot_compile( 220 f: Callable, 221 args: Tuple[Any], 222 kwargs: Optional[Dict[str, Any]] = None, 223 *, 224 dynamic_shapes: Optional[Dict[str, Any]] = None, 225 options: Optional[Dict[str, Any]] = None, 226 remove_runtime_assertions: bool = False, 227 disable_constraint_solver: bool = False, 228 same_signature: bool = True, 229) -> str: 230 """ 231 Note: this function is not stable yet 232 233 Traces either an nn.Module's forward function or just a callable with PyTorch 234 operations inside, generates executable cpp code from the program, and returns 235 the path to the generated shared library 236 237 Args: 238 f: the `nn.Module` or callable to trace. 239 240 args: example positional inputs. 241 242 kwargs: optional example keyword inputs. 243 244 dynamic_shapes: Should either be: 245 1) a dict from argument names of ``f`` to their dynamic shape specifications, 246 2) a tuple that specifies dynamic shape specifications for each input in original order. 247 If you are specifying dynamism on keyword args, you will need to pass them in the order that 248 is defined in the original function signature. 249 250 The dynamic shape of a tensor argument can be specified as either 251 (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is 252 not required to include static dimension indices in this dict, but when they are, 253 they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, 254 where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions 255 are denoted by None. Arguments that are dicts or tuples / lists of tensors are 256 recursively specified by using mappings or sequences of contained specifications. 257 258 options: A dictionary of options to control inductor 259 260 disable_constraint_solver: Whether the dim constraint solver must be disabled. 261 262 Returns: 263 Path to the generated shared library 264 """ 265 from torch.export._trace import _export_to_torch_ir 266 from torch._inductor.decomposition import select_decomp_table 267 from torch._inductor import config 268 269 if config.is_predispatch: 270 gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module() 271 else: 272 # We want to export to Torch IR here to utilize the pre_grad passes in 273 # inductor, which run on Torch IR. 274 gm = _export_to_torch_ir( 275 f, 276 args, 277 kwargs, 278 dynamic_shapes, 279 disable_constraint_solver=disable_constraint_solver, 280 same_signature=same_signature, 281 # Disabling this flag, because instead we can rely on the mapping 282 # dynamo_flat_name_to_original_fqn which is coming from Dynamo. 283 restore_fqn=False, 284 ) 285 286 with torch.no_grad(): 287 so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type] 288 289 return so_path 290 291def aot_load(so_path: str, device: str) -> Callable: 292 """ 293 Loads a shared library generated by aot_compile and returns a callable 294 295 Args: 296 so_path: Path to the shared library 297 298 Returns: 299 A callable 300 """ 301 if device == "cpu": 302 runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] 303 elif device == "cuda" or device.startswith("cuda:"): 304 runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] 305 else: 306 raise RuntimeError("Unsupported device " + device) 307 308 def optimized(*args, **kwargs): 309 call_spec = runner.get_call_spec() # type: ignore[attr-defined] 310 in_spec = pytree.treespec_loads(call_spec[0]) 311 out_spec = pytree.treespec_loads(call_spec[1]) 312 flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] 313 flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] 314 flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] 315 return pytree.tree_unflatten(flat_outputs, out_spec) 316 317 return optimized 318