xref: /aosp_15_r20/external/pytorch/docs/source/scripts/build_opsets.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2from collections import OrderedDict
3from pathlib import Path
4
5import torch
6import torch._prims as prims
7from torchgen.gen import parse_native_yaml
8
9
10ROOT = Path(__file__).absolute().parent.parent.parent.parent
11NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml")
12TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml")
13
14BUILD_DIR = "build/ir"
15ATEN_OPS_CSV_FILE = "aten_ops.csv"
16PRIMS_OPS_CSV_FILE = "prims_ops.csv"
17
18
19def get_aten():
20    parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
21    native_functions = parsed_yaml.native_functions
22
23    aten_ops = OrderedDict()
24    for function in native_functions:
25        if "core" in function.tags:
26            op_name = str(function.func.name)
27            aten_ops[op_name] = function
28
29    op_schema_pairs = []
30    for key, op in sorted(aten_ops.items()):
31        op_name = f"aten.{key}"
32        schema = str(op.func).replace("*", r"\*")
33
34        op_schema_pairs.append((op_name, schema))
35
36    return op_schema_pairs
37
38
39def get_prims():
40    op_schema_pairs = []
41    for op_name in prims.__all__:
42        op_overload = getattr(prims, op_name, None)
43
44        if not isinstance(op_overload, torch._ops.OpOverload):
45            continue
46
47        op_overloadpacket = op_overload.overloadpacket
48
49        op_name = str(op_overload).replace(".default", "")
50        schema = op_overloadpacket.schema.replace("*", r"\*")
51
52        op_schema_pairs.append((op_name, schema))
53
54    return op_schema_pairs
55
56
57def main():
58    aten_ops_list = get_aten()
59    prims_ops_list = get_prims()
60
61    os.makedirs(BUILD_DIR, exist_ok=True)
62
63    with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f:
64        f.write("Operator,Schema\n")
65        for name, schema in aten_ops_list:
66            f.write(f'"``{name}``","{schema}"\n')
67
68    with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f:
69        f.write("Operator,Schema\n")
70        for name, schema in prims_ops_list:
71            f.write(f'"``{name}``","{schema}"\n')
72
73
74if __name__ == "__main__":
75    main()
76