xref: /aosp_15_r20/external/pytorch/tools/onnx/gen_diagnostics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations.
4*da0073e9SAndroid Build Coastguard WorkerThe rules are defined in torch/onnx/_internal/diagnostics/rules.yaml.
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard WorkerUsage:
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerpython -m tools.onnx.gen_diagnostics \
9*da0073e9SAndroid Build Coastguard Worker    torch/onnx/_internal/diagnostics/rules.yaml \
10*da0073e9SAndroid Build Coastguard Worker    torch/onnx/_internal/diagnostics \
11*da0073e9SAndroid Build Coastguard Worker    torch/csrc/onnx/diagnostics/generated \
12*da0073e9SAndroid Build Coastguard Worker    torch/docs/source
13*da0073e9SAndroid Build Coastguard Worker"""
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerimport argparse
16*da0073e9SAndroid Build Coastguard Workerimport os
17*da0073e9SAndroid Build Coastguard Workerimport string
18*da0073e9SAndroid Build Coastguard Workerimport subprocess
19*da0073e9SAndroid Build Coastguard Workerimport textwrap
20*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Mapping, Sequence
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerimport yaml
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerfrom torchgen import utils as torchgen_utils
25*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlLoader
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker_RULES_GENERATED_COMMENT = """\
29*da0073e9SAndroid Build Coastguard WorkerGENERATED CODE - DO NOT EDIT DIRECTLY
30*da0073e9SAndroid Build Coastguard WorkerThis file is generated by gen_diagnostics.py.
31*da0073e9SAndroid Build Coastguard WorkerSee tools/onnx/gen_diagnostics.py for more information.
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard WorkerDiagnostic rules for PyTorch ONNX export.
34*da0073e9SAndroid Build Coastguard Worker"""
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker_PY_RULE_CLASS_COMMENT = """\
37*da0073e9SAndroid Build Coastguard WorkerGENERATED CODE - DO NOT EDIT DIRECTLY
38*da0073e9SAndroid Build Coastguard WorkerThe purpose of generating a class for each rule is to override the `format_message`
39*da0073e9SAndroid Build Coastguard Workermethod to provide more details in the signature about the format arguments.
40*da0073e9SAndroid Build Coastguard Worker"""
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker_PY_RULE_CLASS_TEMPLATE = """\
43*da0073e9SAndroid Build Coastguard Workerclass _{pascal_case_name}(infra.Rule):
44*da0073e9SAndroid Build Coastguard Worker    \"\"\"{short_description}\"\"\"
45*da0073e9SAndroid Build Coastguard Worker    def format_message(  # type: ignore[override]
46*da0073e9SAndroid Build Coastguard Worker        self,
47*da0073e9SAndroid Build Coastguard Worker        {message_arguments}
48*da0073e9SAndroid Build Coastguard Worker    ) -> str:
49*da0073e9SAndroid Build Coastguard Worker        \"\"\"Returns the formatted default message of this Rule.
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker        Message template: {message_template}
52*da0073e9SAndroid Build Coastguard Worker        \"\"\"
53*da0073e9SAndroid Build Coastguard Worker        return self.message_default_template.format({message_arguments_assigned})
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def format(  # type: ignore[override]
56*da0073e9SAndroid Build Coastguard Worker        self,
57*da0073e9SAndroid Build Coastguard Worker        level: infra.Level,
58*da0073e9SAndroid Build Coastguard Worker        {message_arguments}
59*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[infra.Rule, infra.Level, str]:
60*da0073e9SAndroid Build Coastguard Worker        \"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        Message template: {message_template}
63*da0073e9SAndroid Build Coastguard Worker        \"\"\"
64*da0073e9SAndroid Build Coastguard Worker        return self, level, self.format_message({message_arguments_assigned})
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker"""
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\
69*da0073e9SAndroid Build Coastguard Worker{snake_case_name}: _{pascal_case_name} = dataclasses.field(
70*da0073e9SAndroid Build Coastguard Worker    default=_{pascal_case_name}.from_sarif(**{sarif_dict}),
71*da0073e9SAndroid Build Coastguard Worker    init=False,
72*da0073e9SAndroid Build Coastguard Worker)
73*da0073e9SAndroid Build Coastguard Worker\"\"\"{short_description}\"\"\"
74*da0073e9SAndroid Build Coastguard Worker"""
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker_CPP_RULE_TEMPLATE = """\
77*da0073e9SAndroid Build Coastguard Worker/**
78*da0073e9SAndroid Build Coastguard Worker * @brief {short_description}
79*da0073e9SAndroid Build Coastguard Worker */
80*da0073e9SAndroid Build Coastguard Worker{name},
81*da0073e9SAndroid Build Coastguard Worker"""
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker_RuleType = Mapping[str, Any]
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workerdef _kebab_case_to_snake_case(name: str) -> str:
87*da0073e9SAndroid Build Coastguard Worker    return name.replace("-", "_")
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Workerdef _kebab_case_to_pascal_case(name: str) -> str:
91*da0073e9SAndroid Build Coastguard Worker    return "".join(word.capitalize() for word in name.split("-"))
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Workerdef _format_rule_for_python_class(rule: _RuleType) -> str:
95*da0073e9SAndroid Build Coastguard Worker    pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
96*da0073e9SAndroid Build Coastguard Worker    short_description = rule["short_description"]["text"]
97*da0073e9SAndroid Build Coastguard Worker    message_template = rule["message_strings"]["default"]["text"]
98*da0073e9SAndroid Build Coastguard Worker    field_names = [
99*da0073e9SAndroid Build Coastguard Worker        field_name
100*da0073e9SAndroid Build Coastguard Worker        for _, field_name, _, _ in string.Formatter().parse(message_template)
101*da0073e9SAndroid Build Coastguard Worker        if field_name is not None
102*da0073e9SAndroid Build Coastguard Worker    ]
103*da0073e9SAndroid Build Coastguard Worker    for field_name in field_names:
104*da0073e9SAndroid Build Coastguard Worker        assert isinstance(
105*da0073e9SAndroid Build Coastguard Worker            field_name, str
106*da0073e9SAndroid Build Coastguard Worker        ), f"Unexpected field type {type(field_name)} from {field_name}. "
107*da0073e9SAndroid Build Coastguard Worker        "Field name must be string.\nFull message template: {message_template}"
108*da0073e9SAndroid Build Coastguard Worker        assert (
109*da0073e9SAndroid Build Coastguard Worker            not field_name.isnumeric()
110*da0073e9SAndroid Build Coastguard Worker        ), f"Unexpected numeric field name {field_name}. "
111*da0073e9SAndroid Build Coastguard Worker        "Only keyword name formatting is supported.\nFull message template: {message_template}"
112*da0073e9SAndroid Build Coastguard Worker    message_arguments = ", ".join(field_names)
113*da0073e9SAndroid Build Coastguard Worker    message_arguments_assigned = ", ".join(
114*da0073e9SAndroid Build Coastguard Worker        [f"{field_name}={field_name}" for field_name in field_names]
115*da0073e9SAndroid Build Coastguard Worker    )
116*da0073e9SAndroid Build Coastguard Worker    return _PY_RULE_CLASS_TEMPLATE.format(
117*da0073e9SAndroid Build Coastguard Worker        pascal_case_name=pascal_case_name,
118*da0073e9SAndroid Build Coastguard Worker        short_description=short_description,
119*da0073e9SAndroid Build Coastguard Worker        message_template=repr(message_template),
120*da0073e9SAndroid Build Coastguard Worker        message_arguments=message_arguments,
121*da0073e9SAndroid Build Coastguard Worker        message_arguments_assigned=message_arguments_assigned,
122*da0073e9SAndroid Build Coastguard Worker    )
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Workerdef _format_rule_for_python_field(rule: _RuleType) -> str:
126*da0073e9SAndroid Build Coastguard Worker    snake_case_name = _kebab_case_to_snake_case(rule["name"])
127*da0073e9SAndroid Build Coastguard Worker    pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
128*da0073e9SAndroid Build Coastguard Worker    short_description = rule["short_description"]["text"]
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format(
131*da0073e9SAndroid Build Coastguard Worker        snake_case_name=snake_case_name,
132*da0073e9SAndroid Build Coastguard Worker        pascal_case_name=pascal_case_name,
133*da0073e9SAndroid Build Coastguard Worker        sarif_dict=rule,
134*da0073e9SAndroid Build Coastguard Worker        short_description=short_description,
135*da0073e9SAndroid Build Coastguard Worker    )
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef _format_rule_for_cpp(rule: _RuleType) -> str:
139*da0073e9SAndroid Build Coastguard Worker    name = f"k{_kebab_case_to_pascal_case(rule['name'])}"
140*da0073e9SAndroid Build Coastguard Worker    short_description = rule["short_description"]["text"]
141*da0073e9SAndroid Build Coastguard Worker    return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics_python(
145*da0073e9SAndroid Build Coastguard Worker    rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
146*da0073e9SAndroid Build Coastguard Worker) -> None:
147*da0073e9SAndroid Build Coastguard Worker    rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
148*da0073e9SAndroid Build Coastguard Worker    rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    fm = torchgen_utils.FileManager(
151*da0073e9SAndroid Build Coastguard Worker        install_dir=out_py_dir, template_dir=template_dir, dry_run=False
152*da0073e9SAndroid Build Coastguard Worker    )
153*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
154*da0073e9SAndroid Build Coastguard Worker        "_rules.py",
155*da0073e9SAndroid Build Coastguard Worker        "rules.py.in",
156*da0073e9SAndroid Build Coastguard Worker        lambda: {
157*da0073e9SAndroid Build Coastguard Worker            "generated_comment": _RULES_GENERATED_COMMENT,
158*da0073e9SAndroid Build Coastguard Worker            "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT,
159*da0073e9SAndroid Build Coastguard Worker            "rule_classes": "\n".join(rule_class_lines),
160*da0073e9SAndroid Build Coastguard Worker            "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4),
161*da0073e9SAndroid Build Coastguard Worker        },
162*da0073e9SAndroid Build Coastguard Worker    )
163*da0073e9SAndroid Build Coastguard Worker    _lint_file(os.path.join(out_py_dir, "_rules.py"))
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics_cpp(
167*da0073e9SAndroid Build Coastguard Worker    rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
168*da0073e9SAndroid Build Coastguard Worker) -> None:
169*da0073e9SAndroid Build Coastguard Worker    rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
170*da0073e9SAndroid Build Coastguard Worker    rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    fm = torchgen_utils.FileManager(
173*da0073e9SAndroid Build Coastguard Worker        install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False
174*da0073e9SAndroid Build Coastguard Worker    )
175*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
176*da0073e9SAndroid Build Coastguard Worker        "rules.h",
177*da0073e9SAndroid Build Coastguard Worker        "rules.h.in",
178*da0073e9SAndroid Build Coastguard Worker        lambda: {
179*da0073e9SAndroid Build Coastguard Worker            "generated_comment": textwrap.indent(
180*da0073e9SAndroid Build Coastguard Worker                _RULES_GENERATED_COMMENT,
181*da0073e9SAndroid Build Coastguard Worker                " * ",
182*da0073e9SAndroid Build Coastguard Worker                predicate=lambda x: True,  # Don't ignore empty line
183*da0073e9SAndroid Build Coastguard Worker            ),
184*da0073e9SAndroid Build Coastguard Worker            "rules": textwrap.indent("\n".join(rule_lines), " " * 2),
185*da0073e9SAndroid Build Coastguard Worker            "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4),
186*da0073e9SAndroid Build Coastguard Worker        },
187*da0073e9SAndroid Build Coastguard Worker    )
188*da0073e9SAndroid Build Coastguard Worker    _lint_file(os.path.join(out_cpp_dir, "rules.h"))
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics_docs(
192*da0073e9SAndroid Build Coastguard Worker    rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str
193*da0073e9SAndroid Build Coastguard Worker) -> None:
194*da0073e9SAndroid Build Coastguard Worker    # TODO: Add doc generation in a follow-up PR.
195*da0073e9SAndroid Build Coastguard Worker    pass
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Workerdef _lint_file(file_path: str) -> None:
199*da0073e9SAndroid Build Coastguard Worker    p = subprocess.Popen(["lintrunner", "-a", file_path])
200*da0073e9SAndroid Build Coastguard Worker    p.wait()
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Workerdef gen_diagnostics(
204*da0073e9SAndroid Build Coastguard Worker    rules_path: str,
205*da0073e9SAndroid Build Coastguard Worker    out_py_dir: str,
206*da0073e9SAndroid Build Coastguard Worker    out_cpp_dir: str,
207*da0073e9SAndroid Build Coastguard Worker    out_docs_dir: str,
208*da0073e9SAndroid Build Coastguard Worker) -> None:
209*da0073e9SAndroid Build Coastguard Worker    with open(rules_path) as f:
210*da0073e9SAndroid Build Coastguard Worker        rules = yaml.load(f, Loader=YamlLoader)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    gen_diagnostics_python(
215*da0073e9SAndroid Build Coastguard Worker        rules,
216*da0073e9SAndroid Build Coastguard Worker        out_py_dir,
217*da0073e9SAndroid Build Coastguard Worker        template_dir,
218*da0073e9SAndroid Build Coastguard Worker    )
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker    gen_diagnostics_cpp(
221*da0073e9SAndroid Build Coastguard Worker        rules,
222*da0073e9SAndroid Build Coastguard Worker        out_cpp_dir,
223*da0073e9SAndroid Build Coastguard Worker        template_dir,
224*da0073e9SAndroid Build Coastguard Worker    )
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    gen_diagnostics_docs(rules, out_docs_dir, template_dir)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
230*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files")
231*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml")
232*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
233*da0073e9SAndroid Build Coastguard Worker        "out_py_dir",
234*da0073e9SAndroid Build Coastguard Worker        metavar="OUT_PY",
235*da0073e9SAndroid Build Coastguard Worker        help="path to output directory for Python",
236*da0073e9SAndroid Build Coastguard Worker    )
237*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
238*da0073e9SAndroid Build Coastguard Worker        "out_cpp_dir",
239*da0073e9SAndroid Build Coastguard Worker        metavar="OUT_CPP",
240*da0073e9SAndroid Build Coastguard Worker        help="path to output directory for C++",
241*da0073e9SAndroid Build Coastguard Worker    )
242*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
243*da0073e9SAndroid Build Coastguard Worker        "out_docs_dir",
244*da0073e9SAndroid Build Coastguard Worker        metavar="OUT_DOCS",
245*da0073e9SAndroid Build Coastguard Worker        help="path to output directory for docs",
246*da0073e9SAndroid Build Coastguard Worker    )
247*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
248*da0073e9SAndroid Build Coastguard Worker    gen_diagnostics(
249*da0073e9SAndroid Build Coastguard Worker        args.rules_path,
250*da0073e9SAndroid Build Coastguard Worker        args.out_py_dir,
251*da0073e9SAndroid Build Coastguard Worker        args.out_cpp_dir,
252*da0073e9SAndroid Build Coastguard Worker        args.out_docs_dir,
253*da0073e9SAndroid Build Coastguard Worker    )
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
257*da0073e9SAndroid Build Coastguard Worker    main()
258