xref: /aosp_15_r20/external/pytorch/tools/onnx/gen_diagnostics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations.
4The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml.
5
6Usage:
7
8python -m tools.onnx.gen_diagnostics \
9    torch/onnx/_internal/diagnostics/rules.yaml \
10    torch/onnx/_internal/diagnostics \
11    torch/csrc/onnx/diagnostics/generated \
12    torch/docs/source
13"""
14
15import argparse
16import os
17import string
18import subprocess
19import textwrap
20from typing import Any, Mapping, Sequence
21
22import yaml
23
24from torchgen import utils as torchgen_utils
25from torchgen.yaml_utils import YamlLoader
26
27
28_RULES_GENERATED_COMMENT = """\
29GENERATED CODE - DO NOT EDIT DIRECTLY
30This file is generated by gen_diagnostics.py.
31See tools/onnx/gen_diagnostics.py for more information.
32
33Diagnostic rules for PyTorch ONNX export.
34"""
35
36_PY_RULE_CLASS_COMMENT = """\
37GENERATED CODE - DO NOT EDIT DIRECTLY
38The purpose of generating a class for each rule is to override the `format_message`
39method to provide more details in the signature about the format arguments.
40"""
41
42_PY_RULE_CLASS_TEMPLATE = """\
43class _{pascal_case_name}(infra.Rule):
44    \"\"\"{short_description}\"\"\"
45    def format_message(  # type: ignore[override]
46        self,
47        {message_arguments}
48    ) -> str:
49        \"\"\"Returns the formatted default message of this Rule.
50
51        Message template: {message_template}
52        \"\"\"
53        return self.message_default_template.format({message_arguments_assigned})
54
55    def format(  # type: ignore[override]
56        self,
57        level: infra.Level,
58        {message_arguments}
59    ) -> Tuple[infra.Rule, infra.Level, str]:
60        \"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
61
62        Message template: {message_template}
63        \"\"\"
64        return self, level, self.format_message({message_arguments_assigned})
65
66"""
67
68_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\
69{snake_case_name}: _{pascal_case_name} = dataclasses.field(
70    default=_{pascal_case_name}.from_sarif(**{sarif_dict}),
71    init=False,
72)
73\"\"\"{short_description}\"\"\"
74"""
75
76_CPP_RULE_TEMPLATE = """\
77/**
78 * @brief {short_description}
79 */
80{name},
81"""
82
83_RuleType = Mapping[str, Any]
84
85
86def _kebab_case_to_snake_case(name: str) -> str:
87    return name.replace("-", "_")
88
89
90def _kebab_case_to_pascal_case(name: str) -> str:
91    return "".join(word.capitalize() for word in name.split("-"))
92
93
94def _format_rule_for_python_class(rule: _RuleType) -> str:
95    pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
96    short_description = rule["short_description"]["text"]
97    message_template = rule["message_strings"]["default"]["text"]
98    field_names = [
99        field_name
100        for _, field_name, _, _ in string.Formatter().parse(message_template)
101        if field_name is not None
102    ]
103    for field_name in field_names:
104        assert isinstance(
105            field_name, str
106        ), f"Unexpected field type {type(field_name)} from {field_name}. "
107        "Field name must be string.\nFull message template: {message_template}"
108        assert (
109            not field_name.isnumeric()
110        ), f"Unexpected numeric field name {field_name}. "
111        "Only keyword name formatting is supported.\nFull message template: {message_template}"
112    message_arguments = ", ".join(field_names)
113    message_arguments_assigned = ", ".join(
114        [f"{field_name}={field_name}" for field_name in field_names]
115    )
116    return _PY_RULE_CLASS_TEMPLATE.format(
117        pascal_case_name=pascal_case_name,
118        short_description=short_description,
119        message_template=repr(message_template),
120        message_arguments=message_arguments,
121        message_arguments_assigned=message_arguments_assigned,
122    )
123
124
125def _format_rule_for_python_field(rule: _RuleType) -> str:
126    snake_case_name = _kebab_case_to_snake_case(rule["name"])
127    pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
128    short_description = rule["short_description"]["text"]
129
130    return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format(
131        snake_case_name=snake_case_name,
132        pascal_case_name=pascal_case_name,
133        sarif_dict=rule,
134        short_description=short_description,
135    )
136
137
138def _format_rule_for_cpp(rule: _RuleType) -> str:
139    name = f"k{_kebab_case_to_pascal_case(rule['name'])}"
140    short_description = rule["short_description"]["text"]
141    return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description)
142
143
144def gen_diagnostics_python(
145    rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
146) -> None:
147    rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
148    rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]
149
150    fm = torchgen_utils.FileManager(
151        install_dir=out_py_dir, template_dir=template_dir, dry_run=False
152    )
153    fm.write_with_template(
154        "_rules.py",
155        "rules.py.in",
156        lambda: {
157            "generated_comment": _RULES_GENERATED_COMMENT,
158            "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT,
159            "rule_classes": "\n".join(rule_class_lines),
160            "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4),
161        },
162    )
163    _lint_file(os.path.join(out_py_dir, "_rules.py"))
164
165
166def gen_diagnostics_cpp(
167    rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
168) -> None:
169    rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
170    rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]
171
172    fm = torchgen_utils.FileManager(
173        install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False
174    )
175    fm.write_with_template(
176        "rules.h",
177        "rules.h.in",
178        lambda: {
179            "generated_comment": textwrap.indent(
180                _RULES_GENERATED_COMMENT,
181                " * ",
182                predicate=lambda x: True,  # Don't ignore empty line
183            ),
184            "rules": textwrap.indent("\n".join(rule_lines), " " * 2),
185            "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4),
186        },
187    )
188    _lint_file(os.path.join(out_cpp_dir, "rules.h"))
189
190
191def gen_diagnostics_docs(
192    rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str
193) -> None:
194    # TODO: Add doc generation in a follow-up PR.
195    pass
196
197
198def _lint_file(file_path: str) -> None:
199    p = subprocess.Popen(["lintrunner", "-a", file_path])
200    p.wait()
201
202
203def gen_diagnostics(
204    rules_path: str,
205    out_py_dir: str,
206    out_cpp_dir: str,
207    out_docs_dir: str,
208) -> None:
209    with open(rules_path) as f:
210        rules = yaml.load(f, Loader=YamlLoader)
211
212    template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
213
214    gen_diagnostics_python(
215        rules,
216        out_py_dir,
217        template_dir,
218    )
219
220    gen_diagnostics_cpp(
221        rules,
222        out_cpp_dir,
223        template_dir,
224    )
225
226    gen_diagnostics_docs(rules, out_docs_dir, template_dir)
227
228
229def main() -> None:
230    parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files")
231    parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml")
232    parser.add_argument(
233        "out_py_dir",
234        metavar="OUT_PY",
235        help="path to output directory for Python",
236    )
237    parser.add_argument(
238        "out_cpp_dir",
239        metavar="OUT_CPP",
240        help="path to output directory for C++",
241    )
242    parser.add_argument(
243        "out_docs_dir",
244        metavar="OUT_DOCS",
245        help="path to output directory for docs",
246    )
247    args = parser.parse_args()
248    gen_diagnostics(
249        args.rules_path,
250        args.out_py_dir,
251        args.out_cpp_dir,
252        args.out_docs_dir,
253    )
254
255
256if __name__ == "__main__":
257    main()
258