1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import abc 5import contextlib 6import dataclasses 7import difflib 8import io 9import logging 10import sys 11from typing import Any, Callable, TYPE_CHECKING 12 13import torch 14import torch.fx 15from torch._subclasses.fake_tensor import unset_fake_temporarily 16from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher 17 18 19if TYPE_CHECKING: 20 from torch._subclasses import fake_tensor 21 22 23@dataclasses.dataclass 24class PackageInfo: 25 package_name: str 26 version: str | None 27 commit_hash: str | None 28 29 def to_onnx_domain_string(self) -> str: 30 return ".".join( 31 filter(None, ("pkg", self.package_name, self.version, self.commit_hash)) 32 ) 33 34 @classmethod 35 def from_python_class(cls, python_class_name: type | str) -> PackageInfo: 36 if isinstance(python_class_name, type): 37 python_class_name = python_class_name.__module__ 38 package_name = python_class_name.split(".")[0] 39 package = __import__(package_name) 40 version = getattr(package, "__version__", None) 41 # TODO: Figure out how to retrieve commit hash. 42 commit_hash = None 43 return cls(package_name, version, commit_hash) 44 45 46@dataclasses.dataclass 47class GraphModuleOnnxMeta: 48 package_info: PackageInfo 49 50 51@contextlib.contextmanager 52def _patch_difflib_sequence_matcher_init(): 53 """Context patching `difflib.SequenceMatcher` for fx readable graph. 54 55 Under this context, the `autojunk` argument of `difflib.SequenceMatcher` will always 56 be considered as `False`. This is to prevent `difflib.SequenceMatcher` recognizing 57 stacktrace messages in fx readable graph as junk, as these messages tend to be long (>200) 58 and repeat multiple times, which falls under the junk filter criteria. 59 60 `difflib.SequenceMatcher` is used underneath by all sorts of diffing functions 61 in `difflib`, including `difflib.unified_diff`, `difflib.ndiff`, `difflib.context_diff`. 62 Unfortunately, there is no way to pass `autojunk` argument to these functions, and 63 they all default to `True`. This context patching will affect all of them. 64 65 `Reference: Automatic junk heuristic <https://docs.python.org/3/library/difflib.html>`_ 66 """ 67 original_init = difflib.SequenceMatcher.__init__ 68 69 def patched_init(self, isjunk=None, a="", b="", autojunk=True): 70 original_init(self, isjunk, a, b, autojunk=False) 71 72 difflib.SequenceMatcher.__init__ = patched_init # type: ignore[assignment] 73 try: 74 yield 75 finally: 76 difflib.SequenceMatcher.__init__ = original_init # type: ignore[assignment] 77 78 79def _unified_diff(a: str, b: str) -> str: 80 """Return a string containing the unified diff of two strings. 81 82 This function calls a patched version of `difflib.unified_diff` with `autojunk` set 83 to `False` for `difflib.SequenceMatcher` class. More details can be found in 84 `_patch_difflib_sequence_matcher_init` function. 85 86 Args: 87 a: The first string. 88 b: The second string. 89 90 Returns: 91 The unified diff of the two strings. If there is no diff, return "<no diff>". 92 93 Example:: 94 95 >>> a = '''class GraphModule(torch.nn.Module): 96 ... def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor): 97 ... # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1]) 98 ... view = input_ids.view(-1, 3); input_ids = None 99 ... ''' 100 >>> b = '''class <lambda>(torch.nn.Module): 101 ... def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]): 102 ... # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1]) 103 ... view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]); input_ids = None 104 ... ''' 105 >>> print(_unified_diff(a, b)) 106 --- 107 +++ 108 @@ -1,4 +1,4 @@ 109 -class GraphModule(torch.nn.Module): 110 - def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor): 111 +class <lambda>(torch.nn.Module): 112 + def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]): 113 # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1]) 114 - view = input_ids.view(-1, 3); input_ids = None 115 + view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]); input_ids = None 116 """ 117 118 a_list = a.splitlines(keepends=True) 119 b_list = b.splitlines(keepends=True) 120 121 with _patch_difflib_sequence_matcher_init(): 122 # Set `n` to `sys.maxsize` to show entire graph when there is a diff. 123 diff = "".join(difflib.unified_diff(a_list, b_list, n=sys.maxsize)) 124 125 if not diff: 126 return "<no diff>" 127 return diff 128 129 130def _transform_diagnose_call_message_formatter( 131 run: Callable, 132 self: Transform, 133 *args: Any, 134 **kwargs: Any, 135) -> str: 136 return f"Running {self.__class__.__name__} pass. " 137 138 139def maybe_fx_graph_tabular(graph: torch.fx.Graph) -> str | None: 140 """Return the Graph nodes in tabular format. Equivalent to stdout of `graph.print_tabular()`. 141 If `tabulate` is not installed, return `None`. 142 143 Args: 144 graph: The Graph to print. 145 146 Returns: 147 The Graph printed in a tabular format. None if `tabulate` is not installed. 148 """ 149 f = io.StringIO() 150 with contextlib.redirect_stdout(f): 151 try: 152 graph.print_tabular() 153 except ImportError: 154 return None 155 return f.getvalue() 156 157 158class Transform(abc.ABC): 159 """Base class for FX graph transformations to be used by FX-ONNX exporter. 160 161 Similar to `FX Interpreter <https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter>`_, 162 specializations of this class execute the FX graph Node-by-Node. 163 Methods in the `Transform` class can be overridden to customize the behavior of the model. 164 This pattern can be useful for many things, including writing code transformations as well as analysis passes. 165 166 The following methods can be overridden:: 167 168 _run() 169 +-- run_node() 170 +-- placeholder() 171 +-- get_attr() 172 +-- call_function() 173 +-- call_method() 174 +-- call_module() 175 +-- output() 176 177 One important aspect to note is that if the transformation modifies the model input and/or output signature, 178 (e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep` 179 are needed to reconcile :attr:`ONNXProgram.model_proto`. 180 That is, the model signature and the model representation must match. 181 182 As an additional feature, this class provides builtin support for transformation recording using the diagnostics. 183 The granularity of overriding is up to the user. And it affects the granularity of 184 the diagnostics information. For example, if `_run()` is overridden, the 185 diagnostics information will only contain graph level transformation. Instead, 186 if `call_function()` is overridden, the diagnostics information will additionally 187 contain the node level information of `call_function()`. 188 189 TODO(bowbao): Add more overridable methods in call hierarchy 190 TODO(bowbao): Create an example once more overridable methods are added. 191 """ 192 193 diagnostic_context: diagnostics.DiagnosticContext 194 """The diagnostic context for recording diagnostics.""" 195 196 module: torch.fx.GraphModule 197 """The module to be transformed.""" 198 199 fake_mode: fake_tensor.FakeTensorMode | None 200 """The existing fake mode detected from `self.module`.""" 201 202 def __init__( 203 self, 204 diagnostic_context: diagnostics.DiagnosticContext, 205 module: torch.fx.GraphModule, 206 ): 207 """Initialize the transform. 208 209 Args: 210 diagnostic_context: The diagnostic context for recording diagnostics. 211 module: The module to be transformed. 212 """ 213 self.diagnostic_context = diagnostic_context 214 self.module = module 215 self.fake_mode = self._detect_fake_mode() 216 217 def _detect_fake_mode(self) -> fake_tensor.FakeTensorMode | None: 218 """Detect fake mode from the graph. 219 220 Scan through all nodes in graph and their meta['val'] to detect fake mode. 221 """ 222 fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes] 223 with unset_fake_temporarily(): 224 return torch._dynamo.utils.detect_fake_mode(fake_tensors) 225 226 def _maybe_fakefy_args( 227 self, fake_mode: fake_tensor.FakeTensorMode | None, *args: Any 228 ) -> tuple[Any, ...]: 229 if fake_mode is None: 230 return args 231 # NB: This should hit the cache if tensors were fakefied before. 232 # E.g., when the fx graph is produced by Dynamo. 233 return tuple( 234 fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args 235 ) 236 237 @abc.abstractmethod 238 def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ... 239 240 @diagnostics.diagnose_call( 241 diagnostics.rules.fx_pass, 242 diagnostic_message_formatter=_transform_diagnose_call_message_formatter, 243 ) 244 def run(self, *args, **kwargs) -> torch.fx.GraphModule: 245 """Run the transform on `self.module`. 246 247 Note that this method may or may not mutate `self.module`, and the returned 248 `GraphModule` could be either `self.module` or a new `GraphModule`. 249 250 Args: 251 *args: Positional arguments for `self.module` to run. 252 **kwargs: Keyword arguments for `self.module` to run. 253 """ 254 diagnostic = self.diagnostic_context.inflight_diagnostic( 255 rule=diagnostics.rules.fx_pass 256 ) 257 diagnostic.info( 258 "For detailed logging of graph modifications by this pass, either set " 259 "`DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable " 260 "`TORCH_LOGS='onnx_diagnostics'`." 261 ) 262 263 # Gather graph information before transform. 264 graph_diff_log_level = logging.DEBUG 265 if diagnostic.logger.isEnabledFor(graph_diff_log_level): 266 # Cannot use LazyString because the graph may have been mutated at evaluation time. 267 old_readable_graph = self.module.print_readable(print_output=False) 268 old_tabular = maybe_fx_graph_tabular(self.module.graph) 269 else: 270 # Set to empty string to avoid unbound warning. This value should never be 271 # used since the log level is not enabled. 272 old_readable_graph = "" 273 old_tabular = "" 274 275 module = self._run(*args, **kwargs) 276 277 # Gather graph information after transform. 278 if diagnostic.logger.isEnabledFor(graph_diff_log_level): 279 new_readable_graph = module.print_readable(print_output=False) 280 new_tabular = maybe_fx_graph_tabular(module.graph) 281 282 with diagnostic.log_section(graph_diff_log_level, "Graph diff:"): 283 diagnostic.log( 284 graph_diff_log_level, 285 "```\n%s\n```", 286 diagnostics.LazyString( 287 _unified_diff, old_readable_graph, new_readable_graph 288 ), 289 ) 290 291 with diagnostic.log_section(graph_diff_log_level, "Tabular diff:"): 292 if old_tabular is None or new_tabular is None: 293 diagnostic.log( 294 graph_diff_log_level, 295 "Tabular diff is not available because `tabulate` is not installed.", 296 ) 297 else: 298 diagnostic.log( 299 graph_diff_log_level, 300 "```\n%s\n```", 301 diagnostics.LazyString(_unified_diff, old_tabular, new_tabular), 302 ) 303 304 return module 305 306 307class AnalysisResult(abc.ABC): # noqa: B024 308 ... 309 310 311class Analysis(abc.ABC): 312 def __init__( 313 self, 314 diagnostic_context: diagnostics.DiagnosticContext, 315 module: torch.fx.GraphModule, 316 onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, 317 ): 318 self.diagnostic_context = diagnostic_context 319 self.module = module 320 self.onnxfunction_dispatcher = onnxfunction_dispatcher 321 322 @abc.abstractmethod 323 def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: ... 324