1# mypy: ignore-errors 2 3import argparse 4 5import torchgen.model as model 6from torchgen.gen import FileManager, parse_native_yaml 7 8 9def num_leading_spaces(line: str) -> int: 10 return len(line) - len(line.lstrip()) 11 12 13def deindent(code: str) -> str: 14 lines = code.split("\n") 15 min_leading_spaces = min(map(num_leading_spaces, lines)) 16 lines = [line[min_leading_spaces:] for line in lines] 17 return "\n".join(lines) 18 19 20def gen_external(native_functions_path, tags_path, external_path): 21 native_functions = parse_native_yaml(native_functions_path, tags_path) 22 func_decls = [] 23 func_registrations = [] 24 for func in native_functions: 25 schema = func.func 26 name = schema.name.name.base 27 args = schema.arguments 28 # Only supports extern calls for functions with out variants 29 if not schema.is_out_fn(): 30 continue 31 32 # Doesn't currently support functions with more than one out parameter 33 if len(args.out) > 1: 34 continue 35 36 # Doesn't currently support kwarg arguments 37 if ( 38 len(args.pre_tensor_options_kwarg_only) > 0 39 or len(args.post_tensor_options_kwarg_only) > 0 40 ): 41 continue 42 self_arg = [args.self_arg.argument] if args.self_arg is not None else [] 43 args = ( 44 list(args.pre_self_positional) + self_arg + list(args.post_self_positional) 45 ) 46 tensor_args = [ 47 arg 48 for arg in args 49 if isinstance(arg.type, model.BaseType) 50 and arg.type.name == model.BaseTy.Tensor 51 ] 52 if len(tensor_args) != len(args): 53 continue 54 55 arg_names = [None] * len(args) 56 57 tensor_decls = [] 58 for idx, arg in enumerate(tensor_args): 59 s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];" 60 tensor_decls.append(s) 61 arg_names[idx] = arg.name 62 nl = "\n" 63 64 # print(tensor_decls, name, arg_names) 65 func_decl = f"""\ 66void nnc_aten_{name}( 67 int64_t bufs_num, 68 void** buf_data, 69 int64_t* buf_ranks, 70 int64_t* buf_dims, 71 int64_t* buf_strides, 72 int8_t* buf_dtypes, 73 int64_t args_num, 74 int64_t* extra_args) {{ 75 std::vector<at::Tensor> tensors = 76 constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes); 77 at::Tensor& r = tensors[0]; 78 {nl.join(tensor_decls)} 79 try {{ 80 at::{name}_out({', '.join(['r'] + arg_names)}); 81 }} catch (...) {{ 82 }} 83}}""" 84 func_registration = f"""\ 85const static RegisterNNCExternalFunction nnc_{name}( 86 "nnc_aten_{name}", 87 nnc_aten_{name});""" 88 func_decls.append(func_decl) 89 func_registrations.append(func_registration) 90 fm = FileManager(install_dir=".", template_dir=".", dry_run=False) 91 fm.write_with_template( 92 "external_functions_codegen.cpp", 93 external_path, 94 lambda: { 95 "external_registrations": func_registrations, 96 "external_functions": func_decls, 97 }, 98 ) 99 100 101def main() -> None: 102 parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") 103 parser.add_argument( 104 "--native-functions", 105 "--native_functions", 106 help="path to native_functions.yaml", 107 default="../../../../aten/src/ATen/native/native_functions.yaml", 108 ) 109 parser.add_argument( 110 "--tags", 111 help="path to tags.yaml", 112 default="../../../../aten/src/ATen/native/tags.yaml", 113 ) 114 parser.add_argument( 115 "--template-path", 116 "--template_path", 117 help="path to external_functions_codegen_template.cpp", 118 default="../../../../tools/jit/templates/external_functions_codegen_template.cpp", 119 ) 120 args = parser.parse_args() 121 gen_external(args.native_functions, args.tags, args.template_path) 122 123 124if __name__ == "__main__": 125 main() 126