xref: /aosp_15_r20/external/executorch/codegen/tools/gen_oplist.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import json
9import os
10import sys
11from enum import IntEnum
12from typing import Any, Dict, List, Optional, Set
13
14import yaml
15from torchgen.executorch.parse import strip_et_fields
16
17from torchgen.gen import LineLoader, parse_native_yaml_struct
18from torchgen.selective_build.operator import SelectiveBuildOperator
19from torchgen.selective_build.selector import merge_et_kernel_metadata
20
21# Output YAML file format:
22# ------------------------
23#
24# <BEGIN FILE CONTENTS>
25# include_all_non_op_selectives: False
26# include_all_operators: False
27# debug_info:
28#   - model1@v100
29#   - model2@v50
30# operators:
31#   aten::add:
32#     is_root_operator: Yes
33#     is_used_for_training: Yes
34#     include_all_overloads: No
35#     debug_info:
36#       - model1@v100
37#       - model2@v50
38#   aten::add.int:
39#     is_root_operator: No
40#     is_used_for_training: No
41#     include_all_overloads: Yes
42# et_kernel_metadata:
43#   aten::add.out:
44#     # A list of different kernel keys (tensors with dtype-enum/dim-order) combinations used in model
45#       - v1/6;0,1|6;0,1|6;0,1|6;0,1  # Float, 0, 1
46#       - v1/3;0,1|3;0,1|3;0,1|3;0,1  # Int, 0, 1
47#   aten::mul.out:
48#       - v1/6;0,1|6;0,1|6;0,1|6;0,1  # Float, 0, 1
49# <END FILE CONTENTS>
50
51
52class ScalarType(IntEnum):
53    Byte = 0
54    Char = 1
55    Short = 2
56    Int = 3
57    Long = 4
58    Float = 6
59    Double = 7
60    Bool = 11
61    # TODO(jakeszwe): Verify these are unused and then remove support
62    QInt8 = 12
63    QUInt8 = 13
64    QInt32 = 14
65    QUInt4X2 = 16
66    QUInt2X4 = 17
67    # Types currently not implemented.
68    # Half = 5
69    # ComplexHalf = 8
70    # ComplexFloat = 9
71    # ComplexDouble = 10
72    # BFloat16 = 15
73
74
75class KernelType(IntEnum):
76    TENSOR = 5
77    TENSOR_LIST = 10
78    OPTIONAL_TENSOR_LIST = 11
79
80
81def _get_operators(model_file: str) -> List[str]:
82    from executorch.codegen.tools.selective_build import (
83        _get_program_from_buffer,
84        _get_program_operators,
85    )
86
87    print("Processing model file: ", model_file)
88    with open(model_file, "rb") as f:
89        buf = f.read()
90
91    program = _get_program_from_buffer(buf)
92    operators = _get_program_operators(program)
93    print(f"Model file loaded, operators are: {operators}")
94    return operators
95
96
97def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]:
98
99    from executorch.codegen.tools.selective_build import (
100        _get_io_metadata_for_program_operators,
101        _get_program_from_buffer,
102        _IOMetaData,
103    )
104
105    with open(model_file, "rb") as f:
106        buf = f.read()
107
108    program = _get_program_from_buffer(buf)
109    operators_with_io_metadata = _get_io_metadata_for_program_operators(program)
110
111    op_kernel_key_list: Dict[str, List[str]] = {}
112
113    specialized_kernels: Set[List[_IOMetaData]]
114    for op_name, specialized_kernels in operators_with_io_metadata.items():
115        print(op_name)
116        if op_name not in op_kernel_key_list:
117            op_kernel_key_list[op_name] = []
118
119        for specialized_kernel in specialized_kernels:
120            version = "v1"
121            kernel_key = version + "/"
122            for io_metadata in specialized_kernel:
123                if io_metadata.kernel_type in [
124                    KernelType.TENSOR,
125                    KernelType.TENSOR_LIST,
126                    KernelType.OPTIONAL_TENSOR_LIST,
127                ]:
128                    dim_order = ",".join(map(str, io_metadata.dim_order))
129                    kernel_key += f"{io_metadata.dtype};{dim_order}|"
130            op_kernel_key_list[op_name].append(kernel_key[:-1])
131
132    return op_kernel_key_list
133
134
135def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[str]]:
136    ops = []
137    with open(ops_yaml_path, "r") as f:
138        es = yaml.load(f, Loader=LineLoader)
139        func_entries = []
140        for e in es:
141            if "op" in e:
142                ops.append(("aten::" if "::" not in e.get("op") else "") + e.get("op"))
143            else:
144                func_entries.append(e)
145        strip_et_fields(es)
146        parsed_yaml = parse_native_yaml_struct(
147            func_entries, set(), None, path=ops_yaml_path, skip_native_fns_gen=True
148        )
149    ops.extend([f"{f.namespace}::{f.func.name}" for f in parsed_yaml.native_functions])
150    # TODO (larryliu): accept the new op yaml syntax
151    return {op: ["default"] for op in ops}
152
153
154def _dump_yaml(
155    op_list: List[str],
156    output_path: str,
157    model_name: Optional[str] = None,
158    et_kernel_metadata: Optional[Dict[str, List[str]]] = None,
159    include_all_operators: bool = False,
160):
161    # no debug info yet
162    output = {}
163    operators: Dict[str, Dict[str, object]] = {}
164    for op_name in op_list:
165        op = SelectiveBuildOperator.from_yaml_dict(
166            op_name,
167            {
168                "is_root_operator": True,
169                "is_used_for_training": True,
170                "include_all_overloads": False,
171                "debug_info": [model_name],
172            },
173        )
174        operators[op_name] = op.to_dict()
175
176    output["operators"] = operators
177    output["custom_classes"] = []
178    output["build_features"] = []
179    output["include_all_non_op_selectives"] = False
180    output["include_all_operators"] = include_all_operators
181    output["kernel_metadata"] = {}
182    output["et_kernel_metadata"] = et_kernel_metadata
183    with open(output_path, "wb") as out_file:
184        out_file.write(
185            yaml.safe_dump(
186                output,
187                default_flow_style=False,
188            ).encode("utf-8")
189        )
190
191
192def gen_oplist(
193    output_path: str,
194    model_file_path: Optional[str] = None,
195    ops_schema_yaml_path: Optional[str] = None,
196    root_ops: Optional[str] = None,
197    ops_dict: Optional[str] = None,
198    include_all_operators: bool = False,
199):
200    assert (
201        model_file_path
202        or ops_schema_yaml_path
203        or root_ops
204        or ops_dict
205        or include_all_operators
206    ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
207
208    assert output_path, "Need to provide output_path for dumped yaml file."
209    op_set = set()
210    source_name = None
211    et_kernel_metadata = {}
212    if root_ops:
213        # decide delimiter
214        delimiter = "," if "," in root_ops else " "
215        print(root_ops)
216        op_set.update(
217            set(filter(lambda x: len(x) > 0, map(str.strip, root_ops.split(delimiter))))
218        )
219        et_kernel_metadata = merge_et_kernel_metadata(
220            et_kernel_metadata, {op: ["default"] for op in op_set}
221        )
222    if ops_dict:
223        ops_and_metadata = json.loads(ops_dict)
224        for op, metadata in ops_and_metadata.items():
225            op_set.update({op})
226            op_metadata = metadata if len(metadata) > 0 else ["default"]
227            et_kernel_metadata = merge_et_kernel_metadata(
228                et_kernel_metadata, {op: op_metadata}
229            )
230    if model_file_path:
231        assert os.path.isfile(
232            model_file_path
233        ), f"The value for --model_file_path needs to be a valid file, got {model_file_path}"
234        op_set.update(_get_operators(model_file_path))
235        source_name = model_file_path
236        et_kernel_metadata = merge_et_kernel_metadata(
237            et_kernel_metadata, _get_kernel_metadata_for_model(model_file_path)
238        )
239    if ops_schema_yaml_path:
240        assert os.path.isfile(
241            ops_schema_yaml_path
242        ), f"The value for --ops_schema_yaml_path needs to be a valid file, got {ops_schema_yaml_path}"
243        et_kernel_metadata = merge_et_kernel_metadata(
244            et_kernel_metadata,
245            _get_et_kernel_metadata_from_ops_yaml(ops_schema_yaml_path),
246        )
247        op_set.update(et_kernel_metadata.keys())
248        source_name = ops_schema_yaml_path
249    _dump_yaml(
250        sorted(op_set),
251        output_path,
252        os.path.basename(source_name) if source_name else None,
253        et_kernel_metadata,
254        include_all_operators,
255    )
256
257
258def main(args: List[Any]) -> None:
259    """This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
260    It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
261    it can be used in gen.py.
262    """
263    parser = argparse.ArgumentParser(
264        description="Generate operator list from a model file"
265    )
266    parser.add_argument(
267        "--output_path",
268        help=("The path to the output yaml file (selected_operators.yaml)"),
269        required=True,
270    )
271    parser.add_argument(
272        "--model_file_path",
273        help=("Path to an executorch program"),
274        required=False,
275    )
276    parser.add_argument(
277        "--ops_schema_yaml_path",
278        help=("Dump operator names from operator schema yaml path"),
279        required=False,
280    )
281    parser.add_argument(
282        "--root_ops",
283        help=("A comma separated list of root operators used by the model"),
284        required=False,
285    )
286    parser.add_argument(
287        "--ops_dict",
288        help=(
289            "A json object containing operators and their associated dtype and dim order"
290        ),
291        required=False,
292    )
293    parser.add_argument(
294        "--include-all-operators",
295        "--include_all_operators",
296        action="store_true",
297        default=False,
298        help="Set this flag to request inclusion of all operators (i.e. build is not selective).",
299        required=False,
300    )
301    options = parser.parse_args(args)
302
303    try:
304        gen_oplist(
305            output_path=options.output_path,
306            model_file_path=options.model_file_path,
307            ops_schema_yaml_path=options.ops_schema_yaml_path,
308            root_ops=options.root_ops,
309            ops_dict=options.ops_dict,
310            include_all_operators=options.include_all_operators,
311        )
312    except Exception as e:
313        command = ["python codegen/tools/gen_oplist.py"]
314        if options.model_file_path:
315            command.append(f"--model_file_path {options.model_file_path}")
316        if options.ops_schema_yaml_path:
317            command.append(f"--ops_schema_yaml_path {options.ops_schema_yaml_path}")
318        if options.root_ops:
319            command.append(f"--root_ops {options.root_ops}")
320        if options.ops_dict:
321            command.append(f"--ops_dict {options.ops_dict}")
322        if options.include_all_operators:
323            command.append("--include-all-operators")
324        repro_command = " ".join(command)
325        raise RuntimeError(
326            f"""Failed to generate selected_operators.yaml. Repro command:
327            {repro_command}
328            """
329        ) from e
330
331
332if __name__ == "__main__":
333    main(sys.argv[1:])
334