1#!/usr/bin/env python3 2 3""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations. 4The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml. 5 6Usage: 7 8python -m tools.onnx.gen_diagnostics \ 9 torch/onnx/_internal/diagnostics/rules.yaml \ 10 torch/onnx/_internal/diagnostics \ 11 torch/csrc/onnx/diagnostics/generated \ 12 torch/docs/source 13""" 14 15import argparse 16import os 17import string 18import subprocess 19import textwrap 20from typing import Any, Mapping, Sequence 21 22import yaml 23 24from torchgen import utils as torchgen_utils 25from torchgen.yaml_utils import YamlLoader 26 27 28_RULES_GENERATED_COMMENT = """\ 29GENERATED CODE - DO NOT EDIT DIRECTLY 30This file is generated by gen_diagnostics.py. 31See tools/onnx/gen_diagnostics.py for more information. 32 33Diagnostic rules for PyTorch ONNX export. 34""" 35 36_PY_RULE_CLASS_COMMENT = """\ 37GENERATED CODE - DO NOT EDIT DIRECTLY 38The purpose of generating a class for each rule is to override the `format_message` 39method to provide more details in the signature about the format arguments. 40""" 41 42_PY_RULE_CLASS_TEMPLATE = """\ 43class _{pascal_case_name}(infra.Rule): 44 \"\"\"{short_description}\"\"\" 45 def format_message( # type: ignore[override] 46 self, 47 {message_arguments} 48 ) -> str: 49 \"\"\"Returns the formatted default message of this Rule. 50 51 Message template: {message_template} 52 \"\"\" 53 return self.message_default_template.format({message_arguments_assigned}) 54 55 def format( # type: ignore[override] 56 self, 57 level: infra.Level, 58 {message_arguments} 59 ) -> Tuple[infra.Rule, infra.Level, str]: 60 \"\"\"Returns a tuple of (Rule, Level, message) for this Rule. 61 62 Message template: {message_template} 63 \"\"\" 64 return self, level, self.format_message({message_arguments_assigned}) 65 66""" 67 68_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ 69{snake_case_name}: _{pascal_case_name} = dataclasses.field( 70 default=_{pascal_case_name}.from_sarif(**{sarif_dict}), 71 init=False, 72) 73\"\"\"{short_description}\"\"\" 74""" 75 76_CPP_RULE_TEMPLATE = """\ 77/** 78 * @brief {short_description} 79 */ 80{name}, 81""" 82 83_RuleType = Mapping[str, Any] 84 85 86def _kebab_case_to_snake_case(name: str) -> str: 87 return name.replace("-", "_") 88 89 90def _kebab_case_to_pascal_case(name: str) -> str: 91 return "".join(word.capitalize() for word in name.split("-")) 92 93 94def _format_rule_for_python_class(rule: _RuleType) -> str: 95 pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) 96 short_description = rule["short_description"]["text"] 97 message_template = rule["message_strings"]["default"]["text"] 98 field_names = [ 99 field_name 100 for _, field_name, _, _ in string.Formatter().parse(message_template) 101 if field_name is not None 102 ] 103 for field_name in field_names: 104 assert isinstance( 105 field_name, str 106 ), f"Unexpected field type {type(field_name)} from {field_name}. " 107 "Field name must be string.\nFull message template: {message_template}" 108 assert ( 109 not field_name.isnumeric() 110 ), f"Unexpected numeric field name {field_name}. " 111 "Only keyword name formatting is supported.\nFull message template: {message_template}" 112 message_arguments = ", ".join(field_names) 113 message_arguments_assigned = ", ".join( 114 [f"{field_name}={field_name}" for field_name in field_names] 115 ) 116 return _PY_RULE_CLASS_TEMPLATE.format( 117 pascal_case_name=pascal_case_name, 118 short_description=short_description, 119 message_template=repr(message_template), 120 message_arguments=message_arguments, 121 message_arguments_assigned=message_arguments_assigned, 122 ) 123 124 125def _format_rule_for_python_field(rule: _RuleType) -> str: 126 snake_case_name = _kebab_case_to_snake_case(rule["name"]) 127 pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) 128 short_description = rule["short_description"]["text"] 129 130 return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( 131 snake_case_name=snake_case_name, 132 pascal_case_name=pascal_case_name, 133 sarif_dict=rule, 134 short_description=short_description, 135 ) 136 137 138def _format_rule_for_cpp(rule: _RuleType) -> str: 139 name = f"k{_kebab_case_to_pascal_case(rule['name'])}" 140 short_description = rule["short_description"]["text"] 141 return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) 142 143 144def gen_diagnostics_python( 145 rules: Sequence[_RuleType], out_py_dir: str, template_dir: str 146) -> None: 147 rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] 148 rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] 149 150 fm = torchgen_utils.FileManager( 151 install_dir=out_py_dir, template_dir=template_dir, dry_run=False 152 ) 153 fm.write_with_template( 154 "_rules.py", 155 "rules.py.in", 156 lambda: { 157 "generated_comment": _RULES_GENERATED_COMMENT, 158 "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, 159 "rule_classes": "\n".join(rule_class_lines), 160 "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), 161 }, 162 ) 163 _lint_file(os.path.join(out_py_dir, "_rules.py")) 164 165 166def gen_diagnostics_cpp( 167 rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str 168) -> None: 169 rule_lines = [_format_rule_for_cpp(rule) for rule in rules] 170 rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules] 171 172 fm = torchgen_utils.FileManager( 173 install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False 174 ) 175 fm.write_with_template( 176 "rules.h", 177 "rules.h.in", 178 lambda: { 179 "generated_comment": textwrap.indent( 180 _RULES_GENERATED_COMMENT, 181 " * ", 182 predicate=lambda x: True, # Don't ignore empty line 183 ), 184 "rules": textwrap.indent("\n".join(rule_lines), " " * 2), 185 "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4), 186 }, 187 ) 188 _lint_file(os.path.join(out_cpp_dir, "rules.h")) 189 190 191def gen_diagnostics_docs( 192 rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str 193) -> None: 194 # TODO: Add doc generation in a follow-up PR. 195 pass 196 197 198def _lint_file(file_path: str) -> None: 199 p = subprocess.Popen(["lintrunner", "-a", file_path]) 200 p.wait() 201 202 203def gen_diagnostics( 204 rules_path: str, 205 out_py_dir: str, 206 out_cpp_dir: str, 207 out_docs_dir: str, 208) -> None: 209 with open(rules_path) as f: 210 rules = yaml.load(f, Loader=YamlLoader) 211 212 template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") 213 214 gen_diagnostics_python( 215 rules, 216 out_py_dir, 217 template_dir, 218 ) 219 220 gen_diagnostics_cpp( 221 rules, 222 out_cpp_dir, 223 template_dir, 224 ) 225 226 gen_diagnostics_docs(rules, out_docs_dir, template_dir) 227 228 229def main() -> None: 230 parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files") 231 parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml") 232 parser.add_argument( 233 "out_py_dir", 234 metavar="OUT_PY", 235 help="path to output directory for Python", 236 ) 237 parser.add_argument( 238 "out_cpp_dir", 239 metavar="OUT_CPP", 240 help="path to output directory for C++", 241 ) 242 parser.add_argument( 243 "out_docs_dir", 244 metavar="OUT_DOCS", 245 help="path to output directory for docs", 246 ) 247 args = parser.parse_args() 248 gen_diagnostics( 249 args.rules_path, 250 args.out_py_dir, 251 args.out_cpp_dir, 252 args.out_docs_dir, 253 ) 254 255 256if __name__ == "__main__": 257 main() 258