1from __future__ import annotations 2 3import argparse 4import functools 5import json 6import os 7from collections import defaultdict, namedtuple, OrderedDict 8from dataclasses import dataclass, field 9from pathlib import Path 10from typing import Any, Callable, Literal, Sequence, TypeVar 11 12import yaml 13 14import torchgen.api.dispatcher as dispatcher 15import torchgen.api.meta as meta 16import torchgen.api.native as native 17import torchgen.api.structured as structured 18import torchgen.dest as dest 19from torchgen.aoti.fallback_ops import inductor_fallback_ops 20from torchgen.api import cpp 21from torchgen.api.translate import translate 22from torchgen.api.types import ( 23 Binding, 24 CppSignature, 25 CppSignatureGroup, 26 DispatcherSignature, 27 NamedCType, 28 NativeSignature, 29 SpecialArgName, 30) 31from torchgen.context import ( 32 method_with_native_function, 33 native_function_manager, 34 with_native_function, 35 with_native_function_and_indices, 36) 37from torchgen.gen_aoti_c_shim import ( 38 gen_aoti_c_shim, 39 gen_static_dispatch_backend_call_signature, 40 get_fallback_op_name, 41 get_header_for_aoti, 42) 43from torchgen.gen_functionalization_type import ( 44 gen_functionalization_definition, 45 gen_functionalization_registration, 46 gen_functionalization_view_inverse_declaration, 47 GenCompositeViewCopyKernel, 48) 49from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing 50from torchgen.model import ( 51 Argument, 52 BackendIndex, 53 BackendMetadata, 54 BaseOperatorName, 55 DEFAULT_KERNEL_NAMESPACE, 56 DispatchKey, 57 FRAGMENT_NAMESPACES, 58 FunctionSchema, 59 is_cuda_dispatch_key, 60 is_generic_dispatch_key, 61 is_ufunc_dispatch_key, 62 is_xpu_dispatch_key, 63 Location, 64 NativeFunction, 65 NativeFunctionsGroup, 66 NativeFunctionsViewGroup, 67 OperatorName, 68 OptionalType, 69 SchemaKind, 70 SelfArgument, 71 STRUCTURED_DISPATCH_KEYS, 72 TensorOptionsArguments, 73 Type, 74 Variant, 75 ViewSchemaKind, 76) 77from torchgen.native_function_generation import ( 78 add_generated_native_functions, 79 gen_composite_functional_kernel, 80 gen_composite_out_kernel, 81 pre_group_native_functions, 82) 83from torchgen.selective_build.selector import SelectiveBuilder 84from torchgen.utils import ( 85 assert_never, 86 concatMap, 87 context, 88 FileManager, 89 make_file_manager, 90 mapMaybe, 91 NamespaceHelper, 92 Target, 93) 94from torchgen.yaml_utils import YamlDumper, YamlLoader 95 96 97T = TypeVar("T") 98 99# Welcome to the ATen code generator v2! The ATen code generator is 100# responsible for parsing native_functions.yaml and then generating 101# various generated files (e.g., TypeDefault.cpp) based on the operators 102# defined in this file. This means that the code generator knows how to 103# parse function schema, and then translate this into various C++ types 104# and boilerplate code. 105# 106# Some things to know about this file when you modify it: 107# 108# - This file has STRICT mypy typechecking. Typecheck it with 109# `mypy --config mypy-strict.ini` in the root source directory 110# 111# - Most of the heavy lifting lives in external modules: 112# - 'model' has the data model for native_functions.yaml. The classes 113# in those file represent what you see when you look at 114# a native_functions.yaml 115# - 'api' has conversions for how to translate JIT schema into 116# the various C++ APIs that the codegen interacts with. There 117# are in fact THREE different C++ APIs: the public C++ API, 118# the dispatcher API, and the legacy dispatcher API. See each 119# of these respective files for more information 120 121# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 122# 123# HELPER FUNCTIONS 124# 125# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 126 127 128# A custom loader for YAML to let us also keep track of line numbers 129# of each entry in the YAML file 130class LineLoader(YamlLoader): 131 def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] 132 mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] 133 # Add 1 so line numbering starts at 1 134 mapping["__line__"] = node.start_mark.line + 1 135 return mapping 136 137 138# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. 139ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) 140 141 142_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {} 143_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {} 144 145 146def parse_native_yaml_struct( 147 es: object, 148 valid_tags: set[str], 149 ignore_keys: set[DispatchKey] | None = None, 150 path: str = "<stdin>", 151 skip_native_fns_gen: bool = False, 152) -> ParsedYaml: 153 assert isinstance(es, list) 154 rs: list[NativeFunction] = [] 155 bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict) 156 for e in es: 157 assert isinstance(e, dict), f"expected to be dict: {e}" 158 assert isinstance(e.get("__line__"), int), e 159 loc = Location(path, e["__line__"]) 160 funcs = e.get("func") 161 assert funcs is not None, f"missed 'func' in {e}" 162 with context(lambda: f"in {loc}:\n {funcs}"): 163 func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys) 164 rs.append(func) 165 BackendIndex.grow_index(bs, m) 166 error_check_native_functions(rs) 167 # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. 168 indices: dict[DispatchKey, BackendIndex] = defaultdict( 169 lambda: BackendIndex( 170 dispatch_key=DispatchKey.Undefined, 171 use_out_as_primary=True, 172 external=False, 173 device_guard=False, 174 # I'm actually not sure about this; undefined could be hit on 175 # empty TensorList, hypothetically that could have sizes in it 176 index={}, 177 ) 178 ) 179 if not skip_native_fns_gen: 180 add_generated_native_functions(rs, bs) 181 for k, v in bs.items(): 182 # All structured in-tree operators are implemented in terms of their out operator. 183 indices[k] = BackendIndex( 184 dispatch_key=k, 185 use_out_as_primary=True, 186 external=False, 187 # Only cuda-like devices in tree require device guards 188 device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k), 189 index=v, 190 ) 191 return ParsedYaml(rs, indices) 192 193 194def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]: 195 assert isinstance(es, list) 196 rs: set[str] = set() 197 for e in es: 198 assert isinstance(e.get("__line__"), int), e 199 loc = Location(path, e["__line__"]) 200 tags = e.get("tag") 201 with context(lambda: f"in {loc}:\n {tags}"): 202 e_i = e.copy() 203 name = e_i.pop("tag") 204 desc = e_i.pop("desc", "") 205 # ensure that each tag has a non-empty description 206 assert desc != "" 207 rs.add(name) 208 return rs 209 210 211@functools.lru_cache(maxsize=None) 212def parse_tags_yaml(path: str) -> set[str]: 213 global _GLOBAL_PARSE_TAGS_YAML_CACHE 214 if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: 215 with open(path) as f: 216 es = yaml.load(f, Loader=LineLoader) 217 _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path) 218 219 return _GLOBAL_PARSE_TAGS_YAML_CACHE[path] 220 221 222def parse_native_yaml( 223 path: str, 224 tags_yaml_path: str, 225 ignore_keys: set[DispatchKey] | None = None, 226 *, 227 skip_native_fns_gen: bool = False, 228 loaded_yaml: object | None = None, 229) -> ParsedYaml: 230 global _GLOBAL_PARSE_NATIVE_YAML_CACHE 231 if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: 232 valid_tags = parse_tags_yaml(tags_yaml_path) 233 234 # if a loaded yaml is provided, use that instead of reading from path 235 if loaded_yaml is None: 236 with open(path) as f: 237 es = yaml.load(f, Loader=LineLoader) 238 else: 239 es = loaded_yaml 240 241 _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct( 242 es, 243 valid_tags, 244 ignore_keys, 245 path=path, 246 skip_native_fns_gen=skip_native_fns_gen, 247 ) 248 249 return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] 250 251 252# Some assertions are already performed during parsing, but those are only within a single NativeFunction. 253# Assertions here are meant to be performed across NativeFunctions. 254def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: 255 func_map: dict[OperatorName, NativeFunction] = {} 256 base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) 257 for f in funcs: 258 func_map[f.func.name] = f 259 base_func_map[f.func.name.name].append(f) 260 for f in funcs: 261 if f.structured_delegate is not None: 262 delegate_func = func_map.get(f.structured_delegate) 263 assert delegate_func is not None, ( 264 f"{f.func.name} is marked as a structured_delegate pointing to " 265 f"{f.structured_delegate}, but {f.structured_delegate} is missing." 266 ) 267 assert delegate_func.structured, ( 268 f"{f.func.name} is marked as a structured_delegate pointing to " 269 f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " 270 f"Consider adding 'structured=True' to the delegated operator" 271 ) 272 # See Note [resize_ in Functionalization] 273 # resize_() is technically an inplace view op (and therefore needs the tag), 274 # but it would be overkill to add a true "view" variant of resize. 275 # Instead, resize_() gets special treatment in functionalization, 276 # and we have a resize() op that is non-aliasing + functional. 277 if ( 278 "inplace_view" in f.tags 279 and str(f.func.name) != "resize_" 280 and str(f.func.name) != "resize_as_" 281 and str(f.func.name.name) != "set_" 282 ): 283 base_name = f.func.name.name 284 assert base_name.inplace, ( 285 f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming " 286 "convention for inplace ops - the codegen expects the base name to have a trailing underscore. " 287 ) 288 out_of_place_base_name = BaseOperatorName( 289 base_name.base, False, base_name.dunder_method 290 ) 291 assert len(base_func_map[out_of_place_base_name]) > 0, ( 292 f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding " 293 f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. " 294 ) 295 296 297def cpp_string(s: str) -> str: 298 """Convert a python string into a c++ string literal""" 299 s = s.replace("\\", "\\\\") 300 s = s.replace('"', '\\"') 301 s = s.replace("\a", "\\a") 302 s = s.replace("\b", "\\b") 303 s = s.replace("\f", "\\f") 304 s = s.replace("\n", "\\n") 305 s = s.replace("\v", "\\v") 306 s = s.replace("\t", "\\t") 307 return f'"{s}"' 308 309 310# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 311# 312# C++ CODE GENERATION 313# 314# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 315 316# Most functions in this section are curried: they consist of a function 317# that takes some parameters (e.g., what is to be generated) which itself 318# returns a function that actually maps NativeFunction to the code 319# to be generated. This pattern makes it convenient to use map, concatMap 320# and similar functional combinators. 321 322 323def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]: 324 if len(backends) == 0: 325 return [] 326 else: 327 return [backend.dispatch_key for backend in backends] + [ 328 DispatchKey.CompositeImplicitAutograd, 329 DispatchKey.CompositeImplicitAutogradNestedTensor, 330 DispatchKey.CompositeExplicitAutograd, 331 DispatchKey.CompositeExplicitAutogradNonFunctional, 332 ] 333 334 335def get_static_dispatch_backend( 336 f: NativeFunction, backend_index: BackendIndex 337) -> DispatchKey | None: 338 if f.structured_delegate is not None or backend_index.has_kernel(f): 339 # TODO: for ops with structured_delegate it should check the dispatch table of 340 # the out variant instead. For now, these structured ops all have CPU/CUDA kernels 341 # so we always dispatch to the `backend`, but this could be wrong when we 342 # migrate math/default_backend ops to use structured delegate. 343 return backend_index.dispatch_key 344 elif f.has_composite_explicit_autograd_kernel: 345 return DispatchKey.CompositeExplicitAutograd 346 elif f.has_composite_explicit_autograd_non_functional_kernel: 347 return DispatchKey.CompositeExplicitAutogradNonFunctional 348 elif f.has_composite_implicit_autograd_kernel: 349 return DispatchKey.CompositeImplicitAutograd 350 elif f.has_composite_implicit_autograd_nested_tensor_kernel: 351 return DispatchKey.CompositeImplicitAutogradNestedTensor 352 return None 353 354 355def static_dispatch_ops_header( 356 f: NativeFunction, backend_index: list[BackendIndex] 357) -> str | None: 358 if backend_index is None or f.manual_kernel_registration: 359 return None 360 361 output = [] 362 for index in backend_index: 363 dispatch_key = get_static_dispatch_backend(f, index) 364 if dispatch_key is not None: 365 output.append( 366 f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>" 367 ) 368 return "\n".join(output) 369 370 371def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]: 372 return [ 373 f"#include <ATen/{dispatch_key}Functions.h>" 374 for dispatch_key in static_dispatch_keys(backends) 375 ] 376 377 378# Translates arguments of `sig` to CppSignature bindings. 379# Note that we have a special case for `memory_format` argument and this case is not covered by 380# tools.codegen.api.translate() yet as its application is limited to static dispatch. 381def translate_args( 382 sig: CppSignature | DispatcherSignature, 383 cpp_sig: CppSignature, 384) -> str: 385 # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings 386 def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]: 387 output_bindings: list[Binding] = [] 388 for binding in input_bindings: 389 if binding.name == "memory_format": 390 spl_mem_format_binding = Binding( 391 nctype=NamedCType( 392 SpecialArgName.possibly_redundant_memory_format, 393 binding.nctype.type, 394 ), 395 name=binding.name, 396 default=binding.default, 397 argument=binding.argument, 398 ) 399 output_bindings.append(spl_mem_format_binding) 400 else: 401 output_bindings.append(binding) 402 return output_bindings 403 404 src_bindings = list(sig.arguments()) 405 goal_bindings = list(cpp_sig.arguments()) 406 # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, 407 # get memory_format bindings of dispatcher signature to have the same NCType as well 408 for arg in goal_bindings: 409 if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format: 410 src_bindings = add_spl_memory_format_binding(src_bindings) 411 break 412 exprs = translate(src_bindings, goal_bindings) 413 return ", ".join(a.expr for a in exprs) 414 415 416def generate_static_dispatch_backend_call( 417 sig: CppSignature | DispatcherSignature, 418 f: NativeFunction, 419 backend_index: BackendIndex, 420) -> str: 421 cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) 422 name = cpp_sig.name() 423 exprs = translate_args(sig, cpp_sig) 424 backend_metadata = backend_index.get_kernel(f) 425 kernel_ns = ( 426 backend_metadata.cpp_namespace 427 if backend_metadata and backend_metadata.cpp_namespace 428 else DEFAULT_KERNEL_NAMESPACE 429 ) 430 ns = kernel_ns.replace("::native", "") 431 return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});" 432 433 434def generate_static_dispatch_fallback_call( 435 sig: CppSignature | DispatcherSignature, 436 f: NativeFunction, 437 backend_indices: list[BackendIndex], 438) -> str: 439 cpp_sigs = CppSignatureGroup.from_native_function( 440 f, method=False, fallback_binding=False 441 ) 442 if sig.symint and f.func.has_symint(): 443 cpp_sig = cpp_sigs.symint_signature 444 else: 445 cpp_sig = cpp_sigs.signature 446 assert cpp_sig is not None 447 name = cpp_sig.name() 448 exprs = translate_args(sig, cpp_sig) 449 ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") 450 if f.has_composite_explicit_autograd_kernel: 451 return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" 452 elif f.has_composite_explicit_autograd_non_functional_kernel: 453 return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});" 454 elif f.has_composite_implicit_autograd_kernel: 455 return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});" 456 elif f.has_composite_implicit_autograd_nested_tensor_kernel: 457 return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});" 458 else: 459 return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ 460{', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" 461 462 463def static_dispatch( 464 sig: CppSignature | DispatcherSignature, 465 f: NativeFunction, 466 backend_indices: list[BackendIndex], 467) -> str: 468 """ 469 For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one 470 backends exsit, fallback to static dispatch by determining dispatch key from inputs. 471 Arguments: 472 sig: A CppSignature or DispatcherSignature for this native function we want to use. 473 f: NativeFunction to generate static dispatch. 474 backend_indices: All available backends. 475 Return: 476 C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);" 477 """ 478 if len(backend_indices) == 0 or f.manual_kernel_registration: 479 return "" 480 481 keys = [ 482 b 483 for b in backend_indices 484 if b.has_kernel(f) 485 or ( 486 f.structured_delegate is not None 487 and b.dispatch_key in STRUCTURED_DISPATCH_KEYS 488 ) 489 ] 490 if len(keys) == 1: 491 return generate_static_dispatch_backend_call(sig, f, keys[0]) 492 elif len(keys) == 0: 493 return generate_static_dispatch_fallback_call(sig, f, backend_indices) 494 495 native_tensor_args = [ 496 a.name 497 for a in sig.arguments() 498 if isinstance(a.argument, SelfArgument) 499 or isinstance(a.argument, Argument) 500 and a.argument.type.is_tensor_like() 501 ] 502 tensor_args = ", ".join(native_tensor_args) 503 tensor_opts = f.func.arguments.tensor_options 504 505 stmts = [] 506 subexprs: list[str] = [] 507 if tensor_opts is not None: 508 subexprs.append( 509 "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))" 510 ) 511 if tensor_args != "": 512 subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})") 513 stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""") 514 stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);") 515 516 dispatch_code = [] 517 for index in keys: 518 dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") 519 dispatch_code.append( 520 f"""\t{generate_static_dispatch_backend_call(sig, f, index)};""" 521 ) 522 523 fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices) 524 connector = "\n\t\t" 525 526 return f""" 527 {connector.join(stmts)} 528 switch (_dk) {{ 529 {connector.join(dispatch_code)} 530 default: 531 {fallback} 532 }} 533 """ 534 535 536# Generates RegisterSchema.cpp. Depending on the selector, either 537# all schemas are registered, or only some are (in the case of 538# selective build) 539@dataclass(frozen=True) 540class RegisterSchema: 541 selector: SelectiveBuilder 542 known_tags: dict[str, int] = field(default_factory=dict) 543 544 @method_with_native_function 545 def __call__(self, f: NativeFunction) -> str | None: 546 if not self.selector.is_native_function_selected(f): 547 return None 548 tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" 549 if tags == "{}": 550 return f"m.def({cpp_string(str(f.func))}, {{}});\n" 551 maybe_tags = "" 552 if tags not in self.known_tags: 553 idx = len(self.known_tags) 554 self.known_tags[tags] = idx 555 maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n" 556 return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n" 557 558 559# Generates Operators.h and Operators.cpp. 560# These provide macros that, given an operator and overload name, allow users 561# to access an "un-overloaded" function version of the operator. This 562# is useful for extension writers who want to (1) want to decltype the operator 563# and (2) don't want to worry about method-only operators. 564@dataclass(frozen=True) 565class ComputeOperators: 566 target: Literal[Target.DECLARATION, Target.DEFINITION] 567 static_dispatch_backend_indices: list[BackendIndex] 568 569 @method_with_native_function 570 def __call__(self, f: NativeFunction) -> str: 571 sig = DispatcherSignature.from_schema(f.func) 572 name = f.func.name.unambiguous_name() 573 574 if self.target is Target.DECLARATION: 575 # Note [The ATen Operators API] 576 # The ATen Operators API lives in the at::_ops namespace, and contains compile-time 577 # metadata about each operator + entry points into the Dispatcher. 578 # The C++ function, method, and redispatch API's are all implemented as wrappers 579 # into various bits of the structs defined here. 580 # 581 # Important characteristics about the Operators API: 582 # (1) It follows the Dispatcher API. 583 # This is kind of necessary to avoid overhead. 584 # For example: if it followed the C++ API, then all of the faithful C++ factory functions 585 # would need to wrap their arguments into TensorOptions only to unwrap them again. 586 # (2) Overload names are disambiguated. 587 # This is helpful for pytorch extenders who would like to decltype() an aten operator, 588 # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call) 589 # (3) No argument defaulting is allowed. 590 # This is more of an implementation detail to avoid #include cycles, 591 # since TensorBody.h (which defines the Tensor class) needs to include this file. 592 # (4) manual_cpp_bindings and faithful names are not included in the API. 593 # This applies to stuff like __dispatch__is_complex(), and add_outf(). 594 # These aren't "real aten ops", they're just additional functions provided by the C++ API. 595 # They're implemented as wrappers in Functions.h that call into the actual operators 596 # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). 597 # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. 598 return f""" 599struct TORCH_API {name} {{ 600 using schema = {sig.type()}; 601 using ptr_schema = schema*; 602 // See Note [static constexpr char* members for windows NVCC] 603 STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") 604 STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") 605 STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) 606 static {sig.defn(name="call", is_redispatching_fn=False)}; 607 static {sig.defn(name="redispatch", is_redispatching_fn=True)}; 608}};""" 609 610 elif self.target is Target.DEFINITION: 611 defns = f""" 612STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}") 613STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") 614STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))}) 615 616// aten::{f.func} 617static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{ 618 return c10::Dispatcher::singleton() 619 .findSchemaOrThrow({name}::name, {name}::overload_name) 620 .typed<{name}::schema>(); 621}} 622""" 623 for is_redispatching_fn in [False, True]: 624 if is_redispatching_fn: 625 dispatcher_exprs_str = ", ".join( 626 ["dispatchKeySet"] + [a.name for a in sig.arguments()] 627 ) 628 method_base = "redispatch" 629 else: 630 dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()]) 631 method_base = "call" 632 633 dispatcher_call = method_base 634 method_name = f"{name}::{method_base}" 635 636 fn_body = f""" 637 static auto op = create_{name}_typed_handle(); 638 return op.{dispatcher_call}({dispatcher_exprs_str});""" 639 640 if ( 641 not is_redispatching_fn 642 and len(self.static_dispatch_backend_indices) > 0 643 ): 644 # call() should go through static dispatch 645 fn_body = static_dispatch( 646 sig, f, backend_indices=self.static_dispatch_backend_indices 647 ) 648 defns += f""" 649// aten::{f.func} 650{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ 651 {fn_body} 652}} 653""" 654 return defns 655 else: 656 assert_never(self.target) 657 658 659# Generates Functions.h, which provides the functional public C++ API, 660# and the scaffolding to call into the dispatcher from these functions. 661@dataclass(frozen=True) 662class ComputeFunction: 663 @method_with_native_function 664 def __call__(self, f: NativeFunction) -> str | None: 665 sig_group = CppSignatureGroup.from_native_function( 666 f, method=False, fallback_binding=f.manual_cpp_binding 667 ) 668 has_symint = f.func.has_symint() 669 670 result = "" 671 for sig in sig_group.signatures(): 672 # See Note [The ATen Operators API] 673 target_sig = DispatcherSignature.from_schema(f.func) 674 exprs = translate(sig.arguments(), target_sig.arguments()) 675 exprs_str = ", ".join([e.expr for e in exprs]) 676 677 if sig.symint: 678 intlike_t = "c10::SymInt" 679 else: 680 intlike_t = "int64_t" 681 682 if Variant.function in f.variants: 683 result += f""" 684// aten::{f.func} 685inline {sig.decl()} {{ 686 return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); 687}}""" 688 689 # The template function can be used from template situations 690 # where you want to switch between the symint or not version 691 # depending on a template argument 692 # 693 # NB: we ALWAYS generate this even for methods. But we put it in 694 # this header so it can take advantage of per-op headers 695 if has_symint: 696 result += f""" 697namespace symint {{ 698 template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>> 699 {sig.decl(suppress_symint_suffix=True)} {{ 700 return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); 701 }} 702}} 703""" 704 return result 705 706 707# Generates TensorBody.h. This file provides the object-oriented (method-based) 708# public C++ API, and the scaffolding to call into the dispatcher from these functions. 709@dataclass(frozen=True) 710class ComputeTensorMethod: 711 target: Literal[Target.DECLARATION, Target.DEFINITION] 712 static_dispatch_backend_indices: list[BackendIndex] 713 714 @method_with_native_function 715 def __call__(self, f: NativeFunction) -> str | None: 716 if Variant.method not in f.variants: 717 return None 718 719 assert not f.func.is_out_fn() 720 assert f.func.arguments.self_arg is not None 721 722 sig_group = CppSignatureGroup.from_native_function( 723 f, method=True, fallback_binding=f.manual_cpp_binding 724 ) 725 726 if self.target is Target.DECLARATION: 727 result = "" 728 for sig in sig_group.signatures(): 729 result += f"{sig.decl()} const;\n" 730 return result 731 732 if self.target is not Target.DEFINITION: 733 assert_never(self.target) 734 735 result = "" 736 737 for sig in sig_group.signatures(): 738 target_sig = DispatcherSignature.from_schema(f.func) 739 exprs = translate(sig.arguments(), target_sig.arguments(), method=True) 740 exprs_str = ", ".join([e.expr for e in exprs]) 741 742 result += f""" 743// aten::{f.func} 744inline {sig.defn(prefix="Tensor::")} const {{ 745 return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); 746}} 747""" 748 749 return result 750 751 752# Generates RedispatchFunctions.h. 753# This is similar to the C++ API defined in Functions.h, but provides access 754# to the dispatcher's redispatch API. 755@dataclass(frozen=True) 756class ComputeRedispatchFunction: 757 @method_with_native_function 758 def __call__(self, f: NativeFunction) -> str | None: 759 # We unconditionally generate function variants of the redispatch API. 760 # This is mainly because we can namespace functions separately, but not methods, 761 sig_group = CppSignatureGroup.from_native_function( 762 f, method=False, fallback_binding=f.manual_cpp_binding 763 ) 764 765 result = "" 766 for sig in sig_group.signatures(): 767 target_sig = DispatcherSignature.from_schema(f.func) 768 exprs = translate(sig.arguments(), target_sig.arguments()) 769 exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs]) 770 771 result += f""" 772// aten::{f.func} 773inline {sig.decl(is_redispatching_fn=True)} {{ 774 return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str}); 775}} 776""" 777 778 return result 779 780 781# Generates ATenOpList.cpp, a runtime accessible list of all aten 782# operators. 783# TODO: This was historically used to help some JIT interop code 784# figure out whether or not to treat aten namespace'd operators 785# one way or another, we should reevaluate if this is actually needed. 786@with_native_function 787def compute_aten_op(f: NativeFunction) -> str: 788 return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},' 789 790 791# Generates MetaFunctions.h 792def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None: 793 if not g.structured: 794 return None 795 with native_function_manager(g.out): 796 name = meta.name(g) 797 args = structured.meta_arguments(g) 798 args_str = ", ".join(a.decl() for a in args) 799 parent_class = g.out.structured_inherits 800 if parent_class is None: 801 parent_class = "at::impl::MetaBase" 802 meta_return = "void" 803 precomputed = g.out.precomputed if g.structured else None 804 805 if precomputed: 806 # Generate the template declaration with one bool parameter for each 807 # precomputed element. Each parameter is true if the corresponding (in 808 # terms of position) precomputed element has been set. 809 precomputed_values = [*precomputed.replace.values(), precomputed.add] 810 precomputed_elements = [ 811 elem for replace_list in precomputed_values for elem in replace_list 812 ] 813 precomputed_template_parameters = [ 814 elem.name.upper() for elem in precomputed_elements 815 ] 816 precomputed_template_params_str = ", ".join( 817 f"bool {param} = false" for param in precomputed_template_parameters 818 ) 819 precompute_template_decl = f"template <{precomputed_template_params_str}>" 820 821 # Generate a string containing declarations of all precomputed elements. 822 precomputed_elements_with_cpp_types = [ 823 structured.argument_type(elem, binds=elem.name) 824 for elem in precomputed_elements 825 ] 826 827 precomputed_elements_decl = ";\n".join( 828 f"{elem.cpp_type(strip_ref=True)} {elem.name}" 829 for elem in precomputed_elements_with_cpp_types 830 ) 831 832 # Generate "setter" methods for each precomputed element. Each method will return 833 # a new instance of precompute_out with the template parameter that corresponds to 834 # the member set by the method to true (to indicate that it has been set). 835 setter_methods = [] 836 for i, elem in enumerate(precomputed_elements): 837 # Generate the signature. The return type will be the same 838 # as the type of `this` but with the template parameter 839 # corresponding to the element set by this method set to true. 840 # The assert generated below will ensure that this template 841 # parameter is false on the type of `this`. 842 return_ty_templates = ", ".join( 843 precomputed_template_parameters[:i] 844 + ["true"] 845 + precomputed_template_parameters[i + 1 :] 846 ) 847 return_ty = f"precompute_out<{return_ty_templates}>" 848 elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type( 849 strip_ref=True 850 ) 851 signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)" 852 853 # Generate an assert which checks that the 854 # template parameter corresponding to the precomputed 855 # element that is set by this method is false on the 856 # class corresponding to the object that `this` points to. 857 # This ensures that each element can be set only once. 858 assert_msg = f'"{elem.name} already set"' 859 assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});" 860 861 # Generate the new object construction block. All state 862 # except the element that this method sets is copied from the 863 # object that `this` points to. The value for the element that 864 # the method sets is taken from a method parameter. 865 construction_stmts = [] 866 construction_stmts.append(f"{return_ty} ret;") 867 868 for j, elem in enumerate(precomputed_elements): 869 if i == j: 870 construction_stmts.append(f"ret.{elem.name} = value;") 871 else: 872 construction_stmts.append( 873 f"ret.{elem.name} = this->{elem.name};" 874 ) 875 876 construction_stmts.append("return ret;") 877 construction_block = "\n".join(construction_stmts) 878 879 setter_methods.append( 880 f""" 881 {signature} {{ 882 {assert_stmt} 883 {construction_block} 884 }} 885 """ 886 ) 887 setter_methods_decl = "\n".join(setter_methods) 888 889 # Meta should return an instance of the struct containing the precomputed elements. 890 meta_return_template_params = ", ".join( 891 ["true"] * len(precomputed_template_parameters) 892 ) 893 # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return 894 # type (which has a variable number of template parameters). 895 meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;" 896 meta_return = "meta_return_ty" 897 precomputed_decl = f""" 898 {precompute_template_decl} 899 struct TORCH_API precompute_out {{ 900 {setter_methods_decl} 901 {precomputed_elements_decl}; 902 }};""" 903 else: 904 meta_return_typedef = "" 905 precomputed_decl = "" 906 907 return f"""\ 908struct TORCH_API structured_{name} : public {parent_class} {{ 909 {precomputed_decl} 910 {meta_return_typedef} 911 {meta_return} meta({args_str}); 912}}; 913""" 914 915 916def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool: 917 name = str(f.func.name.name) 918 if name.endswith("_like") or name.startswith("new_"): 919 return False 920 if f.func.arguments.tensor_options is None: 921 return False 922 return selector.is_native_function_selected(f) 923 924 925# Generates RegisterBackendSelect.cpp, a series of kernels which provide 926# specialized computation of dispatch key for operator signatures which cannot 927# be easily done automatically using templating. 928@dataclass(frozen=True) 929class ComputeBackendSelect: 930 target: Literal[Target.DEFINITION, Target.REGISTRATION] 931 932 # Selector object to determine which operators to generate 933 # registration code for. 934 selector: SelectiveBuilder 935 936 @method_with_native_function 937 def __call__(self, f: NativeFunction) -> str | None: 938 if not needs_backend_select(f, self.selector): 939 return None 940 941 name = native.name(f.func) 942 # BackendSelect can go to Meta, so it must preserve symints 943 native_sig = NativeSignature(f.func, symint=True) 944 945 native_tensor_args = [ 946 a 947 for a in native_sig.arguments() 948 if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() 949 ] 950 951 dispatcher_sig = DispatcherSignature.from_schema(f.func) 952 953 sig: NativeSignature | DispatcherSignature 954 sig = dispatcher_sig 955 dispatcher_exprs = dispatcher_sig.exprs() 956 dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" 957 958 if self.target is Target.DEFINITION: 959 # I don't think there's actually a good reason to generate 960 # these two cases differently 961 # The first case could probably be improved though- it calls computeDispatchKeySet(), 962 # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. 963 if native_tensor_args: 964 assert f.func.arguments.has_tensor_arg() 965 tensor_args = ", ".join(a.name for a in native_tensor_args) 966 compute_dk = f"""\ 967DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); 968DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); 969DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" 970 else: 971 assert not f.func.arguments.has_tensor_arg() 972 compute_dk = ( 973 f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" 974 ) 975 return f"""\ 976// aten::{f.func} 977C10_ALWAYS_INLINE 978{sig.defn(name)} {{ 979 {compute_dk} 980 return at::_ops::{f.func.name.unambiguous_name()}::redispatch( 981 _dk, {', '.join(a.expr for a in dispatcher_exprs)}); 982}} 983""" 984 elif self.target is Target.REGISTRATION: 985 return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" 986 else: 987 assert_never(self.target) 988 989 990# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 991# 992# YAML CODE GENERATION 993# 994# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 995 996 997def format_yaml(data: object) -> str: 998 # Ignore alias in Dumper 999 YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment] 1000 1001 # Support serializing OrderedDict 1002 def dict_representer(dumper: Any, data: Any) -> Any: 1003 return dumper.represent_dict(data.items()) 1004 1005 YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call] 1006 # Some yaml parsers (e.g. Haskell's) don't understand line breaks. 1007 # width=1e9 turns off optional line breaks and improves 1008 # the portability of the outputted yaml. 1009 return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload] 1010 1011 1012# For some reason, some defaults we write to YAML are written as native 1013# YAML objects, rather than doing them uniformly as strings. This 1014# function detects those cases and converts them into native Python 1015# objects. 1016def pythonify_default(s: str) -> object: 1017 if s == "true": 1018 return True 1019 elif s == "false": 1020 return False 1021 1022 try: 1023 return int(s) 1024 except ValueError: 1025 try: 1026 return float(s) 1027 except ValueError: 1028 return s 1029 1030 1031# What is a dynamic type? Over time, the semantic meaning of 1032# dynamic type has degraded to meaninglessness (in the old days, 1033# it captured dtype-ness of types, but that has gone away with 1034# the removal of TH). These days, it's mostly the same thing as 1035# the C++ API argument type, except that Tensor and Tensor? 1036# arguments simply present as Tensor. 1037# 1038# TODO: Get rid of dynamic_type, after getting tools/autograd 1039# to use the new codegen framework 1040def dynamic_type(t: Type) -> str: 1041 if isinstance(t, OptionalType): 1042 return dynamic_type(t.elem) 1043 # Note we don't use t.is_tensor_like() here because it would 1044 # also include Tensor[] 1045 if str(t) == "Tensor": 1046 return "at::Tensor" 1047 # This is a legacy concept, so never report SymInt 1048 return cpp.argumenttype_type( 1049 t, mutable=False, binds="__placeholder__", symint=False 1050 ).cpp_type() 1051 1052 1053def compute_method_of_yaml(variants: set[Variant]) -> list[str]: 1054 # This is written out explicitly to ensure that Tensor and 1055 # namespace are put into the list in the right order 1056 method_of = ["Type"] 1057 if Variant.method in variants: 1058 method_of.append("Tensor") 1059 if Variant.function in variants: 1060 method_of.append("namespace") 1061 return method_of 1062 1063 1064def compute_returns_yaml( 1065 f: NativeFunction, 1066) -> tuple[list[dict[str, str]], dict[str, str]]: 1067 # Note [name and field_name] 1068 # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 1069 # To understand name_to_field_name, we must first talk about this 1070 # schema: 1071 # 1072 # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) 1073 # 1074 # There is something very odd about this schema: it is an out 1075 # variant of the function (that is to say, it will convert into 1076 # at::lstsq_out() in the C++ API), but the names of the output 1077 # return arguments don't match the keyword argument names of 1078 # the inputs. It TURNS OUT that in this situation, the historical 1079 # Declarations.yaml we want to output is this (abbreviated to 1080 # only show relevant fields): 1081 # 1082 # arguments: 1083 # ... 1084 # - field_name: solution 1085 # name: X 1086 # - field_name: QR 1087 # name: qr 1088 # ... 1089 # 1090 # returns: 1091 # - field_name: solution 1092 # name: X 1093 # - field_name: QR 1094 # name: qr 1095 # 1096 # The name of the return fields is stored in 'field_name', and the 1097 # name of the arguments is stored in 'name'. So when we process 1098 # arguments, we need a way to get at the corresponding return. At 1099 # the moment, this is most conveniently done by constructing a 1100 # mapping from name (the argument concept) to field_name (the 1101 # return concept) while processing return arguments, since we don't 1102 # directly maintain this correspondence in the modeling of function 1103 # schema itself. 1104 # 1105 # See also https://github.com/pytorch/pytorch/issues/43114 1106 name_to_field_name: dict[str, str] = {} 1107 1108 # Compute the returns field of the YAML entry 1109 names = cpp.return_names(f) 1110 returns = [] 1111 for i, (r, name) in enumerate(zip(f.func.returns, names)): 1112 ret = { 1113 "dynamic_type": dynamic_type(r.type), 1114 "name": name, 1115 # legacy, report ints 1116 "type": cpp.return_type(r, symint=False).cpp_type(), 1117 } 1118 1119 if r.name: 1120 # See Note [name and field_name] 1121 ret["field_name"] = r.name 1122 if f.func.is_out_fn(): 1123 name_to_field_name[f.func.arguments.out[i].name] = r.name 1124 1125 returns.append(ret) 1126 1127 return returns, name_to_field_name 1128 1129 1130# arguments in yaml roughly corresponds to the public C++ API 1131def compute_cpp_argument_yaml( 1132 cpp_a: Binding, 1133 *, 1134 schema_order: bool, 1135 kwarg_only_set: set[str], 1136 out_arg_set: set[str], 1137 name_to_field_name: dict[str, str], 1138) -> object: 1139 if isinstance(cpp_a.argument, TensorOptionsArguments): 1140 arg: dict[str, object] = { 1141 "annotation": None, 1142 "dynamic_type": "at::TensorOptions", 1143 "is_nullable": False, 1144 "name": cpp_a.name, 1145 "type": cpp_a.type, 1146 "kwarg_only": True, 1147 } 1148 if cpp_a.default is not None: 1149 arg["default"] = cpp_a.default 1150 return arg 1151 elif isinstance(cpp_a.argument, SelfArgument): 1152 raise AssertionError 1153 elif isinstance(cpp_a.argument, Argument): 1154 return compute_argument_yaml( 1155 cpp_a.argument, 1156 schema_order=schema_order, 1157 kwarg_only_set=kwarg_only_set, 1158 out_arg_set=out_arg_set, 1159 name_to_field_name=name_to_field_name, 1160 ) 1161 1162 1163def compute_argument_yaml( 1164 a: Argument, 1165 *, 1166 schema_order: bool, 1167 kwarg_only_set: set[str], 1168 out_arg_set: set[str], 1169 name_to_field_name: dict[str, str], 1170) -> object: 1171 arg: dict[str, object] = { 1172 "annotation": str(a.annotation) if a.annotation else None, 1173 "dynamic_type": dynamic_type(a.type), 1174 "is_nullable": a.type.is_nullable(), 1175 "name": a.name, 1176 # legacy, report ints 1177 "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(), 1178 } 1179 if a.default is not None: 1180 arg["default"] = pythonify_default( 1181 cpp.default_expr(a.default, a.type, symint=False) 1182 ) 1183 if a.name in kwarg_only_set: 1184 arg["kwarg_only"] = True 1185 if a.name in out_arg_set: 1186 arg["output"] = True 1187 arg["allocate"] = True 1188 # See Note [name and field_name] 1189 if a.name in name_to_field_name: 1190 arg["field_name"] = name_to_field_name[a.name] 1191 # Historically, booleans don't get their size recorded, because it 1192 # is already built into the cpp type (e.g., std::array<bool, 4>) 1193 l = a.type.is_list_like() 1194 if l is not None and l.size is not None and str(l.elem) != "bool": 1195 arg["size"] = l.size 1196 return arg 1197 1198 1199@with_native_function 1200def compute_declaration_yaml(f: NativeFunction) -> object: 1201 returns, name_to_field_name = compute_returns_yaml(f) 1202 1203 # These sets are used to conveniently test if an argument is a 1204 # kwarg-only or out argument 1205 kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only} 1206 out_arg_set = {a.name for a in f.func.arguments.out} 1207 1208 sig_group = CppSignatureGroup.from_native_function( 1209 f, method=False, fallback_binding=False 1210 ) 1211 cpp_args = sig_group.signature.arguments() 1212 arguments = [ 1213 compute_cpp_argument_yaml( 1214 cpp_a, 1215 schema_order=False, 1216 kwarg_only_set=kwarg_only_set, 1217 out_arg_set=out_arg_set, 1218 name_to_field_name=name_to_field_name, 1219 ) 1220 for cpp_a in cpp_args 1221 ] 1222 1223 schema_order_jit_arguments = list(f.func.schema_order_arguments()) 1224 1225 schema_order_arguments = [ 1226 compute_argument_yaml( 1227 a, 1228 schema_order=True, 1229 kwarg_only_set=kwarg_only_set, 1230 out_arg_set=out_arg_set, 1231 name_to_field_name=name_to_field_name, 1232 ) 1233 for a in schema_order_jit_arguments 1234 ] 1235 1236 cpp_schema_order_types = [ 1237 # NB: method here doesn't matter 1238 r.type 1239 for a in schema_order_jit_arguments 1240 for r in cpp.argument( 1241 a, 1242 method=False, 1243 cpp_no_default_args=set(), 1244 faithful=False, 1245 symint=False, 1246 has_tensor_options=False, 1247 ) 1248 ] 1249 1250 # legacy, report ints 1251 cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type() 1252 schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" 1253 1254 is_factory_method = ( 1255 any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) 1256 and Variant.method not in f.variants 1257 ) 1258 1259 return OrderedDict( 1260 [ 1261 ("name", cpp.name(f.func)), 1262 ("operator_name", str(f.func.name.name)), 1263 ("overload_name", str(f.func.name.overload_name)), 1264 ("manual_kernel_registration", f.manual_kernel_registration), 1265 ( 1266 "category_override", 1267 f.category_override if f.category_override is not None else "", 1268 ), 1269 ("schema_string", f"aten::{f.func}"), 1270 ("arguments", arguments), 1271 ("schema_order_cpp_signature", schema_order_cpp_signature), 1272 ("schema_order_arguments", schema_order_arguments), 1273 ("method_of", compute_method_of_yaml(f.variants)), 1274 ("mode", "native"), 1275 ("python_module", "" if f.python_module is None else f.python_module), 1276 ("returns", returns), 1277 ("inplace", f.func.name.name.inplace), 1278 ("is_factory_method", is_factory_method), 1279 ("abstract", f.is_abstract), 1280 ("device_guard", f.device_guard), 1281 ("with_gil", False), 1282 ("deprecated", False), 1283 ("has_math_kernel", f.has_composite_implicit_autograd_kernel), 1284 ] 1285 ) 1286 1287 1288# See Note [Auto generated composite kernels] 1289def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: 1290 return (f.structured or f.structured_delegate is not None) and ( 1291 f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace 1292 ) 1293 1294 1295@with_native_function_and_indices 1296def compute_registration_declarations( 1297 f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex] 1298) -> str: 1299 name = dispatcher.name(f.func) 1300 returns_type = dispatcher.returns_type( 1301 f.func.returns 1302 ).cpp_type_registration_declarations() 1303 args = dispatcher.arguments(f.func) 1304 args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args) 1305 comment_data: dict[str, str] = { 1306 "schema": f"aten::{f.func}", 1307 # TODO: What exactly is the semantics of the 'dispatch' field? 1308 "dispatch": str( 1309 {k for k, v in backend_indices.items() if v.has_kernel(f)} 1310 != {DispatchKey.CompositeImplicitAutograd} 1311 and {k for k, v in backend_indices.items() if v.has_kernel(f)} 1312 != { 1313 DispatchKey.CompositeImplicitAutograd, 1314 DispatchKey.CompositeImplicitAutogradNestedTensor, 1315 } 1316 ), 1317 "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)), 1318 } 1319 return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} 1320""" 1321 1322 1323# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1324# 1325# RUN IT ALL 1326# 1327# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1328 1329 1330def get_custom_build_selector( 1331 provided_op_registration_allowlist: list[str] | None, 1332 op_selection_yaml_path: str | None, 1333) -> SelectiveBuilder: 1334 assert not ( 1335 provided_op_registration_allowlist is not None 1336 and op_selection_yaml_path is not None 1337 ), ( 1338 "Both provided_op_registration_allowlist and " 1339 + "op_selection_yaml_path can NOT be provided at the " 1340 + "same time." 1341 ) 1342 1343 op_registration_allowlist: set[str] | None = None 1344 if provided_op_registration_allowlist is not None: 1345 op_registration_allowlist = set(provided_op_registration_allowlist) 1346 1347 if op_registration_allowlist is not None: 1348 selector = SelectiveBuilder.from_legacy_op_registration_allow_list( 1349 op_registration_allowlist, 1350 True, 1351 False, 1352 ) 1353 elif op_selection_yaml_path is not None: 1354 selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path) 1355 else: 1356 selector = SelectiveBuilder.get_nop_selector() 1357 1358 return selector 1359 1360 1361def get_grouped_by_view_native_functions( 1362 native_functions: Sequence[NativeFunction], 1363) -> Sequence[NativeFunction | NativeFunctionsViewGroup]: 1364 def maybe_create_view_group( 1365 d: dict[ViewSchemaKind | SchemaKind, NativeFunction] 1366 ) -> list[NativeFunction | NativeFunctionsViewGroup]: 1367 funcs: list[NativeFunction | NativeFunctionsViewGroup] = [] 1368 if ViewSchemaKind.aliasing in d: 1369 view = d.pop(ViewSchemaKind.aliasing) 1370 view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) 1371 view_copy = d.pop(SchemaKind.functional, None) 1372 1373 funcs.append( 1374 NativeFunctionsViewGroup( 1375 view=view, 1376 view_copy=view_copy, 1377 view_inplace=view_inplace, 1378 ) 1379 ) 1380 # Take the remaining functions that weren't part of the view group 1381 # and emit them separately 1382 funcs.extend(d.values()) 1383 return funcs 1384 1385 grouped_by_views: dict[ 1386 FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction] 1387 ] = defaultdict(dict) 1388 for f in native_functions: 1389 schema = f.func.view_signature() 1390 view_kind: ViewSchemaKind = f.view_schema_kind 1391 # We need to group up ops relevant to the same "view", consisting of: 1392 # view op (ViewSchemaKind.aliasing) 1393 # view_inplace op (ViewSchemaKind.aliasing_inplace) 1394 # view_copy op (SchemaKind.functional) 1395 if view_kind == ViewSchemaKind.non_aliasing: 1396 kind = f.func.kind() 1397 assert kind not in grouped_by_views[schema] 1398 grouped_by_views[schema][kind] = f 1399 else: 1400 assert ( 1401 view_kind not in grouped_by_views[schema] 1402 ), f"{view_kind} already in {grouped_by_views[schema].keys()}" 1403 grouped_by_views[schema][view_kind] = f 1404 1405 return list(concatMap(maybe_create_view_group, grouped_by_views.values())) 1406 1407 1408def get_grouped_native_functions( 1409 native_functions: Sequence[NativeFunction], 1410) -> Sequence[NativeFunction | NativeFunctionsGroup]: 1411 def flatten_pre_group( 1412 d: dict[SchemaKind, NativeFunction] 1413 ) -> Sequence[NativeFunction | NativeFunctionsGroup]: 1414 r = NativeFunctionsGroup.from_dict(d) 1415 if r is None: 1416 # Invariant: any NativeFunctions that are code-generated 1417 # should have been grouped into NativeFunctionsGroup objects 1418 assert not any("generated" in f.tags for f in d.values()) 1419 return list(d.values()) 1420 else: 1421 return [r] 1422 1423 # TODO: how come ValuesView isn't a Sequence lol 1424 pre_grouped_native_functions = pre_group_native_functions(native_functions) 1425 return list( 1426 concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())) 1427 ) 1428 1429 1430def get_ns_grouped_kernels( 1431 *, 1432 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1433 backend_indices: dict[DispatchKey, BackendIndex], 1434 native_function_decl_gen: Callable[ 1435 [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] 1436 ] = dest.compute_native_function_declaration, 1437) -> dict[str, list[str]]: 1438 ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) 1439 for f in grouped_native_functions: 1440 native_function_namespaces = set() 1441 dispatch_keys = set() 1442 for dispatch_key, backend_idx in backend_indices.items(): 1443 backend_metadata = backend_idx.get_kernel(f) 1444 if backend_metadata: 1445 namespace = backend_metadata.cpp_namespace 1446 dispatch_keys.add(dispatch_key) 1447 native_function_namespaces.add(namespace) 1448 else: 1449 namespace = DEFAULT_KERNEL_NAMESPACE 1450 assert ( 1451 len(native_function_namespaces) <= 1 1452 ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" 1453 ns_grouped_kernels[namespace].extend( 1454 native_function_decl_gen(f, backend_idx) 1455 ) 1456 return ns_grouped_kernels 1457 1458 1459def get_native_function_declarations_from_ns_grouped_kernels( 1460 *, 1461 ns_grouped_kernels: dict[str, list[str]], 1462) -> list[str]: 1463 declarations: list[str] = [] 1464 newline = "\n" 1465 for namespace, kernels in ns_grouped_kernels.items(): 1466 ns_helper = NamespaceHelper( 1467 namespace_str=namespace, 1468 entity_name="", 1469 max_level=4, 1470 ) 1471 # Convert to a set first to remove duplicate kernel names. Backends are 1472 # allowed to repeat kernel names; only generate the declaration once! 1473 ordered_kernels = list(OrderedDict.fromkeys(kernels)) 1474 declarations.extend( 1475 f""" 1476{ns_helper.prologue} 1477{newline.join(ordered_kernels)} 1478{ns_helper.epilogue} 1479 """.split( 1480 newline 1481 ) 1482 ) 1483 return declarations 1484 1485 1486# Return native function declarations grouped by their namespaces. 1487def get_native_function_declarations( 1488 *, 1489 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1490 backend_indices: dict[DispatchKey, BackendIndex], 1491 native_function_decl_gen: Callable[ 1492 [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] 1493 ] = dest.compute_native_function_declaration, 1494) -> list[str]: 1495 """ 1496 Generate kernel declarations, in `NativeFunction(s).h`. 1497 :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`. 1498 :param backend_indices: kernel collections grouped by dispatch key. 1499 :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`. 1500 :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline. 1501 """ 1502 1503 ns_grouped_kernels = get_ns_grouped_kernels( 1504 grouped_native_functions=grouped_native_functions, 1505 backend_indices=backend_indices, 1506 native_function_decl_gen=native_function_decl_gen, 1507 ) 1508 return get_native_function_declarations_from_ns_grouped_kernels( 1509 ns_grouped_kernels=ns_grouped_kernels 1510 ) 1511 1512 1513def get_kernel_namespace( 1514 *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex 1515) -> str: 1516 backend_metadata = backend_idx.get_kernel(f) 1517 assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( 1518 f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} " 1519 f"with dispatch key {backend_idx.dispatch_key}" 1520 f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'." 1521 ) 1522 return ( 1523 backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE 1524 ) 1525 1526 1527# Return native function definitions grouped by dispatch key and custom namespace. 1528# Used in RegisterDispatchKey.cpp and etc. 1529def get_native_function_definitions( 1530 *, 1531 fm: FileManager, 1532 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1533 dispatch_key: DispatchKey, 1534 backend_idx: BackendIndex, 1535 selector: SelectiveBuilder, 1536 rocm: bool, 1537 symint: bool, 1538 skip_dispatcher_op_registration: bool, 1539 gen_dispatch_helpers: bool, 1540) -> list[str]: 1541 definitions: list[str] = [] 1542 ns_definitions: dict[str, list[str]] = defaultdict(list) 1543 anonymous_definitions: dict[str, list[str]] = defaultdict(list) 1544 registrations: dict[str, dict[str, list[str]]] = defaultdict(dict) 1545 newline = "\n" 1546 ns_gen = dest.RegisterDispatchKey( 1547 backend_idx, 1548 Target.NAMESPACED_DEFINITION, 1549 selector, 1550 rocm=rocm, 1551 symint=symint, 1552 class_method_name=None, 1553 skip_dispatcher_op_registration=skip_dispatcher_op_registration, 1554 ) 1555 anonymous_gen = dest.RegisterDispatchKey( 1556 backend_idx, 1557 Target.ANONYMOUS_DEFINITION, 1558 selector, 1559 rocm=rocm, 1560 symint=symint, 1561 class_method_name=None, 1562 skip_dispatcher_op_registration=skip_dispatcher_op_registration, 1563 ) 1564 reg_gen = dest.RegisterDispatchKey( 1565 backend_idx, 1566 Target.REGISTRATION, 1567 selector, 1568 rocm=rocm, 1569 symint=symint, 1570 class_method_name=None, 1571 skip_dispatcher_op_registration=skip_dispatcher_op_registration, 1572 ) 1573 for f in grouped_native_functions: 1574 kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( 1575 "::native", "" 1576 ) 1577 1578 ns_definitions[kernel_namespace].extend( 1579 ns_gen(f), 1580 ) 1581 anonymous_definitions[kernel_namespace].extend( 1582 anonymous_gen(f), 1583 ) 1584 namespace = ( 1585 f.namespace if isinstance(f, NativeFunction) else f.functional.namespace 1586 ) 1587 if namespace not in registrations[kernel_namespace]: 1588 registrations[kernel_namespace] = defaultdict(list) 1589 registrations[kernel_namespace][namespace].extend( 1590 reg_gen(f), 1591 ) 1592 1593 for kernel_namespace in ns_definitions: 1594 if len(ns_definitions[kernel_namespace]) == 0: 1595 continue 1596 ns_helper = NamespaceHelper(namespace_str=kernel_namespace) 1597 registration_body = "" 1598 for namespace in registrations[kernel_namespace]: 1599 if not registrations[kernel_namespace][namespace]: 1600 continue 1601 registration_body += f""" 1602TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ 1603 {newline.join(registrations[kernel_namespace][namespace])} 1604}};""" 1605 definitions.extend( 1606 fm.substitute_with_template( 1607 "RegisterDispatchDefinitions.ini", 1608 lambda: { 1609 "ns_prologue": ns_helper.prologue, 1610 "ns_epilogue": ns_helper.epilogue, 1611 "dispatch_helpers": dest.gen_registration_helpers(backend_idx) 1612 if gen_dispatch_helpers 1613 else [], 1614 "dispatch_anonymous_definitions": anonymous_definitions[ 1615 kernel_namespace 1616 ], 1617 "static_init_dispatch_registrations": "" 1618 if skip_dispatcher_op_registration 1619 else registration_body, 1620 "deferred_dispatch_registrations": "", 1621 "dispatch_namespace": dispatch_key.lower(), 1622 "dispatch_namespaced_definitions": ns_definitions[kernel_namespace], 1623 }, 1624 ).split(newline) 1625 ) 1626 1627 return definitions 1628 1629 1630# Return native function declarations grouped by dispatch key and custom namespace. 1631# Used in CPUFunctions_inl.h and etc. 1632def get_namespaced_declaration( 1633 *, 1634 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1635 dispatch_key: DispatchKey, 1636 backend_idx: BackendIndex, 1637 selector: SelectiveBuilder, 1638 rocm: bool, 1639 symint: bool, 1640) -> list[str]: 1641 declarations: list[str] = [] 1642 ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) 1643 newline = "\n" 1644 func = dest.RegisterDispatchKey( 1645 backend_idx, 1646 Target.NAMESPACED_DECLARATION, 1647 selector, 1648 rocm=rocm, 1649 class_method_name=None, 1650 skip_dispatcher_op_registration=False, 1651 symint=symint, 1652 ) 1653 for f in grouped_native_functions: 1654 namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( 1655 "native", dispatch_key.lower() 1656 ) 1657 1658 ns_grouped_kernels[namespace].extend( 1659 func(f), 1660 ) 1661 1662 for namespace, kernels in ns_grouped_kernels.items(): 1663 if len(kernels) == 0: 1664 continue 1665 ns_helper = NamespaceHelper( 1666 namespace_str=namespace, entity_name="", max_level=3 1667 ) 1668 ordered_kernels = list(OrderedDict.fromkeys(kernels)) 1669 declarations.extend( 1670 f""" 1671{ns_helper.prologue} 1672{newline.join(ordered_kernels)} 1673{ns_helper.epilogue} 1674 """.split( 1675 newline 1676 ) 1677 ) 1678 return declarations 1679 1680 1681# Return native function schema registration code for aten and other namespaces. 1682def get_native_function_schema_registrations( 1683 *, 1684 native_functions: Sequence[NativeFunction], 1685 schema_selector: SelectiveBuilder, 1686) -> tuple[list[str], str]: 1687 ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) 1688 for native_function in native_functions: 1689 ns_native_functions[native_function.namespace].append(native_function) 1690 schema_registrations = "" 1691 aten_schema_registrations = [] 1692 custom_namespace = None 1693 for namespace, funcs in ns_native_functions.items(): 1694 schema_registrations_body = list( 1695 mapMaybe(RegisterSchema(schema_selector), funcs) 1696 ) 1697 # NB: we have to separate aten namespace registration from other namespaces, 1698 # because in the template we hardcoded an operator for ATen already. 1699 if namespace == "aten": 1700 aten_schema_registrations = schema_registrations_body 1701 else: 1702 custom_namespace = namespace 1703 tab = "\t" 1704 # if the namespace is predefined, we should use define a library fragment 1705 # instead of a new library 1706 torch_library_macro = ( 1707 "TORCH_LIBRARY_FRAGMENT" 1708 if namespace in FRAGMENT_NAMESPACES 1709 else "TORCH_LIBRARY" 1710 ) 1711 schema_registrations += f""" 1712{torch_library_macro}({custom_namespace}, m) {{ 1713 {tab.join(schema_registrations_body)} 1714}};""" 1715 return (aten_schema_registrations, schema_registrations) 1716 1717 1718def gen_aggregated_headers( 1719 *, 1720 native_functions: Sequence[NativeFunction], 1721 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1722 structured_native_functions: Sequence[NativeFunctionsGroup], 1723 static_dispatch_idx: list[BackendIndex], 1724 selector: SelectiveBuilder, 1725 backend_indices: dict[DispatchKey, BackendIndex], 1726 cpu_fm: FileManager, 1727 cuda_fm: FileManager, 1728 functions_keys: set[DispatchKey], 1729 dispatch_keys: Sequence[DispatchKey], 1730 rocm: bool, 1731) -> None: 1732 # Buck doesn't support dynamic output files, so we aggregate all operator 1733 # headers into a single file 1734 cpu_fm.write( 1735 "NativeMetaFunctions.h", 1736 lambda: { 1737 "NativeMetaFunctions_includes": [], 1738 "NativeMetaFunctions_declarations": list( 1739 mapMaybe(compute_meta_function_declaration, structured_native_functions) 1740 ), 1741 }, 1742 ) 1743 method_native_functions = [ 1744 fn for fn in native_functions if Variant.method in fn.variants 1745 ] 1746 non_method_native_functions = [ 1747 fn for fn in native_functions if fn not in method_native_functions 1748 ] 1749 cpu_fm.write( 1750 "MethodOperators.h", 1751 lambda: { 1752 "MethodOperators_includes": [], 1753 "MethodOperators_declarations": list( 1754 mapMaybe( 1755 ComputeOperators( 1756 Target.DECLARATION, 1757 static_dispatch_backend_indices=static_dispatch_idx, 1758 ), 1759 method_native_functions, 1760 ) 1761 ), 1762 }, 1763 ) 1764 cpu_fm.write( 1765 "Operators.h", 1766 lambda: { 1767 "Operators_includes": ["#include <ATen/MethodOperators.h>"], 1768 "Operators_declarations": list( 1769 mapMaybe( 1770 ComputeOperators( 1771 Target.DECLARATION, 1772 static_dispatch_backend_indices=static_dispatch_idx, 1773 ), 1774 non_method_native_functions, 1775 ) 1776 ), 1777 }, 1778 ) 1779 cpu_fm.write( 1780 "Functions.h", 1781 lambda: { 1782 "static_dispatch_extra_headers": static_dispatch_extra_headers( 1783 static_dispatch_idx 1784 ), 1785 "Functions_includes": ["#include <ATen/Operators.h>"], 1786 "Functions_declarations": list( 1787 mapMaybe( 1788 ComputeFunction(), 1789 native_functions, 1790 ) 1791 ), 1792 }, 1793 ) 1794 declarations = get_native_function_declarations( 1795 grouped_native_functions=grouped_native_functions, 1796 backend_indices=backend_indices, 1797 ) 1798 cpu_fm.write( 1799 "NativeFunctions.h", 1800 lambda: { 1801 "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"], 1802 "NativeFunctions_declarations": declarations, 1803 }, 1804 ) 1805 1806 for dispatch_key in dispatch_keys: 1807 fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm 1808 if dispatch_key in functions_keys: 1809 inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>" 1810 1811 fm.write_with_template( 1812 f"{dispatch_key}Functions.h", 1813 "DispatchKeyFunctions.h", 1814 lambda: { 1815 "dispatch_key": str(dispatch_key), 1816 "inline_headers": inl_headers, 1817 }, 1818 ) 1819 fm.write_with_template( 1820 f"{dispatch_key}Functions_inl.h", 1821 "DispatchKeyFunctions_inl.h", 1822 lambda: { 1823 "DispatchKeyFunctions_inl_includes": [], 1824 "dispatch_namespace": dispatch_key.lower(), 1825 "dispatch_namespaced_declarations": get_namespaced_declaration( 1826 grouped_native_functions=grouped_native_functions, 1827 dispatch_key=dispatch_key, 1828 backend_idx=backend_indices[dispatch_key], 1829 selector=selector, 1830 rocm=rocm, 1831 symint=True, 1832 ), 1833 }, 1834 ) 1835 1836 del fm 1837 1838 1839def gen_per_operator_headers( 1840 *, 1841 native_functions: Sequence[NativeFunction], 1842 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1843 static_dispatch_idx: list[BackendIndex], 1844 selector: SelectiveBuilder, 1845 backend_indices: dict[DispatchKey, BackendIndex], 1846 cpu_fm: FileManager, 1847 cuda_fm: FileManager, 1848 ops_fm: FileManager, 1849 functions_keys: set[DispatchKey], 1850 dispatch_keys: Sequence[DispatchKey], 1851 rocm: bool, 1852) -> None: 1853 # For CMake builds, split operator declarations into separate headers in 1854 # the ATen/ops folder to split up header dependencies 1855 functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list) 1856 for fn in native_functions: 1857 functions_by_root_name[fn.root_name].append(fn) 1858 1859 grouped_functions_by_root_name: dict[ 1860 str, list[NativeFunction | NativeFunctionsGroup] 1861 ] = defaultdict(list) 1862 for group in grouped_native_functions: 1863 name = group.root_name 1864 grouped_functions_by_root_name[name].append(group) 1865 1866 for name, functions in functions_by_root_name.items(): 1867 ops_fm.write_with_template( 1868 f"{name}_ops.h", 1869 "Operator.h", 1870 lambda: { 1871 "declarations": list( 1872 mapMaybe( 1873 ComputeOperators( 1874 Target.DECLARATION, 1875 static_dispatch_backend_indices=static_dispatch_idx, 1876 ), 1877 functions, 1878 ) 1879 ), 1880 }, 1881 ) 1882 1883 ops_fm.write_with_template( 1884 f"{name}.h", 1885 "Function.h", 1886 lambda: { 1887 "static_dispatch_ops_headers": list( 1888 mapMaybe( 1889 lambda fn: static_dispatch_ops_header( 1890 fn, backend_index=static_dispatch_idx 1891 ), 1892 functions, 1893 ) 1894 ), 1895 "operator_includes": f"#include <ATen/ops/{name}_ops.h>", 1896 "function_definitions": list( 1897 mapMaybe( 1898 ComputeFunction(), 1899 functions, 1900 ) 1901 ), 1902 }, 1903 ) 1904 1905 grouped_functions = grouped_functions_by_root_name.get(name, []) 1906 structured_functions = [ 1907 fn 1908 for fn in grouped_functions 1909 if isinstance(fn, NativeFunctionsGroup) and fn.structured 1910 ] 1911 is_structured = len(structured_functions) > 0 1912 1913 if is_structured: 1914 ops_fm.write_with_template( 1915 f"{name}_meta.h", 1916 "NativeMetaFunction.h", 1917 lambda: { 1918 "meta_function_declarations": list( 1919 mapMaybe( 1920 compute_meta_function_declaration, structured_functions 1921 ) 1922 ), 1923 }, 1924 ) 1925 declarations = get_native_function_declarations( 1926 grouped_native_functions=grouped_functions, 1927 backend_indices=backend_indices, 1928 native_function_decl_gen=dest.compute_native_function_declaration, 1929 ) 1930 ops_fm.write_with_template( 1931 f"{name}_native.h", 1932 "NativeFunction.h", 1933 lambda: { 1934 "extra_includes": ( 1935 f"#include <ATen/ops/{name}_meta.h>" if is_structured else [] 1936 ), 1937 "native_function_declarations": declarations, 1938 }, 1939 ) 1940 1941 for category, suffix in [ 1942 ("Functions", ""), 1943 ("Operators", "_ops"), 1944 ("NativeMetaFunctions", "_meta"), 1945 ("NativeFunctions", "_native"), 1946 ]: 1947 cpu_fm.write( 1948 f"{category}.h", 1949 lambda: { 1950 f"{category}_includes": [ 1951 f"#include <ATen/ops/{name}{suffix}.h>" 1952 for name in sorted(functions_by_root_name.keys()) 1953 ], 1954 f"{category}_declarations": [], 1955 }, 1956 ) 1957 1958 for dispatch_key in dispatch_keys: 1959 if dispatch_key not in functions_keys: 1960 continue 1961 1962 dispatch_namespace = dispatch_key.lower() 1963 dispatch_names = [] 1964 1965 for name, functions in functions_by_root_name.items(): 1966 grouped_functions = grouped_functions_by_root_name.get(name, []) 1967 declarations = list( 1968 concatMap( 1969 dest.RegisterDispatchKey( 1970 backend_indices[dispatch_key], 1971 Target.NAMESPACED_DECLARATION, 1972 selector, 1973 rocm=rocm, 1974 symint=True, 1975 class_method_name=None, 1976 skip_dispatcher_op_registration=False, 1977 ), 1978 grouped_functions, 1979 ) 1980 ) 1981 1982 if len(declarations) == 0: 1983 continue 1984 1985 dispatch_names.append(name) 1986 ops_fm.write_with_template( 1987 f"{name}_{dispatch_namespace}_dispatch.h", 1988 "DispatchKeyFunction.h", 1989 lambda: { 1990 "dispatch_namespace": dispatch_namespace, 1991 "dispatch_namespaced_declarations": declarations, 1992 }, 1993 ) 1994 1995 fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm 1996 inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>" 1997 1998 fm.write_with_template( 1999 f"{dispatch_key}Functions.h", 2000 "DispatchKeyFunctions.h", 2001 lambda: { 2002 "dispatch_key": str(dispatch_key), 2003 "inline_headers": inl_headers, 2004 }, 2005 ) 2006 fm.write_with_template( 2007 f"{dispatch_key}Functions_inl.h", 2008 "DispatchKeyFunctions_inl.h", 2009 lambda: { 2010 "dispatch_namespace": dispatch_namespace, 2011 "DispatchKeyFunctions_inl_includes": [ 2012 f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>" 2013 for name in sorted(dispatch_names) 2014 ], 2015 "dispatch_namespaced_declarations": [], 2016 }, 2017 ) 2018 del fm 2019 2020 cpu_fm.write( 2021 "MethodOperators.h", 2022 lambda: { 2023 "MethodOperators_includes": sorted( 2024 f"#include <ATen/ops/{name}_ops.h>" 2025 for name, functions in functions_by_root_name.items() 2026 if any(Variant.method in fn.variants for fn in functions) 2027 ), 2028 "MethodOperators_declarations": [], 2029 }, 2030 ) 2031 2032 2033def gen_headers( 2034 *, 2035 native_functions: Sequence[NativeFunction], 2036 valid_tags: set[str], 2037 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 2038 structured_native_functions: Sequence[NativeFunctionsGroup], 2039 static_dispatch_idx: list[BackendIndex], 2040 selector: SelectiveBuilder, 2041 backend_indices: dict[DispatchKey, BackendIndex], 2042 core_fm: FileManager, 2043 cpu_fm: FileManager, 2044 cuda_fm: FileManager, 2045 ops_fm: FileManager, 2046 dispatch_keys: Sequence[DispatchKey], 2047 functions_keys: set[DispatchKey], 2048 rocm: bool, 2049 per_operator_headers: bool, 2050) -> None: 2051 if per_operator_headers: 2052 gen_per_operator_headers( 2053 native_functions=native_functions, 2054 grouped_native_functions=grouped_native_functions, 2055 static_dispatch_idx=static_dispatch_idx, 2056 selector=selector, 2057 backend_indices=backend_indices, 2058 cpu_fm=cpu_fm, 2059 cuda_fm=cuda_fm, 2060 ops_fm=ops_fm, 2061 dispatch_keys=dispatch_keys, 2062 functions_keys=functions_keys, 2063 rocm=rocm, 2064 ) 2065 else: 2066 gen_aggregated_headers( 2067 native_functions=native_functions, 2068 grouped_native_functions=grouped_native_functions, 2069 structured_native_functions=structured_native_functions, 2070 static_dispatch_idx=static_dispatch_idx, 2071 selector=selector, 2072 backend_indices=backend_indices, 2073 cpu_fm=cpu_fm, 2074 cuda_fm=cuda_fm, 2075 dispatch_keys=dispatch_keys, 2076 functions_keys=functions_keys, 2077 rocm=rocm, 2078 ) 2079 2080 core_fm.write( 2081 "TensorBody.h", 2082 lambda: { 2083 "tensor_method_declarations": list( 2084 mapMaybe( 2085 ComputeTensorMethod( 2086 target=Target.DECLARATION, 2087 static_dispatch_backend_indices=static_dispatch_idx, 2088 ), 2089 native_functions, 2090 ) 2091 ), 2092 "tensor_method_definitions": list( 2093 mapMaybe( 2094 ComputeTensorMethod( 2095 target=Target.DEFINITION, 2096 static_dispatch_backend_indices=static_dispatch_idx, 2097 ), 2098 native_functions, 2099 ) 2100 ), 2101 }, 2102 ) 2103 2104 cpu_fm.write( 2105 "RedispatchFunctions.h", 2106 lambda: { 2107 "function_redispatch_definitions": list( 2108 mapMaybe(ComputeRedispatchFunction(), native_functions) 2109 ), 2110 }, 2111 ) 2112 2113 cpu_fm.write( 2114 "RegistrationDeclarations.h", 2115 lambda: { 2116 "registration_declarations": [ 2117 compute_registration_declarations(f, backend_indices) 2118 for f in native_functions 2119 ], 2120 }, 2121 ) 2122 2123 cpu_fm.write( 2124 "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) 2125 ) 2126 2127 def gen_aten_interned_strings() -> dict[str, str]: 2128 attrs: set[str] = set() # All function argument names 2129 names = set() # All ATen function names 2130 for func in native_functions: 2131 names.add(str(func.func.name.name)) 2132 # Some operators don't have a functional variant but we still create a 2133 # symbol without the underscore 2134 names.add(func.func.name.name.base) 2135 2136 attrs.update(arg.name for arg in func.func.schema_order_arguments()) 2137 2138 # These are keywords in C++, so aren't valid symbol names 2139 # https://en.cppreference.com/w/cpp/language/operator_alternative 2140 names -= { 2141 "and", 2142 "and_eq", 2143 "bitand", 2144 "bitor", 2145 "compl", 2146 "not", 2147 "not_eq", 2148 "or", 2149 "or_eq", 2150 "xor", 2151 "xor_eq", 2152 } 2153 2154 return { 2155 "aten_symbols": " \\\n".join( 2156 [f"_(aten, {name})" for name in sorted(names)] 2157 ), 2158 "attr_symbols": " \\\n".join( 2159 [f"_(attr, {name})" for name in sorted(attrs)] 2160 ), 2161 } 2162 2163 core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) 2164 2165 def gen_tags_enum() -> dict[str, str]: 2166 return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} 2167 2168 core_fm.write("enum_tag.h", gen_tags_enum) 2169 2170 2171def gen_source_files( 2172 *, 2173 native_functions: Sequence[NativeFunction], 2174 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 2175 structured_native_functions: Sequence[NativeFunctionsGroup], 2176 view_groups: Sequence[NativeFunctionsViewGroup], 2177 selector: SelectiveBuilder, 2178 static_dispatch_idx: list[BackendIndex], 2179 backend_indices: dict[DispatchKey, BackendIndex], 2180 aoti_fm: FileManager, 2181 core_fm: FileManager, 2182 cpu_fm: FileManager, 2183 cpu_vec_fm: FileManager, 2184 cuda_fm: FileManager, 2185 dispatch_keys: Sequence[DispatchKey], 2186 functions_keys: set[DispatchKey], 2187 rocm: bool, 2188 force_schema_registration: bool, 2189 per_operator_headers: bool, 2190 skip_dispatcher_op_registration: bool, 2191 update_aoti_c_shim: bool, 2192) -> None: 2193 extra_cuda_headers = """\ 2194#include <c10/cuda/CUDAGuard.h> 2195#include <ATen/cuda/ATenCUDAGeneral.h> 2196#include <ATen/cuda/CUDADevice.h> 2197#include <ATen/cuda/CUDAContext.h>""" 2198 if rocm: 2199 extra_cuda_headers = """\ 2200#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> 2201#include <ATen/hip/ATenHIPGeneral.h> 2202#include <ATen/hip/HIPDevice.h> 2203#include <ATen/hip/HIPContext.h>""" 2204 2205 for dispatch_key in dispatch_keys: 2206 fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm 2207 2208 if per_operator_headers: 2209 2210 def operator_headers() -> list[str]: 2211 headers = [] 2212 for g in grouped_native_functions: 2213 is_registered = False 2214 if backend_index.has_kernel(g): 2215 is_registered = True 2216 # The above has_kernel test on a group will only test for 2217 # the existence of out dispatch, because that's how 2218 # structured kernels work. But sometimes functions can be 2219 # grouped but not be structured, and then you need to check 2220 # each individual piece, as they may have manual dispatch 2221 # entries. 2222 elif isinstance(g, NativeFunctionsGroup) and any( 2223 backend_index.has_kernel(fn) for fn in g.functions() 2224 ): 2225 is_registered = True 2226 # TODO: this condition is a bit questionable 2227 # (It has to do with the fact that structured kernels get generated kernels 2228 # to the Meta + CompositeExplicitAutogradNonFunctional keys). 2229 elif g.structured and dispatch_key in ( 2230 DispatchKey.Meta, 2231 DispatchKey.CompositeExplicitAutogradNonFunctional, 2232 ): 2233 is_registered = True 2234 if not is_registered: 2235 continue 2236 2237 headers.append(f"#include <ATen/ops/{g.root_name}_native.h>") 2238 if ( 2239 dispatch_key 2240 == DispatchKey.CompositeExplicitAutogradNonFunctional 2241 ): 2242 headers.append(f"#include <ATen/ops/{g.root_name}.h>") 2243 if dispatch_key in functions_keys: 2244 headers.append( 2245 f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>" 2246 ) 2247 2248 return sorted(set(headers)) 2249 2250 else: 2251 2252 def operator_headers() -> list[str]: 2253 headers = ["#include <ATen/NativeFunctions.h>"] 2254 if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: 2255 headers.append("#include <ATen/Functions.h>") 2256 if dispatch_key in functions_keys: 2257 headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>") 2258 return headers 2259 2260 backend_index = backend_indices[dispatch_key] 2261 ns_grouped_native_functions = defaultdict(list) 2262 for grouped_native_function in grouped_native_functions: 2263 namespace = ( 2264 grouped_native_function.namespace 2265 if isinstance(grouped_native_function, NativeFunction) 2266 else grouped_native_function.functional.namespace 2267 ) 2268 ns_grouped_native_functions[namespace].append(grouped_native_function) 2269 2270 dispatch_namespace = str(dispatch_key).lower() 2271 2272 # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated 2273 # compilation will fail when `-Werror=unused-function` flag is set 2274 gen_dispatch_helpers: bool = ( 2275 dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor 2276 ) 2277 2278 dispatch_definitions = get_native_function_definitions( 2279 fm=fm, 2280 grouped_native_functions=grouped_native_functions, 2281 dispatch_key=dispatch_key, 2282 backend_idx=backend_index, 2283 selector=selector, 2284 rocm=rocm, 2285 symint=True, 2286 skip_dispatcher_op_registration=skip_dispatcher_op_registration, 2287 gen_dispatch_helpers=gen_dispatch_helpers, 2288 ) 2289 fm.write_with_template( 2290 f"Register{dispatch_key}.cpp", 2291 "RegisterDispatchKey.cpp", 2292 lambda: { 2293 "extra_cuda_headers": extra_cuda_headers 2294 if is_cuda_dispatch_key(dispatch_key) 2295 else "", 2296 "external_backend_headers": "", 2297 "dispatch_headers": dest.gen_registration_headers( 2298 backend_index, per_operator_headers, rocm 2299 ), 2300 "ops_headers": operator_headers(), 2301 "dispatch_helpers": "", 2302 "dispatch_definitions": dispatch_definitions, 2303 }, 2304 ) 2305 2306 for g in structured_native_functions: 2307 if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key): 2308 continue 2309 name = g.functional.func.name.name 2310 if dispatch_key is DispatchKey.CPU: 2311 assert fm is cpu_fm 2312 fm.write_with_template( 2313 f"UfuncCPU_{name}.cpp", 2314 "UfuncCPU.cpp", 2315 lambda: { 2316 "meta_declaration": compute_meta_function_declaration(g), 2317 "native_declaration": dest.compute_native_function_declaration( 2318 g, backend_indices[dispatch_key] 2319 ), 2320 "native_definitions": dest.compute_ufunc_cpu(g), 2321 }, 2322 ) 2323 cpu_vec_fm.write_with_template( 2324 f"UfuncCPUKernel_{name}.cpp", 2325 "UfuncCPUKernel.cpp", 2326 lambda: { 2327 "name": name, 2328 "native_definitions": dest.compute_ufunc_cpu_kernel(g), 2329 }, 2330 ) 2331 elif dispatch_key is DispatchKey.CUDA: 2332 cuda_headers = "#include <ATen/native/cuda/Loops.cuh>" 2333 if rocm: 2334 cuda_headers = "#include <ATen/native/hip/Loops.cuh>" 2335 fm.write_with_template( 2336 f"UfuncCUDA_{name}.cu", 2337 "UfuncCUDA.cu", 2338 lambda: { 2339 "name": name, 2340 "cuda_headers": cuda_headers, 2341 "meta_declaration": compute_meta_function_declaration(g), 2342 "native_declaration": dest.compute_native_function_declaration( 2343 g, backend_indices[dispatch_key] 2344 ), 2345 "native_definitions": dest.compute_ufunc_cuda(g), 2346 }, 2347 ) 2348 else: 2349 raise AssertionError(f"unrecognized {dispatch_key} for ufunc") 2350 2351 structured_func_group_dict = {} 2352 for func_group in structured_native_functions: 2353 for func in func_group.functions(): 2354 if func.structured_delegate is not None: 2355 structured_func_group_dict[func.structured_delegate] = func_group 2356 break 2357 2358 if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA): 2359 fallbacks = {} 2360 for func in native_functions: 2361 op_name = get_fallback_op_name(func) 2362 if op_name in inductor_fallback_ops: 2363 fallbacks[op_name] = func 2364 fallback_native_functions = tuple( 2365 value for _, value in sorted(fallbacks.items()) 2366 ) 2367 2368 # header files were checked in for ABI-compatiblilty checking 2369 header_file_name = f"c_shim_{dispatch_key.lower()}.h" 2370 new_header = gen_aoti_c_shim( 2371 fallback_native_functions, 2372 structured_func_group_dict, 2373 dispatch_key, 2374 backend_indices, 2375 header=True, 2376 includes="", 2377 ) 2378 if update_aoti_c_shim: 2379 aoti_fm.write( 2380 header_file_name, 2381 lambda: new_header, 2382 ) 2383 else: 2384 try: 2385 with open( 2386 os.path.join(aoti_fm.install_dir, header_file_name) 2387 ) as old_file: 2388 old_header = old_file.read() 2389 assert ( 2390 old_header == new_header 2391 ), """ 2392 2393WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This 2394indicates an AOTInductor fallback operator ABI backward compatibility breakage!!! 2395Only in a limited number of situations, this is allowed: 2396 23971. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py. 2398If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing 2399C shim header files. 2400 24012. You added a new default argument to an existing fallback op. This is clearly a BC breaking 2402change in the AOTInductor land. In this case, you need to keep a manual copy of that existing 2403fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version 2404number of that fallback op in the newly generated C shim files, and update the cpp wrapper 2405codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance. 2406 2407 """ 2408 except FileNotFoundError: 2409 print( 2410 f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found" 2411 ) 2412 2413 # cpp files are always generated on-the-fly 2414 def headers_for_aoti() -> str: 2415 headers = [] 2416 for func in fallback_native_functions: 2417 header = get_header_for_aoti( 2418 func, structured_func_group_dict, dispatch_key, backend_indices 2419 ) 2420 if header is not None: 2421 headers.append(header) 2422 return "\n".join(sorted(set(headers))) 2423 2424 extra_headers = ( 2425 extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" 2426 ) 2427 2428 aoti_fm.write( 2429 f"c_shim_{dispatch_key.lower()}.cpp", 2430 lambda: gen_aoti_c_shim( 2431 fallback_native_functions, 2432 structured_func_group_dict, 2433 dispatch_key, 2434 backend_indices, 2435 header=False, 2436 includes=headers_for_aoti() + "\n" + extra_headers, 2437 ), 2438 ) 2439 2440 del fm 2441 2442 # BackendSelect is generated specially 2443 def gen_backend_select() -> dict[str, list[str]]: 2444 relevant_fns = [ 2445 fn for fn in native_functions if needs_backend_select(fn, selector) 2446 ] 2447 return { 2448 "ops_headers": [ 2449 f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns 2450 ], 2451 "backend_select_method_definitions": list( 2452 mapMaybe( 2453 ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns 2454 ) 2455 ), 2456 "backend_select_function_registrations": list( 2457 mapMaybe( 2458 ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns 2459 ) 2460 ), 2461 } 2462 2463 cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select) 2464 2465 schema_selector = selector 2466 if force_schema_registration: 2467 schema_selector = SelectiveBuilder.get_nop_selector() 2468 2469 ( 2470 aten_schema_registrations, 2471 schema_registrations, 2472 ) = get_native_function_schema_registrations( 2473 native_functions=native_functions, schema_selector=schema_selector 2474 ) 2475 cpu_fm.write( 2476 "RegisterSchema.cpp", 2477 lambda: { 2478 "aten_schema_registrations": [] 2479 if skip_dispatcher_op_registration 2480 else aten_schema_registrations, 2481 "schema_registrations": [] 2482 if skip_dispatcher_op_registration 2483 else schema_registrations, 2484 }, 2485 ) 2486 2487 def key_func( 2488 fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 2489 ) -> str: 2490 return fn.root_name 2491 2492 cpu_fm.write_sharded( 2493 "Operators.cpp", 2494 native_functions, 2495 key_fn=key_func, 2496 env_callable=lambda fn: { 2497 "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"], 2498 "definitions": [ 2499 ComputeOperators( 2500 Target.DEFINITION, 2501 static_dispatch_backend_indices=static_dispatch_idx, 2502 )(fn) 2503 ], 2504 }, 2505 base_env={ 2506 "static_dispatch_extra_headers": static_dispatch_extra_headers( 2507 static_dispatch_idx 2508 ), 2509 }, 2510 num_shards=5, 2511 sharded_keys={ 2512 "operator_headers", 2513 "definitions", 2514 "static_dispatch_extra_headers", 2515 }, 2516 ) 2517 2518 cpu_fm.write("Functions.cpp", dict) 2519 2520 core_fm.write("TensorMethods.cpp", dict) 2521 2522 core_fm.write( 2523 "ATenOpList.cpp", 2524 lambda: { 2525 "aten_ops": list(mapMaybe(compute_aten_op, native_functions)), 2526 }, 2527 ) 2528 2529 def functionalization_env_callable( 2530 g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 2531 ) -> dict[str, list[str]]: 2532 def gen_op_headers( 2533 g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 2534 ) -> list[str]: 2535 if isinstance(g, NativeFunctionsViewGroup): 2536 # view ops always get a functionalization kernel 2537 headers = [ 2538 f"#include <ATen/ops/{g.view.root_name}_native.h>", 2539 f"#include <ATen/ops/{g.view.root_name}_ops.h>", 2540 ] 2541 if g.view_copy is not None: 2542 headers += [ 2543 f"#include <ATen/ops/{g.view_copy.root_name}_native.h>", 2544 f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>", 2545 ] 2546 return headers 2547 elif isinstance(g, NativeFunctionsGroup): 2548 headers = [ 2549 f"#include <ATen/ops/{g.functional.root_name}_native.h>", 2550 f"#include <ATen/ops/{g.functional.root_name}_ops.h>", 2551 f"#include <ATen/ops/{g.out.root_name}_native.h>", 2552 f"#include <ATen/ops/{g.out.root_name}_ops.h>", 2553 ] 2554 if g.inplace is not None: 2555 headers += [ 2556 f"#include <ATen/ops/{g.inplace.root_name}_native.h>", 2557 f"#include <ATen/ops/{g.inplace.root_name}_ops.h>", 2558 ] 2559 if g.mutable is not None: 2560 headers += [ 2561 f"#include <ATen/ops/{g.mutable.root_name}_native.h>", 2562 f"#include <ATen/ops/{g.mutable.root_name}_ops.h>", 2563 ] 2564 return headers 2565 else: 2566 return [ 2567 f"#include <ATen/ops/{g.root_name}_native.h>", 2568 f"#include <ATen/ops/{g.root_name}_ops.h>", 2569 ] 2570 2571 return { 2572 "ops_headers": gen_op_headers(g), 2573 "func_definitions": gen_functionalization_definition( 2574 selector, 2575 g, 2576 ), 2577 "func_registrations": gen_functionalization_registration( 2578 selector, 2579 g, 2580 backend_indices[DispatchKey.CompositeImplicitAutograd], 2581 ), 2582 } 2583 2584 all_groups: list[ 2585 NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup 2586 ] = list(structured_native_functions) + list( 2587 view_groups # type: ignore[assignment, arg-type, operator] 2588 ) 2589 # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly. 2590 # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because: 2591 # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) 2592 # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. 2593 # Although this could go away long-term if we add a dedicated dispatch key for decompositions. 2594 structured_map: dict[OperatorName, NativeFunction] = { 2595 f.func.name: f 2596 for f in concatMap(lambda g: list(g.functions()), structured_native_functions) 2597 } 2598 view_map: dict[OperatorName, NativeFunction] = { 2599 f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) 2600 } 2601 for f in native_functions: 2602 if f.func.name not in structured_map and f.func.name not in view_map: 2603 all_groups.append(f) 2604 2605 cpu_fm.write_sharded( 2606 "RegisterFunctionalization.cpp", 2607 all_groups, 2608 key_fn=key_func, 2609 env_callable=functionalization_env_callable, 2610 num_shards=4, 2611 sharded_keys={ 2612 "ops_headers", 2613 "func_definitions", 2614 "func_registrations", 2615 "func_add_back_views_definitions", 2616 "func_add_back_views_registrations", 2617 }, 2618 ) 2619 2620 cpu_fm.write( 2621 "FunctionalInverses.h", 2622 lambda: { 2623 "view_inverse_declarations": list( 2624 mapMaybe( 2625 lambda g: gen_functionalization_view_inverse_declaration( 2626 selector, g 2627 ), 2628 view_groups, 2629 ) 2630 ) 2631 }, 2632 ) 2633 2634 # Note [view_copy NativeFunctions] 2635 # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd 2636 # needs to have a corresponding non-aliasing {view}_copy variant. 2637 # Backends that use functionalization and don't know how to handle aliasing ops 2638 # are expected to implement kernels for these {view}_copy kernels instead. 2639 # The code for {view}_copy operators in core is pretty boilerplate-heavy however, 2640 # so we codegen the following: 2641 # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator. 2642 # These are never explicitly invoked by the functionalization pass, 2643 # but they could theoretically be called from user code (I added these kernels for completeness, 2644 # since the ops are part of the public API). 2645 # (2) A derivative formula for every {view}_copy operator 2646 # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts, 2647 # so rather than stamping all of the entries out in derivatives.yaml, 2648 # we codegen them in. 2649 # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry. 2650 cpu_fm.write( 2651 "CompositeViewCopyKernels.cpp", 2652 lambda: { 2653 "ops_headers": [ 2654 "\n".join( 2655 f"#include <ATen/ops/{f.root_name}_ops.h>\n" 2656 # NB: this include is important as it ensures we 2657 # set the visibility on generated view_copy kernels 2658 # correctly 2659 f"#include <ATen/ops/{f.root_name}_native.h>" 2660 for f in ( 2661 [g.view] if g.view_copy is None else [g.view, g.view_copy] 2662 ) 2663 ) 2664 for g in view_groups 2665 ] 2666 + [ 2667 "\n".join( 2668 f"#include <ATen/ops/{f.root_name}_ops.h>\n" 2669 # NB: this include is also important for correct visibility 2670 f"#include <ATen/ops/{f.root_name}_native.h>" 2671 for f in [g.inplace, g.mutable, g.functional] 2672 if f is not None and "generated" not in f.tags 2673 ) 2674 for g in structured_native_functions 2675 ], 2676 "CompositeViewCopyKernel_Definitions": list( 2677 mapMaybe( 2678 GenCompositeViewCopyKernel( 2679 backend_indices[ 2680 DispatchKey.CompositeExplicitAutogradNonFunctional 2681 ] 2682 ), 2683 view_groups, 2684 ) 2685 ), 2686 "GeneratedCompositeFunctional_Definitions": list( 2687 mapMaybe( 2688 gen_composite_functional_kernel, 2689 structured_native_functions, 2690 ) 2691 ), 2692 "GeneratedCompositeOut_Definitions": list( 2693 mapMaybe( 2694 gen_composite_out_kernel, 2695 structured_native_functions, 2696 ) 2697 ), 2698 }, 2699 ) 2700 2701 2702def gen_declarations_yaml( 2703 cpu_fm: FileManager, native_functions: Sequence[NativeFunction] 2704) -> None: 2705 cpu_fm.write( 2706 "Declarations.yaml", 2707 lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]), 2708 ) 2709 2710 2711def get_torchgen_root() -> Path: 2712 """ 2713 If you're depending on torchgen out-of-tree, you can use the root to figure 2714 out the path to native_functions.yaml 2715 """ 2716 return Path(__file__).parent.resolve() 2717 2718 2719def main() -> None: 2720 parser = argparse.ArgumentParser(description="Generate ATen source files") 2721 parser.add_argument( 2722 "-s", 2723 "--source-path", 2724 help="path to source directory for ATen", 2725 default="aten/src/ATen", 2726 ) 2727 parser.add_argument( 2728 "-o", 2729 "--output-dependencies", 2730 help="output a list of dependencies into the given file and exit", 2731 ) 2732 parser.add_argument( 2733 "--dry-run", 2734 action="store_true", 2735 help="run without writing any files (still updates outputs)", 2736 ) 2737 parser.add_argument( 2738 "--per-operator-headers", 2739 action="store_true", 2740 help="generate separate headers per operator in ATen/ops", 2741 ) 2742 parser.add_argument( 2743 "-d", 2744 "--install-dir", 2745 "--install_dir", 2746 help="output directory", 2747 default="build/aten/src/ATen", 2748 ) 2749 parser.add_argument( 2750 "--aoti-install-dir", 2751 "--aoti_install_dir", 2752 help="output directory for AOTInductor shim", 2753 default="torch/csrc/inductor/aoti_torch/generated", 2754 ) 2755 parser.add_argument( 2756 "--rocm", 2757 action="store_true", 2758 help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", 2759 ) 2760 parser.add_argument( 2761 "--mps", 2762 action="store_true", 2763 help="Generate MPS registration code when set", 2764 ) 2765 # TODO: --op-registration-whitelist will be removed when all call-sites 2766 # for gen.py are moved over to using the operator YAML file for mobile 2767 # custom build. 2768 parser.add_argument( 2769 "--op-registration-whitelist", 2770 "--op_registration_whitelist", 2771 nargs="*", 2772 help="filter op registrations by the whitelist (if set); " 2773 "each item is `namespace`::`operator name` without overload name; " 2774 "e.g.: aten::empty aten::conv2d ...", 2775 ) 2776 parser.add_argument( 2777 "--op-selection-yaml-path", 2778 "--op_selection_yaml_path", 2779 help="Provide a path to the operator selection (for custom build) YAML " 2780 "that contains the information about the set of selected operators " 2781 "and their categories (training, ...). Each operator is either a " 2782 "full operator name with overload or just a bare operator name. " 2783 "The operator names also contain the namespace prefix (e.g. aten::)", 2784 ) 2785 parser.add_argument( 2786 "--backend-whitelist", 2787 "--backend_whitelist", 2788 nargs="*", 2789 help="filter dispatch backend by the whitelist (if set), " 2790 "e.g.: CPU CUDA QuantizedCPU ...", 2791 ) 2792 parser.add_argument( 2793 "--static-dispatch-backend", 2794 "--static_dispatch_backend", 2795 nargs="*", 2796 help="generate static dispatch code for the specific backend (if set)", 2797 ) 2798 parser.add_argument( 2799 "--skip-dispatcher-op-registration", 2800 "--skip_dispatcher_op_registration", 2801 action="store_true", 2802 help="Avoid registering operators into the dispatcher.", 2803 ) 2804 parser.add_argument( 2805 "--force-schema-registration", 2806 "--force_schema_registration", 2807 action="store_true", 2808 help="force it to generate schema-only registrations for all ops, including" 2809 "those that are not listed on --op-registration-whitelist", 2810 ) 2811 parser.add_argument( 2812 "--generate", 2813 type=str, 2814 nargs="*", 2815 choices=["headers", "sources", "declarations_yaml"], 2816 default=["headers", "sources", "declarations_yaml"], 2817 help="Generate only a subset of files", 2818 ) 2819 parser.add_argument( 2820 "--update-aoti-c-shim", 2821 action="store_true", 2822 help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. " 2823 "WARNING: Do not use this unless you are sure what you are doing!!!", 2824 ) 2825 2826 options = parser.parse_args() 2827 2828 selector = get_custom_build_selector( 2829 options.op_registration_whitelist, 2830 options.op_selection_yaml_path, 2831 ) 2832 2833 native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") 2834 tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") 2835 2836 from torchgen.model import dispatch_keys 2837 2838 # TODO: stop generating CUDA kernels for non-CUDA builds 2839 ignore_keys = set() 2840 if not options.mps: 2841 ignore_keys.add(DispatchKey.MPS) 2842 2843 if DispatchKey.MPS in dispatch_keys: 2844 del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] 2845 2846 parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) 2847 valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] 2848 native_functions, backend_indices = ( 2849 parsed_yaml.native_functions, 2850 parsed_yaml.backend_indices, 2851 ) 2852 2853 grouped_native_functions = get_grouped_native_functions(native_functions) 2854 2855 structured_native_functions = [ 2856 g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup) 2857 ] 2858 native_functions_with_view_groups = get_grouped_by_view_native_functions( 2859 native_functions 2860 ) 2861 view_groups = [ 2862 g 2863 for g in native_functions_with_view_groups 2864 if isinstance(g, NativeFunctionsViewGroup) 2865 ] 2866 2867 # NB: It is mandatory to NOT use os.path.join here, as the install directory 2868 # will eventually be ingested by cmake, which does not respect Windows style 2869 # path slashes. If you switch this to use os.path.join, you'll get an error 2870 # like: 2871 # 2872 # Syntax error in cmake code when parsing string 2873 # 2874 # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h 2875 # 2876 # Invalid character escape '\c'. 2877 core_install_dir = f"{options.install_dir}/core" 2878 Path(core_install_dir).mkdir(parents=True, exist_ok=True) 2879 ops_install_dir = f"{options.install_dir}/ops" 2880 Path(ops_install_dir).mkdir(parents=True, exist_ok=True) 2881 aoti_install_dir = f"{options.aoti_install_dir}" 2882 Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) 2883 2884 core_fm = make_file_manager(options=options, install_dir=core_install_dir) 2885 cpu_fm = make_file_manager(options=options) 2886 cpu_vec_fm = make_file_manager(options=options) 2887 cuda_fm = make_file_manager(options=options) 2888 ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) 2889 aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir) 2890 2891 # Only a limited set of dispatch keys get CPUFunctions.h headers generated 2892 # for them; this is the set 2893 functions_keys = { 2894 DispatchKey.CPU, 2895 DispatchKey.CUDA, 2896 DispatchKey.CompositeImplicitAutograd, 2897 DispatchKey.CompositeImplicitAutogradNestedTensor, 2898 DispatchKey.CompositeExplicitAutograd, 2899 DispatchKey.CompositeExplicitAutogradNonFunctional, 2900 DispatchKey.Meta, 2901 } 2902 if options.mps: 2903 functions_keys.add(DispatchKey.MPS) 2904 2905 if options.backend_whitelist: 2906 dispatch_keys = [ 2907 k 2908 for k in dispatch_keys 2909 if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist 2910 ] 2911 2912 static_dispatch_idx: list[BackendIndex] = [] 2913 if options.static_dispatch_backend: 2914 static_dispatch_idx = [ 2915 backend_indices[DispatchKey.parse(key)] 2916 for key in options.static_dispatch_backend 2917 ] 2918 for key in options.static_dispatch_backend: 2919 dp_key = DispatchKey.parse(key) 2920 if dp_key not in functions_keys: 2921 functions_keys.add(dp_key) 2922 2923 if "sources" in options.generate: 2924 gen_source_files( 2925 native_functions=native_functions, 2926 grouped_native_functions=grouped_native_functions, 2927 structured_native_functions=structured_native_functions, 2928 view_groups=view_groups, 2929 selector=selector, 2930 static_dispatch_idx=static_dispatch_idx, 2931 backend_indices=backend_indices, 2932 aoti_fm=aoti_fm, 2933 core_fm=core_fm, 2934 cpu_fm=cpu_fm, 2935 cpu_vec_fm=cpu_vec_fm, 2936 cuda_fm=cuda_fm, 2937 dispatch_keys=dispatch_keys, 2938 functions_keys=functions_keys, 2939 rocm=options.rocm, 2940 force_schema_registration=options.force_schema_registration, 2941 per_operator_headers=options.per_operator_headers, 2942 skip_dispatcher_op_registration=options.skip_dispatcher_op_registration, 2943 update_aoti_c_shim=options.update_aoti_c_shim, 2944 ) 2945 2946 if "headers" in options.generate: 2947 gen_headers( 2948 native_functions=native_functions, 2949 valid_tags=valid_tags, 2950 grouped_native_functions=grouped_native_functions, 2951 structured_native_functions=structured_native_functions, 2952 static_dispatch_idx=static_dispatch_idx, 2953 selector=selector, 2954 backend_indices=backend_indices, 2955 core_fm=core_fm, 2956 cpu_fm=cpu_fm, 2957 cuda_fm=cuda_fm, 2958 ops_fm=ops_fm, 2959 dispatch_keys=dispatch_keys, 2960 functions_keys=functions_keys, 2961 rocm=options.rocm, 2962 per_operator_headers=options.per_operator_headers, 2963 ) 2964 2965 if "declarations_yaml" in options.generate: 2966 gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm) 2967 2968 if options.output_dependencies: 2969 depfile_path = Path(options.output_dependencies).resolve() 2970 depfile_name = depfile_path.name 2971 depfile_stem = depfile_path.stem 2972 2973 for fm, prefix in [ 2974 (cpu_fm, ""), 2975 (cpu_vec_fm, "cpu_vec_"), 2976 (core_fm, "core_"), 2977 (cuda_fm, "cuda_"), 2978 (ops_fm, "ops_"), 2979 ]: 2980 varname = prefix + depfile_stem 2981 path = depfile_path.parent / (prefix + depfile_name) 2982 fm.write_outputs(varname, str(path)) 2983 2984 2985if __name__ == "__main__": 2986 main() 2987