1# mypy: allow-untyped-defs 2"""This file defines an additional layer of abstraction on top of the SARIF OM.""" 3 4from __future__ import annotations 5 6import dataclasses 7import enum 8import logging 9from typing import Mapping, Sequence 10 11from torch.onnx._internal.diagnostics.infra import formatter, sarif 12 13 14class Level(enum.IntEnum): 15 """The level of a diagnostic. 16 17 This class is used to represent the level of a diagnostic. The levels are defined 18 by the SARIF specification, and are not modifiable. For alternative categories, 19 please use infra.Tag instead. When selecting a level, please consider the following 20 guidelines: 21 22 - NONE: Informational result that does not indicate the presence of a problem. 23 - NOTE: An opportunity for improvement was found. 24 - WARNING: A potential problem was found. 25 - ERROR: A serious problem was found. 26 27 This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer 28 value maps to the logging levels in Python's logging module. The mapping is as 29 follows: 30 31 Level.NONE = logging.DEBUG = 10 32 Level.NOTE = logging.INFO = 20 33 Level.WARNING = logging.WARNING = 30 34 Level.ERROR = logging.ERROR = 40 35 """ 36 37 NONE = 10 38 NOTE = 20 39 WARNING = 30 40 ERROR = 40 41 42 43levels = Level 44 45 46class Tag(enum.Enum): 47 """The tag of a diagnostic. This class can be inherited to define custom tags.""" 48 49 50class PatchedPropertyBag(sarif.PropertyBag): 51 """Key/value pairs that provide additional information about the object. 52 53 The definition of PropertyBag via SARIF spec is "A property bag is an object (section 3.6) 54 containing an unordered set of properties with arbitrary names." However it is not 55 reflected in the json file, and therefore not captured by the python representation. 56 This patch adds additional **kwargs to the `__init__` method to allow recording 57 arbitrary key/value pairs. 58 """ 59 60 def __init__(self, tags: list[str] | None = None, **kwargs): 61 super().__init__(tags=tags) 62 self.__dict__.update(kwargs) 63 64 65@dataclasses.dataclass(frozen=True) 66class Rule: 67 id: str 68 name: str 69 message_default_template: str 70 short_description: str | None = None 71 full_description: str | None = None 72 full_description_markdown: str | None = None 73 help_uri: str | None = None 74 75 @classmethod 76 def from_sarif(cls, **kwargs): 77 """Returns a rule from the SARIF reporting descriptor.""" 78 short_description = kwargs.get("short_description", {}).get("text") 79 full_description = kwargs.get("full_description", {}).get("text") 80 full_description_markdown = kwargs.get("full_description", {}).get("markdown") 81 help_uri = kwargs.get("help_uri") 82 83 rule = cls( 84 id=kwargs["id"], 85 name=kwargs["name"], 86 message_default_template=kwargs["message_strings"]["default"]["text"], 87 short_description=short_description, 88 full_description=full_description, 89 full_description_markdown=full_description_markdown, 90 help_uri=help_uri, 91 ) 92 return rule 93 94 def sarif(self) -> sarif.ReportingDescriptor: 95 """Returns a SARIF reporting descriptor of this Rule.""" 96 short_description = ( 97 sarif.MultiformatMessageString(text=self.short_description) 98 if self.short_description is not None 99 else None 100 ) 101 full_description = ( 102 sarif.MultiformatMessageString( 103 text=self.full_description, markdown=self.full_description_markdown 104 ) 105 if self.full_description is not None 106 else None 107 ) 108 return sarif.ReportingDescriptor( 109 id=self.id, 110 name=self.name, 111 short_description=short_description, 112 full_description=full_description, 113 help_uri=self.help_uri, 114 ) 115 116 def format(self, level: Level, *args, **kwargs) -> tuple[Rule, Level, str]: 117 """Returns a tuple of (rule, level, message) for a diagnostic. 118 119 This method is used to format the message of a diagnostic. The message is 120 formatted using the default template of this rule, and the arguments passed in 121 as `*args` and `**kwargs`. The level is used to override the default level of 122 this rule. 123 """ 124 return (self, level, self.format_message(*args, **kwargs)) 125 126 def format_message(self, *args, **kwargs) -> str: 127 """Returns the formatted default message of this Rule. 128 129 This method should be overridden (with code generation) by subclasses to reflect 130 the exact arguments needed by the message template. This is a helper method to 131 create the default message for a diagnostic. 132 """ 133 return self.message_default_template.format(*args, **kwargs) 134 135 136@dataclasses.dataclass 137class Location: 138 uri: str | None = None 139 line: int | None = None 140 message: str | None = None 141 start_column: int | None = None 142 end_column: int | None = None 143 snippet: str | None = None 144 function: str | None = None 145 146 def sarif(self) -> sarif.Location: 147 """Returns the SARIF representation of this location.""" 148 return sarif.Location( 149 physical_location=sarif.PhysicalLocation( 150 artifact_location=sarif.ArtifactLocation(uri=self.uri), 151 region=sarif.Region( 152 start_line=self.line, 153 start_column=self.start_column, 154 end_column=self.end_column, 155 snippet=sarif.ArtifactContent(text=self.snippet), 156 ), 157 ), 158 message=sarif.Message(text=self.message) 159 if self.message is not None 160 else None, 161 ) 162 163 164@dataclasses.dataclass 165class StackFrame: 166 location: Location 167 168 def sarif(self) -> sarif.StackFrame: 169 """Returns the SARIF representation of this stack frame.""" 170 return sarif.StackFrame(location=self.location.sarif()) 171 172 173@dataclasses.dataclass 174class Stack: 175 """Records a stack trace. The frames are in order from newest to oldest stack frame.""" 176 177 frames: list[StackFrame] = dataclasses.field(default_factory=list) 178 message: str | None = None 179 180 def sarif(self) -> sarif.Stack: 181 """Returns the SARIF representation of this stack.""" 182 return sarif.Stack( 183 frames=[frame.sarif() for frame in self.frames], 184 message=sarif.Message(text=self.message) 185 if self.message is not None 186 else None, 187 ) 188 189 190@dataclasses.dataclass 191class ThreadFlowLocation: 192 """Records code location and the initial state.""" 193 194 location: Location 195 state: Mapping[str, str] 196 index: int 197 stack: Stack | None = None 198 199 def sarif(self) -> sarif.ThreadFlowLocation: 200 """Returns the SARIF representation of this thread flow location.""" 201 return sarif.ThreadFlowLocation( 202 location=self.location.sarif(), 203 state=self.state, 204 stack=self.stack.sarif() if self.stack is not None else None, 205 ) 206 207 208@dataclasses.dataclass 209class Graph: 210 """A graph of diagnostics. 211 212 This class stores the string representation of a model graph. 213 The `nodes` and `edges` fields are unused in the current implementation. 214 """ 215 216 graph: str 217 name: str 218 description: str | None = None 219 220 def sarif(self) -> sarif.Graph: 221 """Returns the SARIF representation of this graph.""" 222 return sarif.Graph( 223 description=sarif.Message(text=self.graph), 224 properties=PatchedPropertyBag(name=self.name, description=self.description), 225 ) 226 227 228@dataclasses.dataclass 229class RuleCollection: 230 _rule_id_name_set: frozenset[tuple[str, str]] = dataclasses.field(init=False) 231 232 def __post_init__(self) -> None: 233 self._rule_id_name_set = frozenset( 234 { 235 (field.default.id, field.default.name) 236 for field in dataclasses.fields(self) 237 if isinstance(field.default, Rule) 238 } 239 ) 240 241 def __contains__(self, rule: Rule) -> bool: 242 """Checks if the rule is in the collection.""" 243 return (rule.id, rule.name) in self._rule_id_name_set 244 245 @classmethod 246 def custom_collection_from_list( 247 cls, new_collection_class_name: str, rules: Sequence[Rule] 248 ) -> RuleCollection: 249 """Creates a custom class inherited from RuleCollection with the list of rules.""" 250 return dataclasses.make_dataclass( 251 new_collection_class_name, 252 [ 253 ( 254 formatter.kebab_case_to_snake_case(rule.name), 255 type(rule), 256 dataclasses.field(default=rule), 257 ) 258 for rule in rules 259 ], 260 bases=(cls,), 261 )() 262 263 264class Invocation: 265 # TODO: Implement this. 266 # Tracks top level call arguments and diagnostic options. 267 def __init__(self) -> None: 268 raise NotImplementedError 269 270 271@dataclasses.dataclass 272class DiagnosticOptions: 273 """Options for diagnostic context. 274 275 Attributes: 276 verbosity_level: Set the amount of information logged for each diagnostics, 277 equivalent to the 'level' in Python logging module. 278 warnings_as_errors: When True, warning diagnostics are treated as error diagnostics. 279 """ 280 281 verbosity_level: int = dataclasses.field(default=logging.INFO) 282 """Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module.""" 283 284 warnings_as_errors: bool = dataclasses.field(default=False) 285 """If True, warning diagnostics are treated as error diagnostics.""" 286