1# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp. 2 3from __future__ import annotations 4 5import argparse 6import os 7import sys 8from dataclasses import dataclass 9from pathlib import Path 10from typing import Literal, Sequence, TYPE_CHECKING 11 12import yaml 13 14from torchgen.api import cpp, unboxing 15from torchgen.api.translate import translate 16from torchgen.api.types import CppSignatureGroup 17from torchgen.api.unboxing import convert_arguments 18from torchgen.context import method_with_native_function 19from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml 20from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant 21from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target 22 23 24if TYPE_CHECKING: 25 from torchgen.selective_build.selector import SelectiveBuilder 26 27 28# Generates UnboxingFunctions.h & UnboxingFunctions.cpp. 29@dataclass(frozen=True) 30class ComputeUnboxingFunctions: 31 target: Literal[Target.DECLARATION, Target.DEFINITION] 32 selector: SelectiveBuilder 33 34 @method_with_native_function 35 def __call__(self, f: NativeFunction) -> str: 36 if not self.selector.is_root_operator(f"aten::{f.func.name}"): 37 return "" 38 39 if self.target is Target.DECLARATION: 40 # Note [The ATen Codegen Unboxing API] 41 # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and 42 # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp). 43 # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API. 44 # 45 # Important characteristics about the Codegen Unboxing API: 46 # (1) It follows the OperatorRegistry API. 47 # This is kind of necessary to avoid overhead. 48 # For example: if it followed the C++ API, then all of the faithful C++ factory functions 49 # would need to wrap their arguments into TensorOptions only to unwrap them again. 50 # (2) Under the hood it calls C++ API. 51 return f""" 52// aten::{f.func} 53TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack); 54""" 55 else: 56 sig_group = CppSignatureGroup.from_native_function( 57 f, method=(Variant.method in f.variants) 58 ) 59 sig = sig_group.most_faithful_signature() 60 # parse arguments into C++ code 61 binding_list, code_list = convert_arguments(f) 62 63 # for each C++ argument, generate the conversion code 64 code_connector = "\n\t" 65 arg_connector = ", " 66 # function call and push back to stack 67 prefix = "self_base." if sig.method else "at::" 68 translated_args = translate( 69 binding_list, sig.arguments(), method=sig.method 70 ) 71 args_str = f"{arg_connector.join(e.expr for e in translated_args)}" 72 if len(f.func.returns) == 0: 73 ret_str = "" 74 push_str = "" 75 else: 76 ret_str = "auto result_ = " 77 push_str = """ 78 pack(stack, std::move(result_)); 79 """ 80 return f""" 81// aten::{f.func} 82TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack) {{ 83 {code_connector.join(code_list)} 84 85 drop(stack, {len(binding_list)}); 86 87 {ret_str}{prefix}{sig.name()}({args_str}); 88 {push_str} 89}} 90""" 91 92 93# Generates RegisterCodegenUnboxedKernels.cpp. 94@dataclass(frozen=True) 95class ComputeCodegenUnboxedKernels: 96 selector: SelectiveBuilder 97 98 @method_with_native_function 99 def __call__(self, f: NativeFunction) -> str: 100 if not self.selector.is_root_operator(f"aten::{f.func.name}"): 101 return "" 102 # We unconditionally generate function wrappers, 103 sig_group = CppSignatureGroup.from_native_function(f, method=False) 104 105 sig = sig_group.most_faithful_signature() 106 107 # escape double quote in schema, get rid of extra double quotes 108 schema = cpp_string(str(sig.func))[1:-1] 109 110 # arguments 111 args = sig.arguments() 112 connector = ",\n\t\t" 113 args_code = [] 114 for arg in args: 115 # Using method=False faithful C++ API, so we should not see SelfArgument/TensorOptionsArgument 116 assert isinstance(arg.argument, Argument) 117 if not arg.argument.default: 118 arg_cpp = "c10::IValue(::std::nullopt)" 119 else: 120 # The unboxing code uses the faithful C++ API to avoid the overhead 121 # from wrapping/unwrapping TensorOptios. 122 # However, we would look to include default args for schema parsing. 123 # Default args only show up in the nonfaithful C++ API, 124 arg_default = cpp.default_expr( 125 arg.argument.default, arg.argument.type, symint=False 126 ) 127 if arg_default.startswith("{"): 128 arg_cpp = f"c10::IntArrayRef({arg_default})" 129 else: 130 arg_cpp = f"c10::IValue({arg_default})" 131 args_code.append( 132 f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})""" 133 ) 134 135 returns = f.func.returns 136 returns_code = [] 137 for ret in returns: 138 returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""") 139 return f""" 140// aten::{schema} 141OperatorGenerator( 142 "aten::{f.func.name.name}", 143 "{f.func.name.overload_name}", 144 {{ 145 {connector.join(args_code)} 146 }}, 147 {{ 148 {connector.join(returns_code)} 149 }}, 150 [](Stack & stack) {{ 151 RECORD_FUNCTION("{sig.name()}", std::vector<c10::IValue>()); 152 at::unboxing::{unboxing.name(f)}(stack); 153 }}, 154 aliasAnalysisFromSchema() 155), 156""" 157 158 159def gen_unboxing( 160 *, 161 native_functions: Sequence[NativeFunction], 162 cpu_fm: FileManager, 163 selector: SelectiveBuilder, 164) -> None: 165 def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str: 166 return fn.root_name 167 168 selected_op_num: int = len(selector.operators) 169 # a best practice threshold of operators to enable sharding 170 sharding_threshold: int = 100 171 cpu_fm.write_sharded( 172 "UnboxingFunctions.cpp", 173 native_functions, 174 key_fn=key_func, 175 env_callable=lambda fn: { 176 "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)] 177 }, 178 num_shards=1 if selected_op_num < sharding_threshold else 5, 179 sharded_keys={"definitions"}, 180 ) 181 cpu_fm.write( 182 "UnboxingFunctions.h", 183 lambda: { 184 "declarations": list( 185 mapMaybe( 186 ComputeUnboxingFunctions(Target.DECLARATION, selector), 187 native_functions, 188 ) 189 ), 190 }, 191 ) 192 cpu_fm.write_sharded( 193 "RegisterCodegenUnboxedKernels.cpp", 194 native_functions, 195 key_fn=key_func, 196 env_callable=lambda fn: { 197 "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)] 198 }, 199 num_shards=1 if selected_op_num < sharding_threshold else 10, 200 sharded_keys={"unboxed_ops"}, 201 ) 202 203 204def main(args: list[str]) -> None: 205 parser = argparse.ArgumentParser(description="Generate unboxing source files") 206 parser.add_argument( 207 "-s", 208 "--source-path", 209 help="path to source directory for ATen", 210 default="aten/src/ATen", 211 ) 212 parser.add_argument( 213 "-d", 214 "--install-dir", 215 "--install_dir", 216 help="output directory", 217 default="build/aten/src/ATen", 218 ) 219 parser.add_argument( 220 "-o", 221 "--output-dependencies", 222 help="output a list of dependencies into the given file and exit", 223 ) 224 parser.add_argument( 225 "--dry-run", 226 action="store_true", 227 help="run without writing any files (still updates outputs)", 228 ) 229 parser.add_argument( 230 "--op-selection-yaml-path", 231 "--op_selection_yaml_path", 232 help="Provide a path to the operator selection (for custom build) YAML " 233 "that contains the information about the set of selected operators " 234 "and their categories (training, ...). Each operator is either a " 235 "full operator name with overload or just a bare operator name. " 236 "The operator names also contain the namespace prefix (e.g. aten::)", 237 ) 238 parser.add_argument( 239 "--op-registration-allowlist", 240 "--op_registration_allowlist", 241 nargs="*", 242 help="filter op registrations by the allowlist (if set); " 243 "each item is `namespace`::`operator name` without overload name; " 244 "e.g.: aten::empty aten::conv2d ...", 245 ) 246 parser.add_argument( 247 "--TEST-ONLY-op-registration-allowlist-yaml-path", 248 "--TEST_ONLY_op_registration_allowlist_yaml_path", 249 help="Provide a path to the operator selection (for custom build) YAML " 250 "which contains a list of operators. It is to serve testing purpose and " 251 "each item is `namespace`::`operator name` without overload name; " 252 "e.g.: aten::empty aten::conv2d ...", 253 ) 254 255 options = parser.parse_args(args) 256 if options.op_registration_allowlist: 257 op_registration_allowlist = options.op_registration_allowlist 258 elif options.TEST_ONLY_op_registration_allowlist_yaml_path: 259 with open(options.TEST_ONLY_op_registration_allowlist_yaml_path) as f: 260 op_registration_allowlist = yaml.safe_load(f) 261 else: 262 op_registration_allowlist = None 263 264 selector = get_custom_build_selector( 265 op_registration_allowlist, 266 options.op_selection_yaml_path, 267 ) 268 269 native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") 270 tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") 271 parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path) 272 native_functions, backend_indices = ( 273 parsed_yaml.native_functions, 274 parsed_yaml.backend_indices, 275 ) 276 277 cpu_fm = make_file_manager(options=options) 278 gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector) 279 280 if options.output_dependencies: 281 depfile_path = Path(options.output_dependencies).resolve() 282 depfile_name = depfile_path.name 283 depfile_stem = depfile_path.stem 284 285 path = depfile_path.parent / depfile_name 286 cpu_fm.write_outputs(depfile_stem, str(path)) 287 288 289if __name__ == "__main__": 290 main(sys.argv[1:]) 291