xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import abc
5import contextlib
6import dataclasses
7import difflib
8import io
9import logging
10import sys
11from typing import Any, Callable, TYPE_CHECKING
12
13import torch
14import torch.fx
15from torch._subclasses.fake_tensor import unset_fake_temporarily
16from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
17
18
19if TYPE_CHECKING:
20    from torch._subclasses import fake_tensor
21
22
23@dataclasses.dataclass
24class PackageInfo:
25    package_name: str
26    version: str | None
27    commit_hash: str | None
28
29    def to_onnx_domain_string(self) -> str:
30        return ".".join(
31            filter(None, ("pkg", self.package_name, self.version, self.commit_hash))
32        )
33
34    @classmethod
35    def from_python_class(cls, python_class_name: type | str) -> PackageInfo:
36        if isinstance(python_class_name, type):
37            python_class_name = python_class_name.__module__
38        package_name = python_class_name.split(".")[0]
39        package = __import__(package_name)
40        version = getattr(package, "__version__", None)
41        # TODO: Figure out how to retrieve commit hash.
42        commit_hash = None
43        return cls(package_name, version, commit_hash)
44
45
46@dataclasses.dataclass
47class GraphModuleOnnxMeta:
48    package_info: PackageInfo
49
50
51@contextlib.contextmanager
52def _patch_difflib_sequence_matcher_init():
53    """Context patching `difflib.SequenceMatcher` for fx readable graph.
54
55    Under this context, the `autojunk` argument of `difflib.SequenceMatcher` will always
56    be considered as `False`. This is to prevent `difflib.SequenceMatcher` recognizing
57    stacktrace messages in fx readable graph as junk, as these messages tend to be long (>200)
58    and repeat multiple times, which falls under the junk filter criteria.
59
60    `difflib.SequenceMatcher` is used underneath by all sorts of diffing functions
61    in `difflib`, including `difflib.unified_diff`, `difflib.ndiff`, `difflib.context_diff`.
62    Unfortunately, there is no way to pass `autojunk` argument to these functions, and
63    they all default to `True`. This context patching will affect all of them.
64
65    `Reference: Automatic junk heuristic <https://docs.python.org/3/library/difflib.html>`_
66    """
67    original_init = difflib.SequenceMatcher.__init__
68
69    def patched_init(self, isjunk=None, a="", b="", autojunk=True):
70        original_init(self, isjunk, a, b, autojunk=False)
71
72    difflib.SequenceMatcher.__init__ = patched_init  # type: ignore[assignment]
73    try:
74        yield
75    finally:
76        difflib.SequenceMatcher.__init__ = original_init  # type: ignore[assignment]
77
78
79def _unified_diff(a: str, b: str) -> str:
80    """Return a string containing the unified diff of two strings.
81
82    This function calls a patched version of `difflib.unified_diff` with `autojunk` set
83    to `False` for `difflib.SequenceMatcher` class. More details can be found in
84    `_patch_difflib_sequence_matcher_init` function.
85
86    Args:
87        a: The first string.
88        b: The second string.
89
90    Returns:
91        The unified diff of the two strings. If there is no diff, return "<no diff>".
92
93    Example::
94
95        >>> a = '''class GraphModule(torch.nn.Module):
96        ...     def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor):
97        ...         # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1])
98        ...         view = input_ids.view(-1, 3);  input_ids = None
99        ... '''
100        >>> b = '''class <lambda>(torch.nn.Module):
101        ...     def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]):
102        ...         # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1])
103        ...         view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]);  input_ids = None
104        ... '''
105        >>> print(_unified_diff(a, b))
106        ---
107        +++
108        @@ -1,4 +1,4 @@
109        -class GraphModule(torch.nn.Module):
110        -    def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor):
111        +class <lambda>(torch.nn.Module):
112        +    def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]):
113                # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1])
114        -        view = input_ids.view(-1, 3);  input_ids = None
115        +        view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]);  input_ids = None
116    """
117
118    a_list = a.splitlines(keepends=True)
119    b_list = b.splitlines(keepends=True)
120
121    with _patch_difflib_sequence_matcher_init():
122        # Set `n` to `sys.maxsize` to show entire graph when there is a diff.
123        diff = "".join(difflib.unified_diff(a_list, b_list, n=sys.maxsize))
124
125    if not diff:
126        return "<no diff>"
127    return diff
128
129
130def _transform_diagnose_call_message_formatter(
131    run: Callable,
132    self: Transform,
133    *args: Any,
134    **kwargs: Any,
135) -> str:
136    return f"Running {self.__class__.__name__} pass. "
137
138
139def maybe_fx_graph_tabular(graph: torch.fx.Graph) -> str | None:
140    """Return the Graph nodes in tabular format. Equivalent to stdout of `graph.print_tabular()`.
141    If `tabulate` is not installed, return `None`.
142
143    Args:
144        graph: The Graph to print.
145
146    Returns:
147        The Graph printed in a tabular format. None if `tabulate` is not installed.
148    """
149    f = io.StringIO()
150    with contextlib.redirect_stdout(f):
151        try:
152            graph.print_tabular()
153        except ImportError:
154            return None
155    return f.getvalue()
156
157
158class Transform(abc.ABC):
159    """Base class for FX graph transformations to be used by FX-ONNX exporter.
160
161    Similar to `FX Interpreter <https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter>`_,
162    specializations of this class execute the FX graph Node-by-Node.
163    Methods in the `Transform` class can be overridden to customize the behavior of the model.
164    This pattern can be useful for many things, including writing code transformations as well as analysis passes.
165
166    The following methods can be overridden::
167
168        _run()
169            +-- run_node()
170                +-- placeholder()
171                +-- get_attr()
172                +-- call_function()
173                +-- call_method()
174                +-- call_module()
175                +-- output()
176
177    One important aspect to note is that if the transformation modifies the model input and/or output signature,
178    (e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`
179    are needed to reconcile :attr:`ONNXProgram.model_proto`.
180    That is, the model signature and the model representation must match.
181
182    As an additional feature, this class provides builtin support for transformation recording using the diagnostics.
183    The granularity of overriding is up to the user. And it affects the granularity of
184    the diagnostics information. For example, if `_run()` is overridden, the
185    diagnostics information will only contain graph level transformation. Instead,
186    if `call_function()` is overridden, the diagnostics information will additionally
187    contain the node level information of `call_function()`.
188
189    TODO(bowbao): Add more overridable methods in call hierarchy
190    TODO(bowbao): Create an example once more overridable methods are added.
191    """
192
193    diagnostic_context: diagnostics.DiagnosticContext
194    """The diagnostic context for recording diagnostics."""
195
196    module: torch.fx.GraphModule
197    """The module to be transformed."""
198
199    fake_mode: fake_tensor.FakeTensorMode | None
200    """The existing fake mode detected from `self.module`."""
201
202    def __init__(
203        self,
204        diagnostic_context: diagnostics.DiagnosticContext,
205        module: torch.fx.GraphModule,
206    ):
207        """Initialize the transform.
208
209        Args:
210            diagnostic_context: The diagnostic context for recording diagnostics.
211            module: The module to be transformed.
212        """
213        self.diagnostic_context = diagnostic_context
214        self.module = module
215        self.fake_mode = self._detect_fake_mode()
216
217    def _detect_fake_mode(self) -> fake_tensor.FakeTensorMode | None:
218        """Detect fake mode from the graph.
219
220        Scan through all nodes in graph and their meta['val'] to detect fake mode.
221        """
222        fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes]
223        with unset_fake_temporarily():
224            return torch._dynamo.utils.detect_fake_mode(fake_tensors)
225
226    def _maybe_fakefy_args(
227        self, fake_mode: fake_tensor.FakeTensorMode | None, *args: Any
228    ) -> tuple[Any, ...]:
229        if fake_mode is None:
230            return args
231        # NB: This should hit the cache if tensors were fakefied before.
232        # E.g., when the fx graph is produced by Dynamo.
233        return tuple(
234            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args
235        )
236
237    @abc.abstractmethod
238    def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ...
239
240    @diagnostics.diagnose_call(
241        diagnostics.rules.fx_pass,
242        diagnostic_message_formatter=_transform_diagnose_call_message_formatter,
243    )
244    def run(self, *args, **kwargs) -> torch.fx.GraphModule:
245        """Run the transform on `self.module`.
246
247        Note that this method may or may not mutate `self.module`, and the returned
248        `GraphModule` could be either `self.module` or a new `GraphModule`.
249
250        Args:
251            *args: Positional arguments for `self.module` to run.
252            **kwargs: Keyword arguments for `self.module` to run.
253        """
254        diagnostic = self.diagnostic_context.inflight_diagnostic(
255            rule=diagnostics.rules.fx_pass
256        )
257        diagnostic.info(
258            "For detailed logging of graph modifications by this pass, either set "
259            "`DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable "
260            "`TORCH_LOGS='onnx_diagnostics'`."
261        )
262
263        # Gather graph information before transform.
264        graph_diff_log_level = logging.DEBUG
265        if diagnostic.logger.isEnabledFor(graph_diff_log_level):
266            # Cannot use LazyString because the graph may have been mutated at evaluation time.
267            old_readable_graph = self.module.print_readable(print_output=False)
268            old_tabular = maybe_fx_graph_tabular(self.module.graph)
269        else:
270            # Set to empty string to avoid unbound warning. This value should never be
271            # used since the log level is not enabled.
272            old_readable_graph = ""
273            old_tabular = ""
274
275        module = self._run(*args, **kwargs)
276
277        # Gather graph information after transform.
278        if diagnostic.logger.isEnabledFor(graph_diff_log_level):
279            new_readable_graph = module.print_readable(print_output=False)
280            new_tabular = maybe_fx_graph_tabular(module.graph)
281
282            with diagnostic.log_section(graph_diff_log_level, "Graph diff:"):
283                diagnostic.log(
284                    graph_diff_log_level,
285                    "```\n%s\n```",
286                    diagnostics.LazyString(
287                        _unified_diff, old_readable_graph, new_readable_graph
288                    ),
289                )
290
291            with diagnostic.log_section(graph_diff_log_level, "Tabular diff:"):
292                if old_tabular is None or new_tabular is None:
293                    diagnostic.log(
294                        graph_diff_log_level,
295                        "Tabular diff is not available because `tabulate` is not installed.",
296                    )
297                else:
298                    diagnostic.log(
299                        graph_diff_log_level,
300                        "```\n%s\n```",
301                        diagnostics.LazyString(_unified_diff, old_tabular, new_tabular),
302                    )
303
304        return module
305
306
307class AnalysisResult(abc.ABC):  # noqa: B024
308    ...
309
310
311class Analysis(abc.ABC):
312    def __init__(
313        self,
314        diagnostic_context: diagnostics.DiagnosticContext,
315        module: torch.fx.GraphModule,
316        onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
317    ):
318        self.diagnostic_context = diagnostic_context
319        self.module = module
320        self.onnxfunction_dispatcher = onnxfunction_dispatcher
321
322    @abc.abstractmethod
323    def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: ...
324