1# mypy: allow-untyped-defs 2"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`.""" 3 4from __future__ import annotations 5 6import contextlib 7import gzip 8from typing import TYPE_CHECKING 9 10import torch 11from torch.onnx._internal.diagnostics import infra 12from torch.onnx._internal.diagnostics.infra import formatter, sarif 13from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version 14from torch.utils import cpp_backtrace 15 16 17if TYPE_CHECKING: 18 from collections.abc import Generator 19 20 21def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack: 22 """Returns the current C++ call stack. 23 24 This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack. 25 The returned C++ call stack is a concatenated string of the C++ call stack frames. 26 Each frame is separated by a newline character, in the same format of 27 r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`. 28 29 """ 30 frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n") 31 frame_messages = [] 32 for frame in frames: 33 segments = frame.split(":", 1) 34 if len(segments) == 2: 35 frame_messages.append(segments[1].strip()) 36 else: 37 frame_messages.append("<unknown frame>") 38 return infra.Stack( 39 frames=[ 40 infra.StackFrame(location=infra.Location(message=message)) 41 for message in frame_messages 42 ] 43 ) 44 45 46class TorchScriptOnnxExportDiagnostic(infra.Diagnostic): 47 """Base class for all export diagnostics. 48 49 This class is used to represent all export diagnostics. It is a subclass of 50 infra.Diagnostic, and adds additional methods to add more information to the 51 diagnostic. 52 """ 53 54 python_call_stack: infra.Stack | None = None 55 cpp_call_stack: infra.Stack | None = None 56 57 def __init__( 58 self, 59 *args, 60 frames_to_skip: int = 1, 61 cpp_stack: bool = False, 62 **kwargs, 63 ) -> None: 64 super().__init__(*args, **kwargs) 65 self.python_call_stack = self.record_python_call_stack( 66 frames_to_skip=frames_to_skip 67 ) 68 if cpp_stack: 69 self.cpp_call_stack = self.record_cpp_call_stack( 70 frames_to_skip=frames_to_skip 71 ) 72 73 def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack: 74 """Records the current C++ call stack in the diagnostic.""" 75 stack = _cpp_call_stack(frames_to_skip=frames_to_skip) 76 stack.message = "C++ call stack" 77 self.with_stack(stack) 78 return stack 79 80 81class ExportDiagnosticEngine: 82 """PyTorch ONNX Export diagnostic engine. 83 84 The only purpose of creating this class instead of using `DiagnosticContext` directly 85 is to provide a background context for `diagnose` calls inside exporter. 86 87 By design, one `torch.onnx.export` call should initialize one diagnostic context. 88 All `diagnose` calls inside exporter should be made in the context of that export. 89 However, since diagnostic context is currently being accessed via a global variable, 90 there is no guarantee that the context is properly initialized. Therefore, we need 91 to provide a default background context to fallback to, otherwise any invocation of 92 exporter internals, e.g. unit tests, will fail due to missing diagnostic context. 93 This can be removed once the pipeline for context to flow through the exporter is 94 established. 95 """ 96 97 contexts: list[infra.DiagnosticContext] 98 _background_context: infra.DiagnosticContext 99 100 def __init__(self) -> None: 101 self.contexts = [] 102 self._background_context = infra.DiagnosticContext( 103 name="torch.onnx", 104 version=torch.__version__, 105 ) 106 107 @property 108 def background_context(self) -> infra.DiagnosticContext: 109 return self._background_context 110 111 def create_diagnostic_context( 112 self, 113 name: str, 114 version: str, 115 options: infra.DiagnosticOptions | None = None, 116 ) -> infra.DiagnosticContext: 117 """Creates a new diagnostic context. 118 119 Args: 120 name: The subject name for the diagnostic context. 121 version: The subject version for the diagnostic context. 122 options: The options for the diagnostic context. 123 124 Returns: 125 A new diagnostic context. 126 """ 127 if options is None: 128 options = infra.DiagnosticOptions() 129 context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext( 130 name, version, options 131 ) 132 self.contexts.append(context) 133 return context 134 135 def clear(self): 136 """Clears all diagnostic contexts.""" 137 self.contexts.clear() 138 self._background_context.diagnostics.clear() 139 140 def to_json(self) -> str: 141 return formatter.sarif_to_json(self.sarif_log()) 142 143 def dump(self, file_path: str, compress: bool = False) -> None: 144 """Dumps the SARIF log to a file.""" 145 if compress: 146 with gzip.open(file_path, "wt") as f: 147 f.write(self.to_json()) 148 else: 149 with open(file_path, "w") as f: 150 f.write(self.to_json()) 151 152 def sarif_log(self): 153 log = sarif.SarifLog( 154 version=sarif_version.SARIF_VERSION, 155 schema_uri=sarif_version.SARIF_SCHEMA_LINK, 156 runs=[context.sarif() for context in self.contexts], 157 ) 158 159 log.runs.append(self._background_context.sarif()) 160 return log 161 162 163engine = ExportDiagnosticEngine() 164_context = engine.background_context 165 166 167@contextlib.contextmanager 168def create_export_diagnostic_context() -> ( 169 Generator[infra.DiagnosticContext, None, None] 170): 171 """Create a diagnostic context for export. 172 173 This is a workaround for code robustness since diagnostic context is accessed by 174 export internals via global variable. See `ExportDiagnosticEngine` for more details. 175 """ 176 global _context 177 assert ( 178 _context == engine.background_context 179 ), "Export context is already set. Nested export is not supported." 180 _context = engine.create_diagnostic_context( 181 "torch.onnx.export", 182 torch.__version__, 183 ) 184 try: 185 yield _context 186 finally: 187 _context = engine.background_context 188 189 190def diagnose( 191 rule: infra.Rule, 192 level: infra.Level, 193 message: str | None = None, 194 frames_to_skip: int = 2, 195 **kwargs, 196) -> TorchScriptOnnxExportDiagnostic: 197 """Creates a diagnostic and record it in the global diagnostic context. 198 199 This is a wrapper around `context.log` that uses the global diagnostic 200 context. 201 """ 202 diagnostic = TorchScriptOnnxExportDiagnostic( 203 rule, level, message, frames_to_skip=frames_to_skip, **kwargs 204 ) 205 export_context().log(diagnostic) 206 return diagnostic 207 208 209def export_context() -> infra.DiagnosticContext: 210 global _context 211 return _context 212