# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import argparse import os import sys from collections import defaultdict from typing import Any, Dict, List, Optional import yaml try: from yaml import CSafeLoader as Loader except ImportError: from yaml import SafeLoader as Loader # type: ignore[misc] class BlankLineDumper(yaml.SafeDumper): def write_line_break(self, data=None): super().write_line_break(data) # insert a new line between entries. if len(self.indents) == 1: super().write_line_break() def merge(functions_yaml_path: str, fallback_yaml_path: Optional[str], output_dir: str): output_file = os.path.join(output_dir, "merged.yaml") def get_canonical_opname(func: object) -> str: """get the canonical name of an operator "op" and "func" are two keywords we are supporting for yaml files. To give an example: - op: add.Tensor # mostly used for binding ATen ops to kernels - func: add.Tensor(Tensor self, Tensor other, Scalar alpha) # mostly used for defining custom ops. These two will be supported Args: func (object): yaml object Returns: str: canonical name of the operator """ # pyre-ignore opname = func["op"] if "op" in func else func["func"].split("(")[0] if "::" not in opname: opname = "aten::" + opname return opname with open(functions_yaml_path) as f: functions_obj = yaml.load(f, Loader=Loader) functions_dict: Dict[str, object] = defaultdict(object) for func in functions_obj: functions_dict[get_canonical_opname(func)] = func if fallback_yaml_path is not None and os.path.exists(fallback_yaml_path): with open(fallback_yaml_path) as f: fallback_obj = yaml.load(f, Loader=Loader) for func in fallback_obj: opname = get_canonical_opname(func) if opname not in functions_dict: functions_dict[opname] = func with open(output_file, "w") as f: yaml.dump( list(functions_dict.values()), f, Dumper=BlankLineDumper, default_flow_style=False, sort_keys=False, width=1000, ) def main(argv: List[Any]) -> None: """Merge functions.yaml and fallback yaml. The output yaml will be a union of all entries in functions.yaml and fallback yaml, with operator entries in functions.yaml overriding entries with the same op name in fallback yaml. E.g., functions.yaml: - op: add.Tensor - kernel: add_impl fallback yaml: - op: add.Tensor - kernel: add_fallback - op: relu - kernel: relu_fallback Merged: - op: add.Tensor - kernel: add_impl - op: relu - kernel: relu_fallback """ parser = argparse.ArgumentParser( description="Merge functions.yaml, custom_ops.yaml with fallback yaml, for codegen to consume." ) parser.add_argument( "--functions-yaml-path", "--functions_yaml_path", help="path to the functions.yaml file to use.", required=True, ) parser.add_argument( "--fallback-yaml-path", "--fallback_yaml_path", help="path to fallback yaml file.", required=False, ) parser.add_argument( "--output_dir", help=("The directory to store the output yaml file"), required=True, ) options = parser.parse_args(argv) assert options.functions_yaml_path is not None and os.path.exists( options.functions_yaml_path ) merge(options.functions_yaml_path, options.fallback_yaml_path, options.output_dir) if __name__ == "__main__": main(sys.argv[1:])