xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/diagnostics/_diagnostic.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
3
4from __future__ import annotations
5
6import contextlib
7import gzip
8from typing import TYPE_CHECKING
9
10import torch
11from torch.onnx._internal.diagnostics import infra
12from torch.onnx._internal.diagnostics.infra import formatter, sarif
13from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
14from torch.utils import cpp_backtrace
15
16
17if TYPE_CHECKING:
18    from collections.abc import Generator
19
20
21def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
22    """Returns the current C++ call stack.
23
24    This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
25    The returned C++ call stack is a concatenated string of the C++ call stack frames.
26    Each frame is separated by a newline character, in the same format of
27    r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
28
29    """
30    frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
31    frame_messages = []
32    for frame in frames:
33        segments = frame.split(":", 1)
34        if len(segments) == 2:
35            frame_messages.append(segments[1].strip())
36        else:
37            frame_messages.append("<unknown frame>")
38    return infra.Stack(
39        frames=[
40            infra.StackFrame(location=infra.Location(message=message))
41            for message in frame_messages
42        ]
43    )
44
45
46class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
47    """Base class for all export diagnostics.
48
49    This class is used to represent all export diagnostics. It is a subclass of
50    infra.Diagnostic, and adds additional methods to add more information to the
51    diagnostic.
52    """
53
54    python_call_stack: infra.Stack | None = None
55    cpp_call_stack: infra.Stack | None = None
56
57    def __init__(
58        self,
59        *args,
60        frames_to_skip: int = 1,
61        cpp_stack: bool = False,
62        **kwargs,
63    ) -> None:
64        super().__init__(*args, **kwargs)
65        self.python_call_stack = self.record_python_call_stack(
66            frames_to_skip=frames_to_skip
67        )
68        if cpp_stack:
69            self.cpp_call_stack = self.record_cpp_call_stack(
70                frames_to_skip=frames_to_skip
71            )
72
73    def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
74        """Records the current C++ call stack in the diagnostic."""
75        stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
76        stack.message = "C++ call stack"
77        self.with_stack(stack)
78        return stack
79
80
81class ExportDiagnosticEngine:
82    """PyTorch ONNX Export diagnostic engine.
83
84    The only purpose of creating this class instead of using `DiagnosticContext` directly
85    is to provide a background context for `diagnose` calls inside exporter.
86
87    By design, one `torch.onnx.export` call should initialize one diagnostic context.
88    All `diagnose` calls inside exporter should be made in the context of that export.
89    However, since diagnostic context is currently being accessed via a global variable,
90    there is no guarantee that the context is properly initialized. Therefore, we need
91    to provide a default background context to fallback to, otherwise any invocation of
92    exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
93    This can be removed once the pipeline for context to flow through the exporter is
94    established.
95    """
96
97    contexts: list[infra.DiagnosticContext]
98    _background_context: infra.DiagnosticContext
99
100    def __init__(self) -> None:
101        self.contexts = []
102        self._background_context = infra.DiagnosticContext(
103            name="torch.onnx",
104            version=torch.__version__,
105        )
106
107    @property
108    def background_context(self) -> infra.DiagnosticContext:
109        return self._background_context
110
111    def create_diagnostic_context(
112        self,
113        name: str,
114        version: str,
115        options: infra.DiagnosticOptions | None = None,
116    ) -> infra.DiagnosticContext:
117        """Creates a new diagnostic context.
118
119        Args:
120            name: The subject name for the diagnostic context.
121            version: The subject version for the diagnostic context.
122            options: The options for the diagnostic context.
123
124        Returns:
125            A new diagnostic context.
126        """
127        if options is None:
128            options = infra.DiagnosticOptions()
129        context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
130            name, version, options
131        )
132        self.contexts.append(context)
133        return context
134
135    def clear(self):
136        """Clears all diagnostic contexts."""
137        self.contexts.clear()
138        self._background_context.diagnostics.clear()
139
140    def to_json(self) -> str:
141        return formatter.sarif_to_json(self.sarif_log())
142
143    def dump(self, file_path: str, compress: bool = False) -> None:
144        """Dumps the SARIF log to a file."""
145        if compress:
146            with gzip.open(file_path, "wt") as f:
147                f.write(self.to_json())
148        else:
149            with open(file_path, "w") as f:
150                f.write(self.to_json())
151
152    def sarif_log(self):
153        log = sarif.SarifLog(
154            version=sarif_version.SARIF_VERSION,
155            schema_uri=sarif_version.SARIF_SCHEMA_LINK,
156            runs=[context.sarif() for context in self.contexts],
157        )
158
159        log.runs.append(self._background_context.sarif())
160        return log
161
162
163engine = ExportDiagnosticEngine()
164_context = engine.background_context
165
166
167@contextlib.contextmanager
168def create_export_diagnostic_context() -> (
169    Generator[infra.DiagnosticContext, None, None]
170):
171    """Create a diagnostic context for export.
172
173    This is a workaround for code robustness since diagnostic context is accessed by
174    export internals via global variable. See `ExportDiagnosticEngine` for more details.
175    """
176    global _context
177    assert (
178        _context == engine.background_context
179    ), "Export context is already set. Nested export is not supported."
180    _context = engine.create_diagnostic_context(
181        "torch.onnx.export",
182        torch.__version__,
183    )
184    try:
185        yield _context
186    finally:
187        _context = engine.background_context
188
189
190def diagnose(
191    rule: infra.Rule,
192    level: infra.Level,
193    message: str | None = None,
194    frames_to_skip: int = 2,
195    **kwargs,
196) -> TorchScriptOnnxExportDiagnostic:
197    """Creates a diagnostic and record it in the global diagnostic context.
198
199    This is a wrapper around `context.log` that uses the global diagnostic
200    context.
201    """
202    diagnostic = TorchScriptOnnxExportDiagnostic(
203        rule, level, message, frames_to_skip=frames_to_skip, **kwargs
204    )
205    export_context().log(diagnostic)
206    return diagnostic
207
208
209def export_context() -> infra.DiagnosticContext:
210    global _context
211    return _context
212