xref: /aosp_15_r20/external/executorch/codegen/tools/merge_yaml.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9import sys
10from collections import defaultdict
11from typing import Any, Dict, List, Optional
12
13import yaml
14
15try:
16    from yaml import CSafeLoader as Loader
17except ImportError:
18    from yaml import SafeLoader as Loader  # type: ignore[misc]
19
20
21class BlankLineDumper(yaml.SafeDumper):
22    def write_line_break(self, data=None):
23        super().write_line_break(data)
24        # insert a new line between entries.
25        if len(self.indents) == 1:
26            super().write_line_break()
27
28
29def merge(functions_yaml_path: str, fallback_yaml_path: Optional[str], output_dir: str):
30    output_file = os.path.join(output_dir, "merged.yaml")
31
32    def get_canonical_opname(func: object) -> str:
33        """get the canonical name of an operator
34        "op" and "func" are two keywords we are supporting for yaml files.
35        To give an example:
36        - op: add.Tensor # mostly used for binding ATen ops to kernels
37        - func: add.Tensor(Tensor self, Tensor other, Scalar alpha) # mostly used for
38            defining custom ops.
39
40        These two will be supported
41        Args:
42            func (object): yaml object
43
44        Returns:
45            str: canonical name of the operator
46        """
47        # pyre-ignore
48        opname = func["op"] if "op" in func else func["func"].split("(")[0]
49        if "::" not in opname:
50            opname = "aten::" + opname
51        return opname
52
53    with open(functions_yaml_path) as f:
54        functions_obj = yaml.load(f, Loader=Loader)
55        functions_dict: Dict[str, object] = defaultdict(object)
56        for func in functions_obj:
57            functions_dict[get_canonical_opname(func)] = func
58    if fallback_yaml_path is not None and os.path.exists(fallback_yaml_path):
59        with open(fallback_yaml_path) as f:
60            fallback_obj = yaml.load(f, Loader=Loader)
61            for func in fallback_obj:
62                opname = get_canonical_opname(func)
63                if opname not in functions_dict:
64                    functions_dict[opname] = func
65
66    with open(output_file, "w") as f:
67        yaml.dump(
68            list(functions_dict.values()),
69            f,
70            Dumper=BlankLineDumper,
71            default_flow_style=False,
72            sort_keys=False,
73            width=1000,
74        )
75
76
77def main(argv: List[Any]) -> None:
78    """Merge functions.yaml and fallback yaml. The output yaml will be a union of all entries in functions.yaml and fallback yaml, with operator entries in functions.yaml overriding entries with the same op name in fallback yaml.
79    E.g.,
80    functions.yaml:
81    - op: add.Tensor
82      - kernel: add_impl
83
84    fallback yaml:
85    - op: add.Tensor
86      - kernel: add_fallback
87    - op: relu
88      - kernel: relu_fallback
89
90    Merged:
91    - op: add.Tensor
92      - kernel: add_impl
93    - op: relu
94      - kernel: relu_fallback
95
96    """
97    parser = argparse.ArgumentParser(
98        description="Merge functions.yaml, custom_ops.yaml with fallback yaml, for codegen to consume."
99    )
100    parser.add_argument(
101        "--functions-yaml-path",
102        "--functions_yaml_path",
103        help="path to the functions.yaml file to use.",
104        required=True,
105    )
106    parser.add_argument(
107        "--fallback-yaml-path",
108        "--fallback_yaml_path",
109        help="path to fallback yaml file.",
110        required=False,
111    )
112    parser.add_argument(
113        "--output_dir",
114        help=("The directory to store the output yaml file"),
115        required=True,
116    )
117
118    options = parser.parse_args(argv)
119    assert options.functions_yaml_path is not None and os.path.exists(
120        options.functions_yaml_path
121    )
122    merge(options.functions_yaml_path, options.fallback_yaml_path, options.output_dir)
123
124
125if __name__ == "__main__":
126    main(sys.argv[1:])
127