xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/codegen_external.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import argparse
4
5import torchgen.model as model
6from torchgen.gen import FileManager, parse_native_yaml
7
8
9def num_leading_spaces(line: str) -> int:
10    return len(line) - len(line.lstrip())
11
12
13def deindent(code: str) -> str:
14    lines = code.split("\n")
15    min_leading_spaces = min(map(num_leading_spaces, lines))
16    lines = [line[min_leading_spaces:] for line in lines]
17    return "\n".join(lines)
18
19
20def gen_external(native_functions_path, tags_path, external_path):
21    native_functions = parse_native_yaml(native_functions_path, tags_path)
22    func_decls = []
23    func_registrations = []
24    for func in native_functions:
25        schema = func.func
26        name = schema.name.name.base
27        args = schema.arguments
28        # Only supports extern calls for functions with out variants
29        if not schema.is_out_fn():
30            continue
31
32        # Doesn't currently support functions with more than one out parameter
33        if len(args.out) > 1:
34            continue
35
36        # Doesn't currently support kwarg arguments
37        if (
38            len(args.pre_tensor_options_kwarg_only) > 0
39            or len(args.post_tensor_options_kwarg_only) > 0
40        ):
41            continue
42        self_arg = [args.self_arg.argument] if args.self_arg is not None else []
43        args = (
44            list(args.pre_self_positional) + self_arg + list(args.post_self_positional)
45        )
46        tensor_args = [
47            arg
48            for arg in args
49            if isinstance(arg.type, model.BaseType)
50            and arg.type.name == model.BaseTy.Tensor
51        ]
52        if len(tensor_args) != len(args):
53            continue
54
55        arg_names = [None] * len(args)
56
57        tensor_decls = []
58        for idx, arg in enumerate(tensor_args):
59            s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];"
60            tensor_decls.append(s)
61            arg_names[idx] = arg.name
62        nl = "\n"
63
64        # print(tensor_decls, name, arg_names)
65        func_decl = f"""\
66void nnc_aten_{name}(
67    int64_t bufs_num,
68    void** buf_data,
69    int64_t* buf_ranks,
70    int64_t* buf_dims,
71    int64_t* buf_strides,
72    int8_t* buf_dtypes,
73    int64_t args_num,
74    int64_t* extra_args) {{
75  std::vector<at::Tensor> tensors =
76      constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
77  at::Tensor& r = tensors[0];
78  {nl.join(tensor_decls)}
79  try {{
80    at::{name}_out({', '.join(['r'] + arg_names)});
81  }} catch (...) {{
82  }}
83}}"""
84        func_registration = f"""\
85const static RegisterNNCExternalFunction nnc_{name}(
86    "nnc_aten_{name}",
87    nnc_aten_{name});"""
88        func_decls.append(func_decl)
89        func_registrations.append(func_registration)
90    fm = FileManager(install_dir=".", template_dir=".", dry_run=False)
91    fm.write_with_template(
92        "external_functions_codegen.cpp",
93        external_path,
94        lambda: {
95            "external_registrations": func_registrations,
96            "external_functions": func_decls,
97        },
98    )
99
100
101def main() -> None:
102    parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
103    parser.add_argument(
104        "--native-functions",
105        "--native_functions",
106        help="path to native_functions.yaml",
107        default="../../../../aten/src/ATen/native/native_functions.yaml",
108    )
109    parser.add_argument(
110        "--tags",
111        help="path to tags.yaml",
112        default="../../../../aten/src/ATen/native/tags.yaml",
113    )
114    parser.add_argument(
115        "--template-path",
116        "--template_path",
117        help="path to external_functions_codegen_template.cpp",
118        default="../../../../tools/jit/templates/external_functions_codegen_template.cpp",
119    )
120    args = parser.parse_args()
121    gen_external(args.native_functions, args.tags, args.template_path)
122
123
124if __name__ == "__main__":
125    main()
126