1import os 2from collections import OrderedDict 3from pathlib import Path 4 5import torch 6import torch._prims as prims 7from torchgen.gen import parse_native_yaml 8 9 10ROOT = Path(__file__).absolute().parent.parent.parent.parent 11NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml") 12TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml") 13 14BUILD_DIR = "build/ir" 15ATEN_OPS_CSV_FILE = "aten_ops.csv" 16PRIMS_OPS_CSV_FILE = "prims_ops.csv" 17 18 19def get_aten(): 20 parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH) 21 native_functions = parsed_yaml.native_functions 22 23 aten_ops = OrderedDict() 24 for function in native_functions: 25 if "core" in function.tags: 26 op_name = str(function.func.name) 27 aten_ops[op_name] = function 28 29 op_schema_pairs = [] 30 for key, op in sorted(aten_ops.items()): 31 op_name = f"aten.{key}" 32 schema = str(op.func).replace("*", r"\*") 33 34 op_schema_pairs.append((op_name, schema)) 35 36 return op_schema_pairs 37 38 39def get_prims(): 40 op_schema_pairs = [] 41 for op_name in prims.__all__: 42 op_overload = getattr(prims, op_name, None) 43 44 if not isinstance(op_overload, torch._ops.OpOverload): 45 continue 46 47 op_overloadpacket = op_overload.overloadpacket 48 49 op_name = str(op_overload).replace(".default", "") 50 schema = op_overloadpacket.schema.replace("*", r"\*") 51 52 op_schema_pairs.append((op_name, schema)) 53 54 return op_schema_pairs 55 56 57def main(): 58 aten_ops_list = get_aten() 59 prims_ops_list = get_prims() 60 61 os.makedirs(BUILD_DIR, exist_ok=True) 62 63 with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f: 64 f.write("Operator,Schema\n") 65 for name, schema in aten_ops_list: 66 f.write(f'"``{name}``","{schema}"\n') 67 68 with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f: 69 f.write("Operator,Schema\n") 70 for name, schema in prims_ops_list: 71 f.write(f'"``{name}``","{schema}"\n') 72 73 74if __name__ == "__main__": 75 main() 76