1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9import sys 10from collections import defaultdict 11from typing import Any, Dict, List, Optional 12 13import yaml 14 15try: 16 from yaml import CSafeLoader as Loader 17except ImportError: 18 from yaml import SafeLoader as Loader # type: ignore[misc] 19 20 21class BlankLineDumper(yaml.SafeDumper): 22 def write_line_break(self, data=None): 23 super().write_line_break(data) 24 # insert a new line between entries. 25 if len(self.indents) == 1: 26 super().write_line_break() 27 28 29def merge(functions_yaml_path: str, fallback_yaml_path: Optional[str], output_dir: str): 30 output_file = os.path.join(output_dir, "merged.yaml") 31 32 def get_canonical_opname(func: object) -> str: 33 """get the canonical name of an operator 34 "op" and "func" are two keywords we are supporting for yaml files. 35 To give an example: 36 - op: add.Tensor # mostly used for binding ATen ops to kernels 37 - func: add.Tensor(Tensor self, Tensor other, Scalar alpha) # mostly used for 38 defining custom ops. 39 40 These two will be supported 41 Args: 42 func (object): yaml object 43 44 Returns: 45 str: canonical name of the operator 46 """ 47 # pyre-ignore 48 opname = func["op"] if "op" in func else func["func"].split("(")[0] 49 if "::" not in opname: 50 opname = "aten::" + opname 51 return opname 52 53 with open(functions_yaml_path) as f: 54 functions_obj = yaml.load(f, Loader=Loader) 55 functions_dict: Dict[str, object] = defaultdict(object) 56 for func in functions_obj: 57 functions_dict[get_canonical_opname(func)] = func 58 if fallback_yaml_path is not None and os.path.exists(fallback_yaml_path): 59 with open(fallback_yaml_path) as f: 60 fallback_obj = yaml.load(f, Loader=Loader) 61 for func in fallback_obj: 62 opname = get_canonical_opname(func) 63 if opname not in functions_dict: 64 functions_dict[opname] = func 65 66 with open(output_file, "w") as f: 67 yaml.dump( 68 list(functions_dict.values()), 69 f, 70 Dumper=BlankLineDumper, 71 default_flow_style=False, 72 sort_keys=False, 73 width=1000, 74 ) 75 76 77def main(argv: List[Any]) -> None: 78 """Merge functions.yaml and fallback yaml. The output yaml will be a union of all entries in functions.yaml and fallback yaml, with operator entries in functions.yaml overriding entries with the same op name in fallback yaml. 79 E.g., 80 functions.yaml: 81 - op: add.Tensor 82 - kernel: add_impl 83 84 fallback yaml: 85 - op: add.Tensor 86 - kernel: add_fallback 87 - op: relu 88 - kernel: relu_fallback 89 90 Merged: 91 - op: add.Tensor 92 - kernel: add_impl 93 - op: relu 94 - kernel: relu_fallback 95 96 """ 97 parser = argparse.ArgumentParser( 98 description="Merge functions.yaml, custom_ops.yaml with fallback yaml, for codegen to consume." 99 ) 100 parser.add_argument( 101 "--functions-yaml-path", 102 "--functions_yaml_path", 103 help="path to the functions.yaml file to use.", 104 required=True, 105 ) 106 parser.add_argument( 107 "--fallback-yaml-path", 108 "--fallback_yaml_path", 109 help="path to fallback yaml file.", 110 required=False, 111 ) 112 parser.add_argument( 113 "--output_dir", 114 help=("The directory to store the output yaml file"), 115 required=True, 116 ) 117 118 options = parser.parse_args(argv) 119 assert options.functions_yaml_path is not None and os.path.exists( 120 options.functions_yaml_path 121 ) 122 merge(options.functions_yaml_path, options.fallback_yaml_path, options.output_dir) 123 124 125if __name__ == "__main__": 126 main(sys.argv[1:]) 127