1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations. 4*da0073e9SAndroid Build Coastguard WorkerThe rules are defined in torch/onnx/_internal/diagnostics/rules.yaml. 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard WorkerUsage: 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerpython -m tools.onnx.gen_diagnostics \ 9*da0073e9SAndroid Build Coastguard Worker torch/onnx/_internal/diagnostics/rules.yaml \ 10*da0073e9SAndroid Build Coastguard Worker torch/onnx/_internal/diagnostics \ 11*da0073e9SAndroid Build Coastguard Worker torch/csrc/onnx/diagnostics/generated \ 12*da0073e9SAndroid Build Coastguard Worker torch/docs/source 13*da0073e9SAndroid Build Coastguard Worker""" 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerimport argparse 16*da0073e9SAndroid Build Coastguard Workerimport os 17*da0073e9SAndroid Build Coastguard Workerimport string 18*da0073e9SAndroid Build Coastguard Workerimport subprocess 19*da0073e9SAndroid Build Coastguard Workerimport textwrap 20*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Mapping, Sequence 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerimport yaml 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerfrom torchgen import utils as torchgen_utils 25*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlLoader 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker_RULES_GENERATED_COMMENT = """\ 29*da0073e9SAndroid Build Coastguard WorkerGENERATED CODE - DO NOT EDIT DIRECTLY 30*da0073e9SAndroid Build Coastguard WorkerThis file is generated by gen_diagnostics.py. 31*da0073e9SAndroid Build Coastguard WorkerSee tools/onnx/gen_diagnostics.py for more information. 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard WorkerDiagnostic rules for PyTorch ONNX export. 34*da0073e9SAndroid Build Coastguard Worker""" 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker_PY_RULE_CLASS_COMMENT = """\ 37*da0073e9SAndroid Build Coastguard WorkerGENERATED CODE - DO NOT EDIT DIRECTLY 38*da0073e9SAndroid Build Coastguard WorkerThe purpose of generating a class for each rule is to override the `format_message` 39*da0073e9SAndroid Build Coastguard Workermethod to provide more details in the signature about the format arguments. 40*da0073e9SAndroid Build Coastguard Worker""" 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker_PY_RULE_CLASS_TEMPLATE = """\ 43*da0073e9SAndroid Build Coastguard Workerclass _{pascal_case_name}(infra.Rule): 44*da0073e9SAndroid Build Coastguard Worker \"\"\"{short_description}\"\"\" 45*da0073e9SAndroid Build Coastguard Worker def format_message( # type: ignore[override] 46*da0073e9SAndroid Build Coastguard Worker self, 47*da0073e9SAndroid Build Coastguard Worker {message_arguments} 48*da0073e9SAndroid Build Coastguard Worker ) -> str: 49*da0073e9SAndroid Build Coastguard Worker \"\"\"Returns the formatted default message of this Rule. 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker Message template: {message_template} 52*da0073e9SAndroid Build Coastguard Worker \"\"\" 53*da0073e9SAndroid Build Coastguard Worker return self.message_default_template.format({message_arguments_assigned}) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def format( # type: ignore[override] 56*da0073e9SAndroid Build Coastguard Worker self, 57*da0073e9SAndroid Build Coastguard Worker level: infra.Level, 58*da0073e9SAndroid Build Coastguard Worker {message_arguments} 59*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[infra.Rule, infra.Level, str]: 60*da0073e9SAndroid Build Coastguard Worker \"\"\"Returns a tuple of (Rule, Level, message) for this Rule. 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker Message template: {message_template} 63*da0073e9SAndroid Build Coastguard Worker \"\"\" 64*da0073e9SAndroid Build Coastguard Worker return self, level, self.format_message({message_arguments_assigned}) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker""" 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ 69*da0073e9SAndroid Build Coastguard Worker{snake_case_name}: _{pascal_case_name} = dataclasses.field( 70*da0073e9SAndroid Build Coastguard Worker default=_{pascal_case_name}.from_sarif(**{sarif_dict}), 71*da0073e9SAndroid Build Coastguard Worker init=False, 72*da0073e9SAndroid Build Coastguard Worker) 73*da0073e9SAndroid Build Coastguard Worker\"\"\"{short_description}\"\"\" 74*da0073e9SAndroid Build Coastguard Worker""" 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker_CPP_RULE_TEMPLATE = """\ 77*da0073e9SAndroid Build Coastguard Worker/** 78*da0073e9SAndroid Build Coastguard Worker * @brief {short_description} 79*da0073e9SAndroid Build Coastguard Worker */ 80*da0073e9SAndroid Build Coastguard Worker{name}, 81*da0073e9SAndroid Build Coastguard Worker""" 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker_RuleType = Mapping[str, Any] 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workerdef _kebab_case_to_snake_case(name: str) -> str: 87*da0073e9SAndroid Build Coastguard Worker return name.replace("-", "_") 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Workerdef _kebab_case_to_pascal_case(name: str) -> str: 91*da0073e9SAndroid Build Coastguard Worker return "".join(word.capitalize() for word in name.split("-")) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Workerdef _format_rule_for_python_class(rule: _RuleType) -> str: 95*da0073e9SAndroid Build Coastguard Worker pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) 96*da0073e9SAndroid Build Coastguard Worker short_description = rule["short_description"]["text"] 97*da0073e9SAndroid Build Coastguard Worker message_template = rule["message_strings"]["default"]["text"] 98*da0073e9SAndroid Build Coastguard Worker field_names = [ 99*da0073e9SAndroid Build Coastguard Worker field_name 100*da0073e9SAndroid Build Coastguard Worker for _, field_name, _, _ in string.Formatter().parse(message_template) 101*da0073e9SAndroid Build Coastguard Worker if field_name is not None 102*da0073e9SAndroid Build Coastguard Worker ] 103*da0073e9SAndroid Build Coastguard Worker for field_name in field_names: 104*da0073e9SAndroid Build Coastguard Worker assert isinstance( 105*da0073e9SAndroid Build Coastguard Worker field_name, str 106*da0073e9SAndroid Build Coastguard Worker ), f"Unexpected field type {type(field_name)} from {field_name}. " 107*da0073e9SAndroid Build Coastguard Worker "Field name must be string.\nFull message template: {message_template}" 108*da0073e9SAndroid Build Coastguard Worker assert ( 109*da0073e9SAndroid Build Coastguard Worker not field_name.isnumeric() 110*da0073e9SAndroid Build Coastguard Worker ), f"Unexpected numeric field name {field_name}. " 111*da0073e9SAndroid Build Coastguard Worker "Only keyword name formatting is supported.\nFull message template: {message_template}" 112*da0073e9SAndroid Build Coastguard Worker message_arguments = ", ".join(field_names) 113*da0073e9SAndroid Build Coastguard Worker message_arguments_assigned = ", ".join( 114*da0073e9SAndroid Build Coastguard Worker [f"{field_name}={field_name}" for field_name in field_names] 115*da0073e9SAndroid Build Coastguard Worker ) 116*da0073e9SAndroid Build Coastguard Worker return _PY_RULE_CLASS_TEMPLATE.format( 117*da0073e9SAndroid Build Coastguard Worker pascal_case_name=pascal_case_name, 118*da0073e9SAndroid Build Coastguard Worker short_description=short_description, 119*da0073e9SAndroid Build Coastguard Worker message_template=repr(message_template), 120*da0073e9SAndroid Build Coastguard Worker message_arguments=message_arguments, 121*da0073e9SAndroid Build Coastguard Worker message_arguments_assigned=message_arguments_assigned, 122*da0073e9SAndroid Build Coastguard Worker ) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Workerdef _format_rule_for_python_field(rule: _RuleType) -> str: 126*da0073e9SAndroid Build Coastguard Worker snake_case_name = _kebab_case_to_snake_case(rule["name"]) 127*da0073e9SAndroid Build Coastguard Worker pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) 128*da0073e9SAndroid Build Coastguard Worker short_description = rule["short_description"]["text"] 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( 131*da0073e9SAndroid Build Coastguard Worker snake_case_name=snake_case_name, 132*da0073e9SAndroid Build Coastguard Worker pascal_case_name=pascal_case_name, 133*da0073e9SAndroid Build Coastguard Worker sarif_dict=rule, 134*da0073e9SAndroid Build Coastguard Worker short_description=short_description, 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Workerdef _format_rule_for_cpp(rule: _RuleType) -> str: 139*da0073e9SAndroid Build Coastguard Worker name = f"k{_kebab_case_to_pascal_case(rule['name'])}" 140*da0073e9SAndroid Build Coastguard Worker short_description = rule["short_description"]["text"] 141*da0073e9SAndroid Build Coastguard Worker return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics_python( 145*da0073e9SAndroid Build Coastguard Worker rules: Sequence[_RuleType], out_py_dir: str, template_dir: str 146*da0073e9SAndroid Build Coastguard Worker) -> None: 147*da0073e9SAndroid Build Coastguard Worker rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] 148*da0073e9SAndroid Build Coastguard Worker rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker fm = torchgen_utils.FileManager( 151*da0073e9SAndroid Build Coastguard Worker install_dir=out_py_dir, template_dir=template_dir, dry_run=False 152*da0073e9SAndroid Build Coastguard Worker ) 153*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 154*da0073e9SAndroid Build Coastguard Worker "_rules.py", 155*da0073e9SAndroid Build Coastguard Worker "rules.py.in", 156*da0073e9SAndroid Build Coastguard Worker lambda: { 157*da0073e9SAndroid Build Coastguard Worker "generated_comment": _RULES_GENERATED_COMMENT, 158*da0073e9SAndroid Build Coastguard Worker "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, 159*da0073e9SAndroid Build Coastguard Worker "rule_classes": "\n".join(rule_class_lines), 160*da0073e9SAndroid Build Coastguard Worker "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), 161*da0073e9SAndroid Build Coastguard Worker }, 162*da0073e9SAndroid Build Coastguard Worker ) 163*da0073e9SAndroid Build Coastguard Worker _lint_file(os.path.join(out_py_dir, "_rules.py")) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics_cpp( 167*da0073e9SAndroid Build Coastguard Worker rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str 168*da0073e9SAndroid Build Coastguard Worker) -> None: 169*da0073e9SAndroid Build Coastguard Worker rule_lines = [_format_rule_for_cpp(rule) for rule in rules] 170*da0073e9SAndroid Build Coastguard Worker rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules] 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker fm = torchgen_utils.FileManager( 173*da0073e9SAndroid Build Coastguard Worker install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 176*da0073e9SAndroid Build Coastguard Worker "rules.h", 177*da0073e9SAndroid Build Coastguard Worker "rules.h.in", 178*da0073e9SAndroid Build Coastguard Worker lambda: { 179*da0073e9SAndroid Build Coastguard Worker "generated_comment": textwrap.indent( 180*da0073e9SAndroid Build Coastguard Worker _RULES_GENERATED_COMMENT, 181*da0073e9SAndroid Build Coastguard Worker " * ", 182*da0073e9SAndroid Build Coastguard Worker predicate=lambda x: True, # Don't ignore empty line 183*da0073e9SAndroid Build Coastguard Worker ), 184*da0073e9SAndroid Build Coastguard Worker "rules": textwrap.indent("\n".join(rule_lines), " " * 2), 185*da0073e9SAndroid Build Coastguard Worker "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4), 186*da0073e9SAndroid Build Coastguard Worker }, 187*da0073e9SAndroid Build Coastguard Worker ) 188*da0073e9SAndroid Build Coastguard Worker _lint_file(os.path.join(out_cpp_dir, "rules.h")) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics_docs( 192*da0073e9SAndroid Build Coastguard Worker rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str 193*da0073e9SAndroid Build Coastguard Worker) -> None: 194*da0073e9SAndroid Build Coastguard Worker # TODO: Add doc generation in a follow-up PR. 195*da0073e9SAndroid Build Coastguard Worker pass 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Workerdef _lint_file(file_path: str) -> None: 199*da0073e9SAndroid Build Coastguard Worker p = subprocess.Popen(["lintrunner", "-a", file_path]) 200*da0073e9SAndroid Build Coastguard Worker p.wait() 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics( 204*da0073e9SAndroid Build Coastguard Worker rules_path: str, 205*da0073e9SAndroid Build Coastguard Worker out_py_dir: str, 206*da0073e9SAndroid Build Coastguard Worker out_cpp_dir: str, 207*da0073e9SAndroid Build Coastguard Worker out_docs_dir: str, 208*da0073e9SAndroid Build Coastguard Worker) -> None: 209*da0073e9SAndroid Build Coastguard Worker with open(rules_path) as f: 210*da0073e9SAndroid Build Coastguard Worker rules = yaml.load(f, Loader=YamlLoader) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker gen_diagnostics_python( 215*da0073e9SAndroid Build Coastguard Worker rules, 216*da0073e9SAndroid Build Coastguard Worker out_py_dir, 217*da0073e9SAndroid Build Coastguard Worker template_dir, 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker gen_diagnostics_cpp( 221*da0073e9SAndroid Build Coastguard Worker rules, 222*da0073e9SAndroid Build Coastguard Worker out_cpp_dir, 223*da0073e9SAndroid Build Coastguard Worker template_dir, 224*da0073e9SAndroid Build Coastguard Worker ) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker gen_diagnostics_docs(rules, out_docs_dir, template_dir) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Workerdef main() -> None: 230*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files") 231*da0073e9SAndroid Build Coastguard Worker parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml") 232*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 233*da0073e9SAndroid Build Coastguard Worker "out_py_dir", 234*da0073e9SAndroid Build Coastguard Worker metavar="OUT_PY", 235*da0073e9SAndroid Build Coastguard Worker help="path to output directory for Python", 236*da0073e9SAndroid Build Coastguard Worker ) 237*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 238*da0073e9SAndroid Build Coastguard Worker "out_cpp_dir", 239*da0073e9SAndroid Build Coastguard Worker metavar="OUT_CPP", 240*da0073e9SAndroid Build Coastguard Worker help="path to output directory for C++", 241*da0073e9SAndroid Build Coastguard Worker ) 242*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 243*da0073e9SAndroid Build Coastguard Worker "out_docs_dir", 244*da0073e9SAndroid Build Coastguard Worker metavar="OUT_DOCS", 245*da0073e9SAndroid Build Coastguard Worker help="path to output directory for docs", 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 248*da0073e9SAndroid Build Coastguard Worker gen_diagnostics( 249*da0073e9SAndroid Build Coastguard Worker args.rules_path, 250*da0073e9SAndroid Build Coastguard Worker args.out_py_dir, 251*da0073e9SAndroid Build Coastguard Worker args.out_cpp_dir, 252*da0073e9SAndroid Build Coastguard Worker args.out_docs_dir, 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 257*da0073e9SAndroid Build Coastguard Worker main() 258