xref: /aosp_15_r20/external/pytorch/tools/code_analyzer/gen_oplist.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import json
7import os
8import sys
9from functools import reduce
10from typing import Any
11
12import yaml
13from tools.lite_interpreter.gen_selected_mobile_ops_header import (
14    write_selected_mobile_ops,
15)
16
17from torchgen.selective_build.selector import (
18    combine_selective_builders,
19    SelectiveBuilder,
20)
21
22
23def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:
24    return set(selective_builder.operators.keys())
25
26
27def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:
28    ops = []
29    for op_name, op in selective_builder.operators.items():
30        if op.is_used_for_training:
31            ops.append(op_name)
32    return set(ops)
33
34
35def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
36    ops = []
37    for op_name, op in selective_builder.operators.items():
38        if op.include_all_overloads:
39            ops.append(op_name)
40    if ops:
41        raise Exception(  # noqa: TRY002
42            (
43                "Operators that include all overloads are "
44                + "not allowed since --allow-include-all-overloads "
45                + "was specified: {}"
46            ).format(", ".join(ops))
47        )
48
49
50def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:
51    supported_mobile_models_source = """/*
52 * Generated by gen_oplist.py
53 */
54#include "fb/supported_mobile_models/SupportedMobileModels.h"
55
56
57struct SupportedMobileModelCheckerRegistry {{
58  SupportedMobileModelCheckerRegistry() {{
59    auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton();
60    ref.set_supported_md5_hashes(std::unordered_set<std::string>{{
61      {supported_hashes_template}
62    }});
63  }}
64}};
65
66// This is a global object, initializing which causes the registration to happen.
67SupportedMobileModelCheckerRegistry register_model_versions;
68
69
70"""
71
72    # Generate SupportedMobileModelsRegistration.cpp
73    md5_hashes = set()
74    for model_dict in model_dicts:
75        if "debug_info" in model_dict:
76            debug_info = json.loads(model_dict["debug_info"][0])
77            if debug_info["is_new_style_rule"]:
78                for asset_info in debug_info["asset_info"].values():
79                    md5_hashes.update(asset_info["md5_hash"])
80
81    supported_hashes = ""
82    for md5 in md5_hashes:
83        supported_hashes += f'"{md5}",\n'
84    with open(
85        os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"
86    ) as out_file:
87        source = supported_mobile_models_source.format(
88            supported_hashes_template=supported_hashes
89        )
90        out_file.write(source.encode("utf-8"))
91
92
93def main(argv: list[Any]) -> None:
94    """This binary generates 3 files:
95
96    1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
97       dtypes captured by tracing
98    2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
99    """
100    parser = argparse.ArgumentParser(description="Generate operator lists")
101    parser.add_argument(
102        "--output-dir",
103        "--output_dir",
104        help=(
105            "The directory to store the output yaml files (selected_mobile_ops.h, "
106            + "selected_kernel_dtypes.h, selected_operators.yaml)"
107        ),
108        required=True,
109    )
110    parser.add_argument(
111        "--model-file-list-path",
112        "--model_file_list_path",
113        help=(
114            "Path to a file that contains the locations of individual "
115            + "model YAML files that contain the set of used operators. This "
116            + "file path must have a leading @-symbol, which will be stripped "
117            + "out before processing."
118        ),
119        required=True,
120    )
121    parser.add_argument(
122        "--allow-include-all-overloads",
123        "--allow_include_all_overloads",
124        help=(
125            "Flag to allow operators that include all overloads. "
126            + "If not set, operators registered without using the traced style will"
127            + "break the build."
128        ),
129        action="store_true",
130        default=False,
131        required=False,
132    )
133    options = parser.parse_args(argv)
134
135    if os.path.isfile(options.model_file_list_path):
136        print("Processing model file: ", options.model_file_list_path)
137        model_dicts = []
138        model_dict = yaml.safe_load(open(options.model_file_list_path))
139        model_dicts.append(model_dict)
140    else:
141        print("Processing model directory: ", options.model_file_list_path)
142        assert options.model_file_list_path[0] == "@"
143        model_file_list_path = options.model_file_list_path[1:]
144
145        model_dicts = []
146        with open(model_file_list_path) as model_list_file:
147            model_file_names = model_list_file.read().split()
148            for model_file_name in model_file_names:
149                with open(model_file_name, "rb") as model_file:
150                    model_dict = yaml.safe_load(model_file)
151                    model_dicts.append(model_dict)
152
153    selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]
154
155    # While we have the model_dicts generate the supported mobile models api
156    gen_supported_mobile_models(model_dicts, options.output_dir)
157
158    # We may have 0 selective builders since there may not be any viable
159    # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
160    # This is potentially an error, and we should probably raise an assertion
161    # failure here. However, this needs to be investigated further.
162    selective_builder = SelectiveBuilder.from_yaml_dict({})
163    if len(selective_builders) > 0:
164        selective_builder = reduce(
165            combine_selective_builders,
166            selective_builders,
167        )
168
169    if not options.allow_include_all_overloads:
170        throw_if_any_op_includes_overloads(selective_builder)
171    with open(
172        os.path.join(options.output_dir, "selected_operators.yaml"), "wb"
173    ) as out_file:
174        out_file.write(
175            yaml.safe_dump(
176                selective_builder.to_dict(), default_flow_style=False
177            ).encode("utf-8"),
178        )
179
180    write_selected_mobile_ops(
181        os.path.join(options.output_dir, "selected_mobile_ops.h"),
182        selective_builder,
183    )
184
185
186if __name__ == "__main__":
187    main(sys.argv[1:])
188