xref: /aosp_15_r20/external/executorch/codegen/tools/gen_selected_op_variants.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10import argparse
11import os
12import sys
13from typing import Any, List
14
15import yaml
16
17from torchgen.code_template import CodeTemplate
18
19
20ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n        && ($dtype_checks))"""
21ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str)
22
23selected_kernel_dtypes_h_template_str = """#pragma once
24/**
25 * Generated by executorch/codegen/tools/gen_selected_op_variants.py
26 */
27
28inline constexpr bool should_include_kernel_dtype(
29  const char *operator_name,
30  exec_aten::ScalarType scalar_type
31) {
32  return $body;
33}
34"""
35selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
36
37# enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h
38dtype_enum_to_type = {
39    "0": "Byte",
40    "1": "Char",
41    "2": "Short",
42    "3": "Int",
43    "4": "Long",
44    "5": "Half",
45    "6": "Float",
46    "7": "Double",
47    "8": "ComplexHalf",
48    "9": "ComplexFloat",
49    "10": "ComplexDouble",
50    "11": "Bool",
51    "12": "QInt8",
52    "13": "QUInt8",
53    "14": "QInt32",
54    "15": "BFloat16",
55    "16": "QUInt4x2",
56    "17": "QUInt2x4",
57    "18": "Bits1x8",
58    "19": "Bits2x4",
59    "20": "Bits4x2",
60    "21": "Bits8",
61    "22": "Bits16",
62}
63
64
65def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None:
66    with open(yaml_file_path, "r") as selected_operators_file:
67        # Collect et_kernel_metadata from selected_operators.yaml and extract dtypes
68        # Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1  # Float, 0, 1
69        selected_operators_dict = yaml.safe_load(selected_operators_file)
70        et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {})
71        assert isinstance(et_kernel_metadata, dict)
72        body = "true"
73        body_parts = []
74        for operator_name, kernel_metadata_str in et_kernel_metadata.items():
75            tensor_meta = []
76            for kernel_metadata in kernel_metadata_str:
77                if kernel_metadata == "default" or "/" not in kernel_metadata:
78                    break
79                else:
80                    x = kernel_metadata.split("/")[1]
81                    tensor_meta.extend(x.split("|"))
82            conditions = ["true"]
83            if len(tensor_meta) > 0:
84                dtype_set = set([x.split(";")[0] for x in tensor_meta])
85                dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set])
86                conditions = [
87                    "scalar_type == exec_aten::ScalarType::" + x for x in dtype_list
88                ]
89            body_parts.append(
90                ops_and_dtypes_template.substitute(
91                    operator_name=operator_name.replace("aten::", ""),
92                    dtype_checks=" || ".join(conditions),
93                ),
94            )
95            body = "\n || ".join(body_parts)
96        header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
97        selected_op_variants_path = os.path.join(output_dir, "selected_op_variants.h")
98        with open(selected_op_variants_path, "wb") as out_file:
99            out_file.write(header_contents.encode("utf-8"))
100
101
102def main(argv: List[Any]) -> None:
103    parser = argparse.ArgumentParser(description="Generate operator lists")
104    parser.add_argument(
105        "--yaml-file-path",
106        "--yaml_file_path",
107        help=("The directory where selected_operators.yaml was generated)"),
108        required=True,
109    )
110    parser.add_argument(
111        "--output-dir",
112        "--output_dir",
113        help=(
114            "The directory to store the output yaml files (selected_op_variants.h, "
115            + "selected_kernel_dtypes.h, selected_operators.yaml)"
116        ),
117        required=True,
118    )
119
120    options = parser.parse_args(argv)
121    write_selected_op_variants(options.yaml_file_path, options.output_dir)
122
123
124if __name__ == "__main__":
125    main(sys.argv[1:])
126