1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import functools 5from typing import Any, Callable, Mapping, Sequence 6 7import torch 8import torch.fx 9import torch.onnx 10import torch.onnx._internal.fx.passes as passes 11from torch.onnx._internal import _exporter_legacy, io_adapter 12 13 14# Functions directly wrapped to produce torch.fx.Proxy so that symbolic 15# data can flow through those functions. Python functions (e.g., `torch.arange`) 16# not defined by pybind11 in C++ do not go though Python dispatcher, so 17# they are not automatically patched by FX's Python dispatcher. 18# The list below means `torch.arange`, `torch.tensor`, and so on will be 19# patched. 20_TORCH_METHODS_TO_PATCH: tuple[str, ...] = ( 21 "arange", 22 "tensor", 23 "finfo", 24 "full", 25 "empty", 26) 27 28 29class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer): 30 """Tracer to create ONNX-exporting friendly FX graph. 31 32 This tracer traces models into operators. That is, 33 the traced graph mostly contains call_function nodes and 34 has no call_module nodes. The call_module nodes 35 are problematic to the use of make_fx(...) in ONNX 36 exporter. 37 """ 38 39 def is_leaf_module( 40 self, module: torch.nn.Module, module_qualified_name: str 41 ) -> bool: 42 # This returns False so that all sub-modules are considered as not leaves 43 # and therefore expanded into operators in 44 # torch.fx._symbolic_trace.Tracer.call_module. 45 return False 46 47 def to_bool(self, obj: torch.fx.Proxy) -> bool: 48 # FIXME: This is a hack to tracing through if-else Python blocks. 49 # It may generate incorrect ONNX graphs if the if-else block 50 return False 51 52 53def _wrap_for_symbolic_trace(target: Callable) -> tuple[Callable, Callable]: 54 """This function wraps ```target`` for symbolic tracing. 55 56 This function wraps ```target``` so that its wrapper produces 57 torch.fx.Proxy in symbolic computation. The returned values are 58 the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`, 59 this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs. 60 """ 61 62 @functools.wraps(target) 63 def wrapper(*args, **kwargs): 64 proxy = None 65 66 def check_has_proxy(v): 67 if isinstance(v, torch.fx.Proxy): 68 nonlocal proxy 69 proxy = v 70 71 torch.fx.node.map_aggregate(args, check_has_proxy) 72 torch.fx.node.map_aggregate(kwargs, check_has_proxy) 73 74 if proxy is not None: 75 return proxy.tracer.create_proxy("call_function", target, args, kwargs) 76 else: 77 return target(*args, **kwargs) 78 79 return wrapper, target 80 81 82def _module_expansion_symbolic_trace( 83 root: torch.nn.Module | Callable[..., Any], 84 concrete_args: dict[str, Any] | None = None, 85) -> torch.fx.GraphModule: 86 """Trace a callable into FX graph. 87 88 When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be 89 expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph 90 structure. 91 """ 92 # For functions doesn't support symbolic tracing, create wrappers 93 # which produce symbolic results during tracing. 94 patched_torch_methods = { 95 target_name: _wrap_for_symbolic_trace(getattr(torch, target_name)) 96 for target_name in _TORCH_METHODS_TO_PATCH 97 } 98 99 # Set the symbolic-tracing friendly functions so that `tracer.trace` below 100 # can work. 101 for name, (wrapper, _) in patched_torch_methods.items(): 102 setattr(torch, name, wrapper) 103 104 try: 105 # Set up a tracer. 106 tracer = ModuleExpansionTracer() 107 # Trace the model. 108 graph = tracer.trace(root, concrete_args) 109 name = ( 110 root.__class__.__name__ 111 if isinstance(root, torch.nn.Module) 112 else root.__name__ 113 ) 114 return torch.fx.GraphModule(tracer.root, graph, name) 115 finally: 116 # Revert the patches for symbolic tracing. 117 for name, (_, wrapped) in patched_torch_methods.items(): 118 # wrapped is the original version of `torch.name`. 119 setattr(torch, name, wrapped) 120 121 122# TODO: Migrate to `DynamoExporter` after fake model tracing is supported. 123# Proposal at https://github.com/pytorch/pytorch/issues/95900. 124class FXSymbolicTracer(_exporter_legacy.FXGraphExtractor): 125 """Generates a FX GraphModule using torch.fx.symbolic_trace API 126 Args: 127 concrete_args: Inputs to be partially specialized 128 It can be used to remove control flow or data structures. 129 For example:: 130 def f(a, b): 131 if b == True: 132 return a 133 else: 134 return a*2 135 FX can typically not trace through this due to the presence of control 136 flow. However, we can use `concrete_args` to specialize on the value of 137 `b` to trace through this:: 138 f = fx.symbolic_trace(f, concrete_args={'b': False}) 139 assert f(3, False) == 6 140 Note that although you can still pass in different values of `b`, they will be ignored. 141 It can also be used to eliminate data-structure handling from 142 our function. This will use pytrees to flatten your input. To avoid 143 overspecializing, pass in `fx.PH` for values that shouldn't be 144 specialized. For example:: 145 def f(x): 146 out = 0 147 for v in x.values(): 148 out += v 149 return out 150 151 152 f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) 153 assert f({"a": 1, "b": 2, "c": 4}) == 7 154 """ 155 156 def __init__(self, concrete_args: dict[str, Any] | None = None): 157 super().__init__() 158 # TODO: plumb ``concrete_args`` to symbolic_trace call at ``generate_fx`` 159 self.concrete_args = concrete_args 160 161 def _trace_into_fx_graph_via_fx_symbolic_trace( 162 self, model, model_args, model_kwargs 163 ) -> torch.fx.GraphModule: 164 # Bind model args and kwargs with model signature to retrieve default values 165 # of unprovided arguments. These are then used to construct ``concrete_args``. 166 bind_input_step = io_adapter.BindInputStep( 167 torch.onnx.utils.model_signature(model) 168 ) 169 self.input_adapter.append_step(bind_input_step) 170 _, named_args = bind_input_step.apply(model_args, model_kwargs, model=model) 171 172 # Create inputs to call symbolic trace (torch.fx.symbolic_trace) 173 # Example content of concrete_args: 174 # concrete_args["x"] = torch.fx._symbolic_trace.PH 175 # concrete_args["b"] = 1 176 # where "x" and "b" are argument names in "signature". 177 concrete_args = {} 178 for param_name, param_value in named_args.items(): 179 if isinstance(param_value, torch.Tensor): 180 # param_value can be, e.g., a real tensor or a fake tensor. 181 # param_value is treated as substitutable tensor symbol (aka placeholder). 182 concrete_args[param_name] = torch.fx._symbolic_trace.PH 183 else: 184 concrete_args[param_name] = param_value 185 186 # Merge kwargs back into args since that is the format FX graph expects. 187 merge_kwargs_step = io_adapter.MergeKwargsIntoArgsInputStep() 188 self.input_adapter.append_step(merge_kwargs_step) 189 return _module_expansion_symbolic_trace(model, concrete_args=concrete_args) 190 191 def generate_fx( 192 self, 193 options: _exporter_legacy.ResolvedExportOptions, 194 model: torch.nn.Module | Callable, 195 model_args: Sequence[Any], 196 model_kwargs: Mapping[str, Any], 197 ) -> torch.fx.GraphModule: 198 diagnostic_context = options.diagnostic_context 199 graph_module = self._trace_into_fx_graph_via_fx_symbolic_trace( 200 model, model_args, model_kwargs 201 ) 202 203 # Make sure all placeholder nodes are executed before get_attr nodes. 204 # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input. 205 # Basically, we want 206 # ModeoProto.graph.input = 207 # [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m] 208 # and we don't want 209 # ModeoProto.graph.input = 210 # [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m] 211 graph_module = passes.MovePlaceholderToFront( 212 diagnostic_context, graph_module 213 ).run() 214 # To save memory, move get_attr to input so that the generated model doesn't 215 # have weigh tensors. "replaced_attrs" are a tuple of replaced weight tensors. 216 replace_get_attr_with_placeholder_pass = passes.ReplaceGetAttrWithPlaceholder( 217 diagnostic_context, graph_module 218 ) 219 graph_module = replace_get_attr_with_placeholder_pass.run() 220 replaced_attrs = replace_get_attr_with_placeholder_pass.replaced_attrs 221 append_extra_input_step = io_adapter.LiftParametersAndBuffersIntoArgsInputStep( 222 replaced_attrs 223 ) 224 self.input_adapter.append_step(append_extra_input_step) 225 # Move all newly created placeholder nodes to the front of the graph. 226 graph_module = passes.MovePlaceholderToFront( 227 diagnostic_context, graph_module 228 ).run() 229 # Finalize the graph editing. 230 graph_module.recompile() 231 232 updated_model_args = self.input_adapter.apply( 233 *model_args, model=model, **model_kwargs 234 ) 235 236 return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value] 237 238 def pre_export_passes( 239 self, 240 options: _exporter_legacy.ResolvedExportOptions, 241 original_model: torch.nn.Module | Callable, 242 fx_module: torch.fx.GraphModule, 243 fx_module_args: Sequence[Any], 244 ): 245 return _exporter_legacy.common_pre_export_passes( 246 options, original_model, fx_module, fx_module_args 247 ) 248