1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport argparse 4*da0073e9SAndroid Build Coastguard Workerimport functools 5*da0073e9SAndroid Build Coastguard Workerimport json 6*da0073e9SAndroid Build Coastguard Workerimport os 7*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, namedtuple, OrderedDict 8*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field 9*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 10*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Literal, Sequence, TypeVar 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport yaml 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.dispatcher as dispatcher 15*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.meta as meta 16*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.native as native 17*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.structured as structured 18*da0073e9SAndroid Build Coastguard Workerimport torchgen.dest as dest 19*da0073e9SAndroid Build Coastguard Workerfrom torchgen.aoti.fallback_ops import inductor_fallback_ops 20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 21*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.translate import translate 22*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 23*da0073e9SAndroid Build Coastguard Worker Binding, 24*da0073e9SAndroid Build Coastguard Worker CppSignature, 25*da0073e9SAndroid Build Coastguard Worker CppSignatureGroup, 26*da0073e9SAndroid Build Coastguard Worker DispatcherSignature, 27*da0073e9SAndroid Build Coastguard Worker NamedCType, 28*da0073e9SAndroid Build Coastguard Worker NativeSignature, 29*da0073e9SAndroid Build Coastguard Worker SpecialArgName, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import ( 32*da0073e9SAndroid Build Coastguard Worker method_with_native_function, 33*da0073e9SAndroid Build Coastguard Worker native_function_manager, 34*da0073e9SAndroid Build Coastguard Worker with_native_function, 35*da0073e9SAndroid Build Coastguard Worker with_native_function_and_indices, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen_aoti_c_shim import ( 38*da0073e9SAndroid Build Coastguard Worker gen_aoti_c_shim, 39*da0073e9SAndroid Build Coastguard Worker gen_static_dispatch_backend_call_signature, 40*da0073e9SAndroid Build Coastguard Worker get_fallback_op_name, 41*da0073e9SAndroid Build Coastguard Worker get_header_for_aoti, 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen_functionalization_type import ( 44*da0073e9SAndroid Build Coastguard Worker gen_functionalization_definition, 45*da0073e9SAndroid Build Coastguard Worker gen_functionalization_registration, 46*da0073e9SAndroid Build Coastguard Worker gen_functionalization_view_inverse_declaration, 47*da0073e9SAndroid Build Coastguard Worker GenCompositeViewCopyKernel, 48*da0073e9SAndroid Build Coastguard Worker) 49*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing 50*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 51*da0073e9SAndroid Build Coastguard Worker Argument, 52*da0073e9SAndroid Build Coastguard Worker BackendIndex, 53*da0073e9SAndroid Build Coastguard Worker BackendMetadata, 54*da0073e9SAndroid Build Coastguard Worker BaseOperatorName, 55*da0073e9SAndroid Build Coastguard Worker DEFAULT_KERNEL_NAMESPACE, 56*da0073e9SAndroid Build Coastguard Worker DispatchKey, 57*da0073e9SAndroid Build Coastguard Worker FRAGMENT_NAMESPACES, 58*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 59*da0073e9SAndroid Build Coastguard Worker is_cuda_dispatch_key, 60*da0073e9SAndroid Build Coastguard Worker is_generic_dispatch_key, 61*da0073e9SAndroid Build Coastguard Worker is_ufunc_dispatch_key, 62*da0073e9SAndroid Build Coastguard Worker is_xpu_dispatch_key, 63*da0073e9SAndroid Build Coastguard Worker Location, 64*da0073e9SAndroid Build Coastguard Worker NativeFunction, 65*da0073e9SAndroid Build Coastguard Worker NativeFunctionsGroup, 66*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup, 67*da0073e9SAndroid Build Coastguard Worker OperatorName, 68*da0073e9SAndroid Build Coastguard Worker OptionalType, 69*da0073e9SAndroid Build Coastguard Worker SchemaKind, 70*da0073e9SAndroid Build Coastguard Worker SelfArgument, 71*da0073e9SAndroid Build Coastguard Worker STRUCTURED_DISPATCH_KEYS, 72*da0073e9SAndroid Build Coastguard Worker TensorOptionsArguments, 73*da0073e9SAndroid Build Coastguard Worker Type, 74*da0073e9SAndroid Build Coastguard Worker Variant, 75*da0073e9SAndroid Build Coastguard Worker ViewSchemaKind, 76*da0073e9SAndroid Build Coastguard Worker) 77*da0073e9SAndroid Build Coastguard Workerfrom torchgen.native_function_generation import ( 78*da0073e9SAndroid Build Coastguard Worker add_generated_native_functions, 79*da0073e9SAndroid Build Coastguard Worker gen_composite_functional_kernel, 80*da0073e9SAndroid Build Coastguard Worker gen_composite_out_kernel, 81*da0073e9SAndroid Build Coastguard Worker pre_group_native_functions, 82*da0073e9SAndroid Build Coastguard Worker) 83*da0073e9SAndroid Build Coastguard Workerfrom torchgen.selective_build.selector import SelectiveBuilder 84*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import ( 85*da0073e9SAndroid Build Coastguard Worker assert_never, 86*da0073e9SAndroid Build Coastguard Worker concatMap, 87*da0073e9SAndroid Build Coastguard Worker context, 88*da0073e9SAndroid Build Coastguard Worker FileManager, 89*da0073e9SAndroid Build Coastguard Worker make_file_manager, 90*da0073e9SAndroid Build Coastguard Worker mapMaybe, 91*da0073e9SAndroid Build Coastguard Worker NamespaceHelper, 92*da0073e9SAndroid Build Coastguard Worker Target, 93*da0073e9SAndroid Build Coastguard Worker) 94*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlDumper, YamlLoader 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T") 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker# Welcome to the ATen code generator v2! The ATen code generator is 100*da0073e9SAndroid Build Coastguard Worker# responsible for parsing native_functions.yaml and then generating 101*da0073e9SAndroid Build Coastguard Worker# various generated files (e.g., TypeDefault.cpp) based on the operators 102*da0073e9SAndroid Build Coastguard Worker# defined in this file. This means that the code generator knows how to 103*da0073e9SAndroid Build Coastguard Worker# parse function schema, and then translate this into various C++ types 104*da0073e9SAndroid Build Coastguard Worker# and boilerplate code. 105*da0073e9SAndroid Build Coastguard Worker# 106*da0073e9SAndroid Build Coastguard Worker# Some things to know about this file when you modify it: 107*da0073e9SAndroid Build Coastguard Worker# 108*da0073e9SAndroid Build Coastguard Worker# - This file has STRICT mypy typechecking. Typecheck it with 109*da0073e9SAndroid Build Coastguard Worker# `mypy --config mypy-strict.ini` in the root source directory 110*da0073e9SAndroid Build Coastguard Worker# 111*da0073e9SAndroid Build Coastguard Worker# - Most of the heavy lifting lives in external modules: 112*da0073e9SAndroid Build Coastguard Worker# - 'model' has the data model for native_functions.yaml. The classes 113*da0073e9SAndroid Build Coastguard Worker# in those file represent what you see when you look at 114*da0073e9SAndroid Build Coastguard Worker# a native_functions.yaml 115*da0073e9SAndroid Build Coastguard Worker# - 'api' has conversions for how to translate JIT schema into 116*da0073e9SAndroid Build Coastguard Worker# the various C++ APIs that the codegen interacts with. There 117*da0073e9SAndroid Build Coastguard Worker# are in fact THREE different C++ APIs: the public C++ API, 118*da0073e9SAndroid Build Coastguard Worker# the dispatcher API, and the legacy dispatcher API. See each 119*da0073e9SAndroid Build Coastguard Worker# of these respective files for more information 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 122*da0073e9SAndroid Build Coastguard Worker# 123*da0073e9SAndroid Build Coastguard Worker# HELPER FUNCTIONS 124*da0073e9SAndroid Build Coastguard Worker# 125*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker# A custom loader for YAML to let us also keep track of line numbers 129*da0073e9SAndroid Build Coastguard Worker# of each entry in the YAML file 130*da0073e9SAndroid Build Coastguard Workerclass LineLoader(YamlLoader): 131*da0073e9SAndroid Build Coastguard Worker def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] 132*da0073e9SAndroid Build Coastguard Worker mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] 133*da0073e9SAndroid Build Coastguard Worker # Add 1 so line numbering starts at 1 134*da0073e9SAndroid Build Coastguard Worker mapping["__line__"] = node.start_mark.line + 1 135*da0073e9SAndroid Build Coastguard Worker return mapping 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. 139*da0073e9SAndroid Build Coastguard WorkerParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {} 143*da0073e9SAndroid Build Coastguard Worker_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {} 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Workerdef parse_native_yaml_struct( 147*da0073e9SAndroid Build Coastguard Worker es: object, 148*da0073e9SAndroid Build Coastguard Worker valid_tags: set[str], 149*da0073e9SAndroid Build Coastguard Worker ignore_keys: set[DispatchKey] | None = None, 150*da0073e9SAndroid Build Coastguard Worker path: str = "<stdin>", 151*da0073e9SAndroid Build Coastguard Worker skip_native_fns_gen: bool = False, 152*da0073e9SAndroid Build Coastguard Worker) -> ParsedYaml: 153*da0073e9SAndroid Build Coastguard Worker assert isinstance(es, list) 154*da0073e9SAndroid Build Coastguard Worker rs: list[NativeFunction] = [] 155*da0073e9SAndroid Build Coastguard Worker bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict) 156*da0073e9SAndroid Build Coastguard Worker for e in es: 157*da0073e9SAndroid Build Coastguard Worker assert isinstance(e, dict), f"expected to be dict: {e}" 158*da0073e9SAndroid Build Coastguard Worker assert isinstance(e.get("__line__"), int), e 159*da0073e9SAndroid Build Coastguard Worker loc = Location(path, e["__line__"]) 160*da0073e9SAndroid Build Coastguard Worker funcs = e.get("func") 161*da0073e9SAndroid Build Coastguard Worker assert funcs is not None, f"missed 'func' in {e}" 162*da0073e9SAndroid Build Coastguard Worker with context(lambda: f"in {loc}:\n {funcs}"): 163*da0073e9SAndroid Build Coastguard Worker func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys) 164*da0073e9SAndroid Build Coastguard Worker rs.append(func) 165*da0073e9SAndroid Build Coastguard Worker BackendIndex.grow_index(bs, m) 166*da0073e9SAndroid Build Coastguard Worker error_check_native_functions(rs) 167*da0073e9SAndroid Build Coastguard Worker # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. 168*da0073e9SAndroid Build Coastguard Worker indices: dict[DispatchKey, BackendIndex] = defaultdict( 169*da0073e9SAndroid Build Coastguard Worker lambda: BackendIndex( 170*da0073e9SAndroid Build Coastguard Worker dispatch_key=DispatchKey.Undefined, 171*da0073e9SAndroid Build Coastguard Worker use_out_as_primary=True, 172*da0073e9SAndroid Build Coastguard Worker external=False, 173*da0073e9SAndroid Build Coastguard Worker device_guard=False, 174*da0073e9SAndroid Build Coastguard Worker # I'm actually not sure about this; undefined could be hit on 175*da0073e9SAndroid Build Coastguard Worker # empty TensorList, hypothetically that could have sizes in it 176*da0073e9SAndroid Build Coastguard Worker index={}, 177*da0073e9SAndroid Build Coastguard Worker ) 178*da0073e9SAndroid Build Coastguard Worker ) 179*da0073e9SAndroid Build Coastguard Worker if not skip_native_fns_gen: 180*da0073e9SAndroid Build Coastguard Worker add_generated_native_functions(rs, bs) 181*da0073e9SAndroid Build Coastguard Worker for k, v in bs.items(): 182*da0073e9SAndroid Build Coastguard Worker # All structured in-tree operators are implemented in terms of their out operator. 183*da0073e9SAndroid Build Coastguard Worker indices[k] = BackendIndex( 184*da0073e9SAndroid Build Coastguard Worker dispatch_key=k, 185*da0073e9SAndroid Build Coastguard Worker use_out_as_primary=True, 186*da0073e9SAndroid Build Coastguard Worker external=False, 187*da0073e9SAndroid Build Coastguard Worker # Only cuda-like devices in tree require device guards 188*da0073e9SAndroid Build Coastguard Worker device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k), 189*da0073e9SAndroid Build Coastguard Worker index=v, 190*da0073e9SAndroid Build Coastguard Worker ) 191*da0073e9SAndroid Build Coastguard Worker return ParsedYaml(rs, indices) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Workerdef parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]: 195*da0073e9SAndroid Build Coastguard Worker assert isinstance(es, list) 196*da0073e9SAndroid Build Coastguard Worker rs: set[str] = set() 197*da0073e9SAndroid Build Coastguard Worker for e in es: 198*da0073e9SAndroid Build Coastguard Worker assert isinstance(e.get("__line__"), int), e 199*da0073e9SAndroid Build Coastguard Worker loc = Location(path, e["__line__"]) 200*da0073e9SAndroid Build Coastguard Worker tags = e.get("tag") 201*da0073e9SAndroid Build Coastguard Worker with context(lambda: f"in {loc}:\n {tags}"): 202*da0073e9SAndroid Build Coastguard Worker e_i = e.copy() 203*da0073e9SAndroid Build Coastguard Worker name = e_i.pop("tag") 204*da0073e9SAndroid Build Coastguard Worker desc = e_i.pop("desc", "") 205*da0073e9SAndroid Build Coastguard Worker # ensure that each tag has a non-empty description 206*da0073e9SAndroid Build Coastguard Worker assert desc != "" 207*da0073e9SAndroid Build Coastguard Worker rs.add(name) 208*da0073e9SAndroid Build Coastguard Worker return rs 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 212*da0073e9SAndroid Build Coastguard Workerdef parse_tags_yaml(path: str) -> set[str]: 213*da0073e9SAndroid Build Coastguard Worker global _GLOBAL_PARSE_TAGS_YAML_CACHE 214*da0073e9SAndroid Build Coastguard Worker if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: 215*da0073e9SAndroid Build Coastguard Worker with open(path) as f: 216*da0073e9SAndroid Build Coastguard Worker es = yaml.load(f, Loader=LineLoader) 217*da0073e9SAndroid Build Coastguard Worker _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker return _GLOBAL_PARSE_TAGS_YAML_CACHE[path] 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Workerdef parse_native_yaml( 223*da0073e9SAndroid Build Coastguard Worker path: str, 224*da0073e9SAndroid Build Coastguard Worker tags_yaml_path: str, 225*da0073e9SAndroid Build Coastguard Worker ignore_keys: set[DispatchKey] | None = None, 226*da0073e9SAndroid Build Coastguard Worker *, 227*da0073e9SAndroid Build Coastguard Worker skip_native_fns_gen: bool = False, 228*da0073e9SAndroid Build Coastguard Worker loaded_yaml: object | None = None, 229*da0073e9SAndroid Build Coastguard Worker) -> ParsedYaml: 230*da0073e9SAndroid Build Coastguard Worker global _GLOBAL_PARSE_NATIVE_YAML_CACHE 231*da0073e9SAndroid Build Coastguard Worker if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: 232*da0073e9SAndroid Build Coastguard Worker valid_tags = parse_tags_yaml(tags_yaml_path) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker # if a loaded yaml is provided, use that instead of reading from path 235*da0073e9SAndroid Build Coastguard Worker if loaded_yaml is None: 236*da0073e9SAndroid Build Coastguard Worker with open(path) as f: 237*da0073e9SAndroid Build Coastguard Worker es = yaml.load(f, Loader=LineLoader) 238*da0073e9SAndroid Build Coastguard Worker else: 239*da0073e9SAndroid Build Coastguard Worker es = loaded_yaml 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct( 242*da0073e9SAndroid Build Coastguard Worker es, 243*da0073e9SAndroid Build Coastguard Worker valid_tags, 244*da0073e9SAndroid Build Coastguard Worker ignore_keys, 245*da0073e9SAndroid Build Coastguard Worker path=path, 246*da0073e9SAndroid Build Coastguard Worker skip_native_fns_gen=skip_native_fns_gen, 247*da0073e9SAndroid Build Coastguard Worker ) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker# Some assertions are already performed during parsing, but those are only within a single NativeFunction. 253*da0073e9SAndroid Build Coastguard Worker# Assertions here are meant to be performed across NativeFunctions. 254*da0073e9SAndroid Build Coastguard Workerdef error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: 255*da0073e9SAndroid Build Coastguard Worker func_map: dict[OperatorName, NativeFunction] = {} 256*da0073e9SAndroid Build Coastguard Worker base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) 257*da0073e9SAndroid Build Coastguard Worker for f in funcs: 258*da0073e9SAndroid Build Coastguard Worker func_map[f.func.name] = f 259*da0073e9SAndroid Build Coastguard Worker base_func_map[f.func.name.name].append(f) 260*da0073e9SAndroid Build Coastguard Worker for f in funcs: 261*da0073e9SAndroid Build Coastguard Worker if f.structured_delegate is not None: 262*da0073e9SAndroid Build Coastguard Worker delegate_func = func_map.get(f.structured_delegate) 263*da0073e9SAndroid Build Coastguard Worker assert delegate_func is not None, ( 264*da0073e9SAndroid Build Coastguard Worker f"{f.func.name} is marked as a structured_delegate pointing to " 265*da0073e9SAndroid Build Coastguard Worker f"{f.structured_delegate}, but {f.structured_delegate} is missing." 266*da0073e9SAndroid Build Coastguard Worker ) 267*da0073e9SAndroid Build Coastguard Worker assert delegate_func.structured, ( 268*da0073e9SAndroid Build Coastguard Worker f"{f.func.name} is marked as a structured_delegate pointing to " 269*da0073e9SAndroid Build Coastguard Worker f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " 270*da0073e9SAndroid Build Coastguard Worker f"Consider adding 'structured=True' to the delegated operator" 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker # See Note [resize_ in Functionalization] 273*da0073e9SAndroid Build Coastguard Worker # resize_() is technically an inplace view op (and therefore needs the tag), 274*da0073e9SAndroid Build Coastguard Worker # but it would be overkill to add a true "view" variant of resize. 275*da0073e9SAndroid Build Coastguard Worker # Instead, resize_() gets special treatment in functionalization, 276*da0073e9SAndroid Build Coastguard Worker # and we have a resize() op that is non-aliasing + functional. 277*da0073e9SAndroid Build Coastguard Worker if ( 278*da0073e9SAndroid Build Coastguard Worker "inplace_view" in f.tags 279*da0073e9SAndroid Build Coastguard Worker and str(f.func.name) != "resize_" 280*da0073e9SAndroid Build Coastguard Worker and str(f.func.name) != "resize_as_" 281*da0073e9SAndroid Build Coastguard Worker and str(f.func.name.name) != "set_" 282*da0073e9SAndroid Build Coastguard Worker ): 283*da0073e9SAndroid Build Coastguard Worker base_name = f.func.name.name 284*da0073e9SAndroid Build Coastguard Worker assert base_name.inplace, ( 285*da0073e9SAndroid Build Coastguard Worker f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming " 286*da0073e9SAndroid Build Coastguard Worker "convention for inplace ops - the codegen expects the base name to have a trailing underscore. " 287*da0073e9SAndroid Build Coastguard Worker ) 288*da0073e9SAndroid Build Coastguard Worker out_of_place_base_name = BaseOperatorName( 289*da0073e9SAndroid Build Coastguard Worker base_name.base, False, base_name.dunder_method 290*da0073e9SAndroid Build Coastguard Worker ) 291*da0073e9SAndroid Build Coastguard Worker assert len(base_func_map[out_of_place_base_name]) > 0, ( 292*da0073e9SAndroid Build Coastguard Worker f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding " 293*da0073e9SAndroid Build Coastguard Worker f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. " 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Workerdef cpp_string(s: str) -> str: 298*da0073e9SAndroid Build Coastguard Worker """Convert a python string into a c++ string literal""" 299*da0073e9SAndroid Build Coastguard Worker s = s.replace("\\", "\\\\") 300*da0073e9SAndroid Build Coastguard Worker s = s.replace('"', '\\"') 301*da0073e9SAndroid Build Coastguard Worker s = s.replace("\a", "\\a") 302*da0073e9SAndroid Build Coastguard Worker s = s.replace("\b", "\\b") 303*da0073e9SAndroid Build Coastguard Worker s = s.replace("\f", "\\f") 304*da0073e9SAndroid Build Coastguard Worker s = s.replace("\n", "\\n") 305*da0073e9SAndroid Build Coastguard Worker s = s.replace("\v", "\\v") 306*da0073e9SAndroid Build Coastguard Worker s = s.replace("\t", "\\t") 307*da0073e9SAndroid Build Coastguard Worker return f'"{s}"' 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 311*da0073e9SAndroid Build Coastguard Worker# 312*da0073e9SAndroid Build Coastguard Worker# C++ CODE GENERATION 313*da0073e9SAndroid Build Coastguard Worker# 314*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker# Most functions in this section are curried: they consist of a function 317*da0073e9SAndroid Build Coastguard Worker# that takes some parameters (e.g., what is to be generated) which itself 318*da0073e9SAndroid Build Coastguard Worker# returns a function that actually maps NativeFunction to the code 319*da0073e9SAndroid Build Coastguard Worker# to be generated. This pattern makes it convenient to use map, concatMap 320*da0073e9SAndroid Build Coastguard Worker# and similar functional combinators. 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Workerdef static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]: 324*da0073e9SAndroid Build Coastguard Worker if len(backends) == 0: 325*da0073e9SAndroid Build Coastguard Worker return [] 326*da0073e9SAndroid Build Coastguard Worker else: 327*da0073e9SAndroid Build Coastguard Worker return [backend.dispatch_key for backend in backends] + [ 328*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 329*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor, 330*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutograd, 331*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional, 332*da0073e9SAndroid Build Coastguard Worker ] 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Workerdef get_static_dispatch_backend( 336*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, backend_index: BackendIndex 337*da0073e9SAndroid Build Coastguard Worker) -> DispatchKey | None: 338*da0073e9SAndroid Build Coastguard Worker if f.structured_delegate is not None or backend_index.has_kernel(f): 339*da0073e9SAndroid Build Coastguard Worker # TODO: for ops with structured_delegate it should check the dispatch table of 340*da0073e9SAndroid Build Coastguard Worker # the out variant instead. For now, these structured ops all have CPU/CUDA kernels 341*da0073e9SAndroid Build Coastguard Worker # so we always dispatch to the `backend`, but this could be wrong when we 342*da0073e9SAndroid Build Coastguard Worker # migrate math/default_backend ops to use structured delegate. 343*da0073e9SAndroid Build Coastguard Worker return backend_index.dispatch_key 344*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_explicit_autograd_kernel: 345*da0073e9SAndroid Build Coastguard Worker return DispatchKey.CompositeExplicitAutograd 346*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_explicit_autograd_non_functional_kernel: 347*da0073e9SAndroid Build Coastguard Worker return DispatchKey.CompositeExplicitAutogradNonFunctional 348*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_implicit_autograd_kernel: 349*da0073e9SAndroid Build Coastguard Worker return DispatchKey.CompositeImplicitAutograd 350*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_implicit_autograd_nested_tensor_kernel: 351*da0073e9SAndroid Build Coastguard Worker return DispatchKey.CompositeImplicitAutogradNestedTensor 352*da0073e9SAndroid Build Coastguard Worker return None 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Workerdef static_dispatch_ops_header( 356*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, backend_index: list[BackendIndex] 357*da0073e9SAndroid Build Coastguard Worker) -> str | None: 358*da0073e9SAndroid Build Coastguard Worker if backend_index is None or f.manual_kernel_registration: 359*da0073e9SAndroid Build Coastguard Worker return None 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker output = [] 362*da0073e9SAndroid Build Coastguard Worker for index in backend_index: 363*da0073e9SAndroid Build Coastguard Worker dispatch_key = get_static_dispatch_backend(f, index) 364*da0073e9SAndroid Build Coastguard Worker if dispatch_key is not None: 365*da0073e9SAndroid Build Coastguard Worker output.append( 366*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>" 367*da0073e9SAndroid Build Coastguard Worker ) 368*da0073e9SAndroid Build Coastguard Worker return "\n".join(output) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Workerdef static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]: 372*da0073e9SAndroid Build Coastguard Worker return [ 373*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/{dispatch_key}Functions.h>" 374*da0073e9SAndroid Build Coastguard Worker for dispatch_key in static_dispatch_keys(backends) 375*da0073e9SAndroid Build Coastguard Worker ] 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker# Translates arguments of `sig` to CppSignature bindings. 379*da0073e9SAndroid Build Coastguard Worker# Note that we have a special case for `memory_format` argument and this case is not covered by 380*da0073e9SAndroid Build Coastguard Worker# tools.codegen.api.translate() yet as its application is limited to static dispatch. 381*da0073e9SAndroid Build Coastguard Workerdef translate_args( 382*da0073e9SAndroid Build Coastguard Worker sig: CppSignature | DispatcherSignature, 383*da0073e9SAndroid Build Coastguard Worker cpp_sig: CppSignature, 384*da0073e9SAndroid Build Coastguard Worker) -> str: 385*da0073e9SAndroid Build Coastguard Worker # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings 386*da0073e9SAndroid Build Coastguard Worker def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]: 387*da0073e9SAndroid Build Coastguard Worker output_bindings: list[Binding] = [] 388*da0073e9SAndroid Build Coastguard Worker for binding in input_bindings: 389*da0073e9SAndroid Build Coastguard Worker if binding.name == "memory_format": 390*da0073e9SAndroid Build Coastguard Worker spl_mem_format_binding = Binding( 391*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType( 392*da0073e9SAndroid Build Coastguard Worker SpecialArgName.possibly_redundant_memory_format, 393*da0073e9SAndroid Build Coastguard Worker binding.nctype.type, 394*da0073e9SAndroid Build Coastguard Worker ), 395*da0073e9SAndroid Build Coastguard Worker name=binding.name, 396*da0073e9SAndroid Build Coastguard Worker default=binding.default, 397*da0073e9SAndroid Build Coastguard Worker argument=binding.argument, 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker output_bindings.append(spl_mem_format_binding) 400*da0073e9SAndroid Build Coastguard Worker else: 401*da0073e9SAndroid Build Coastguard Worker output_bindings.append(binding) 402*da0073e9SAndroid Build Coastguard Worker return output_bindings 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker src_bindings = list(sig.arguments()) 405*da0073e9SAndroid Build Coastguard Worker goal_bindings = list(cpp_sig.arguments()) 406*da0073e9SAndroid Build Coastguard Worker # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, 407*da0073e9SAndroid Build Coastguard Worker # get memory_format bindings of dispatcher signature to have the same NCType as well 408*da0073e9SAndroid Build Coastguard Worker for arg in goal_bindings: 409*da0073e9SAndroid Build Coastguard Worker if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format: 410*da0073e9SAndroid Build Coastguard Worker src_bindings = add_spl_memory_format_binding(src_bindings) 411*da0073e9SAndroid Build Coastguard Worker break 412*da0073e9SAndroid Build Coastguard Worker exprs = translate(src_bindings, goal_bindings) 413*da0073e9SAndroid Build Coastguard Worker return ", ".join(a.expr for a in exprs) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Workerdef generate_static_dispatch_backend_call( 417*da0073e9SAndroid Build Coastguard Worker sig: CppSignature | DispatcherSignature, 418*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 419*da0073e9SAndroid Build Coastguard Worker backend_index: BackendIndex, 420*da0073e9SAndroid Build Coastguard Worker) -> str: 421*da0073e9SAndroid Build Coastguard Worker cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) 422*da0073e9SAndroid Build Coastguard Worker name = cpp_sig.name() 423*da0073e9SAndroid Build Coastguard Worker exprs = translate_args(sig, cpp_sig) 424*da0073e9SAndroid Build Coastguard Worker backend_metadata = backend_index.get_kernel(f) 425*da0073e9SAndroid Build Coastguard Worker kernel_ns = ( 426*da0073e9SAndroid Build Coastguard Worker backend_metadata.cpp_namespace 427*da0073e9SAndroid Build Coastguard Worker if backend_metadata and backend_metadata.cpp_namespace 428*da0073e9SAndroid Build Coastguard Worker else DEFAULT_KERNEL_NAMESPACE 429*da0073e9SAndroid Build Coastguard Worker ) 430*da0073e9SAndroid Build Coastguard Worker ns = kernel_ns.replace("::native", "") 431*da0073e9SAndroid Build Coastguard Worker return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});" 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Workerdef generate_static_dispatch_fallback_call( 435*da0073e9SAndroid Build Coastguard Worker sig: CppSignature | DispatcherSignature, 436*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 437*da0073e9SAndroid Build Coastguard Worker backend_indices: list[BackendIndex], 438*da0073e9SAndroid Build Coastguard Worker) -> str: 439*da0073e9SAndroid Build Coastguard Worker cpp_sigs = CppSignatureGroup.from_native_function( 440*da0073e9SAndroid Build Coastguard Worker f, method=False, fallback_binding=False 441*da0073e9SAndroid Build Coastguard Worker ) 442*da0073e9SAndroid Build Coastguard Worker if sig.symint and f.func.has_symint(): 443*da0073e9SAndroid Build Coastguard Worker cpp_sig = cpp_sigs.symint_signature 444*da0073e9SAndroid Build Coastguard Worker else: 445*da0073e9SAndroid Build Coastguard Worker cpp_sig = cpp_sigs.signature 446*da0073e9SAndroid Build Coastguard Worker assert cpp_sig is not None 447*da0073e9SAndroid Build Coastguard Worker name = cpp_sig.name() 448*da0073e9SAndroid Build Coastguard Worker exprs = translate_args(sig, cpp_sig) 449*da0073e9SAndroid Build Coastguard Worker ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") 450*da0073e9SAndroid Build Coastguard Worker if f.has_composite_explicit_autograd_kernel: 451*da0073e9SAndroid Build Coastguard Worker return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" 452*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_explicit_autograd_non_functional_kernel: 453*da0073e9SAndroid Build Coastguard Worker return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});" 454*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_implicit_autograd_kernel: 455*da0073e9SAndroid Build Coastguard Worker return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});" 456*da0073e9SAndroid Build Coastguard Worker elif f.has_composite_implicit_autograd_nested_tensor_kernel: 457*da0073e9SAndroid Build Coastguard Worker return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});" 458*da0073e9SAndroid Build Coastguard Worker else: 459*da0073e9SAndroid Build Coastguard Worker return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ 460*da0073e9SAndroid Build Coastguard Worker{', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Workerdef static_dispatch( 464*da0073e9SAndroid Build Coastguard Worker sig: CppSignature | DispatcherSignature, 465*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 466*da0073e9SAndroid Build Coastguard Worker backend_indices: list[BackendIndex], 467*da0073e9SAndroid Build Coastguard Worker) -> str: 468*da0073e9SAndroid Build Coastguard Worker """ 469*da0073e9SAndroid Build Coastguard Worker For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one 470*da0073e9SAndroid Build Coastguard Worker backends exsit, fallback to static dispatch by determining dispatch key from inputs. 471*da0073e9SAndroid Build Coastguard Worker Arguments: 472*da0073e9SAndroid Build Coastguard Worker sig: A CppSignature or DispatcherSignature for this native function we want to use. 473*da0073e9SAndroid Build Coastguard Worker f: NativeFunction to generate static dispatch. 474*da0073e9SAndroid Build Coastguard Worker backend_indices: All available backends. 475*da0073e9SAndroid Build Coastguard Worker Return: 476*da0073e9SAndroid Build Coastguard Worker C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);" 477*da0073e9SAndroid Build Coastguard Worker """ 478*da0073e9SAndroid Build Coastguard Worker if len(backend_indices) == 0 or f.manual_kernel_registration: 479*da0073e9SAndroid Build Coastguard Worker return "" 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker keys = [ 482*da0073e9SAndroid Build Coastguard Worker b 483*da0073e9SAndroid Build Coastguard Worker for b in backend_indices 484*da0073e9SAndroid Build Coastguard Worker if b.has_kernel(f) 485*da0073e9SAndroid Build Coastguard Worker or ( 486*da0073e9SAndroid Build Coastguard Worker f.structured_delegate is not None 487*da0073e9SAndroid Build Coastguard Worker and b.dispatch_key in STRUCTURED_DISPATCH_KEYS 488*da0073e9SAndroid Build Coastguard Worker ) 489*da0073e9SAndroid Build Coastguard Worker ] 490*da0073e9SAndroid Build Coastguard Worker if len(keys) == 1: 491*da0073e9SAndroid Build Coastguard Worker return generate_static_dispatch_backend_call(sig, f, keys[0]) 492*da0073e9SAndroid Build Coastguard Worker elif len(keys) == 0: 493*da0073e9SAndroid Build Coastguard Worker return generate_static_dispatch_fallback_call(sig, f, backend_indices) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker native_tensor_args = [ 496*da0073e9SAndroid Build Coastguard Worker a.name 497*da0073e9SAndroid Build Coastguard Worker for a in sig.arguments() 498*da0073e9SAndroid Build Coastguard Worker if isinstance(a.argument, SelfArgument) 499*da0073e9SAndroid Build Coastguard Worker or isinstance(a.argument, Argument) 500*da0073e9SAndroid Build Coastguard Worker and a.argument.type.is_tensor_like() 501*da0073e9SAndroid Build Coastguard Worker ] 502*da0073e9SAndroid Build Coastguard Worker tensor_args = ", ".join(native_tensor_args) 503*da0073e9SAndroid Build Coastguard Worker tensor_opts = f.func.arguments.tensor_options 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker stmts = [] 506*da0073e9SAndroid Build Coastguard Worker subexprs: list[str] = [] 507*da0073e9SAndroid Build Coastguard Worker if tensor_opts is not None: 508*da0073e9SAndroid Build Coastguard Worker subexprs.append( 509*da0073e9SAndroid Build Coastguard Worker "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))" 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker if tensor_args != "": 512*da0073e9SAndroid Build Coastguard Worker subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})") 513*da0073e9SAndroid Build Coastguard Worker stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""") 514*da0073e9SAndroid Build Coastguard Worker stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);") 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker dispatch_code = [] 517*da0073e9SAndroid Build Coastguard Worker for index in keys: 518*da0073e9SAndroid Build Coastguard Worker dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") 519*da0073e9SAndroid Build Coastguard Worker dispatch_code.append( 520*da0073e9SAndroid Build Coastguard Worker f"""\t{generate_static_dispatch_backend_call(sig, f, index)};""" 521*da0073e9SAndroid Build Coastguard Worker ) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices) 524*da0073e9SAndroid Build Coastguard Worker connector = "\n\t\t" 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker return f""" 527*da0073e9SAndroid Build Coastguard Worker {connector.join(stmts)} 528*da0073e9SAndroid Build Coastguard Worker switch (_dk) {{ 529*da0073e9SAndroid Build Coastguard Worker {connector.join(dispatch_code)} 530*da0073e9SAndroid Build Coastguard Worker default: 531*da0073e9SAndroid Build Coastguard Worker {fallback} 532*da0073e9SAndroid Build Coastguard Worker }} 533*da0073e9SAndroid Build Coastguard Worker """ 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker# Generates RegisterSchema.cpp. Depending on the selector, either 537*da0073e9SAndroid Build Coastguard Worker# all schemas are registered, or only some are (in the case of 538*da0073e9SAndroid Build Coastguard Worker# selective build) 539*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 540*da0073e9SAndroid Build Coastguard Workerclass RegisterSchema: 541*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder 542*da0073e9SAndroid Build Coastguard Worker known_tags: dict[str, int] = field(default_factory=dict) 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker @method_with_native_function 545*da0073e9SAndroid Build Coastguard Worker def __call__(self, f: NativeFunction) -> str | None: 546*da0073e9SAndroid Build Coastguard Worker if not self.selector.is_native_function_selected(f): 547*da0073e9SAndroid Build Coastguard Worker return None 548*da0073e9SAndroid Build Coastguard Worker tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" 549*da0073e9SAndroid Build Coastguard Worker if tags == "{}": 550*da0073e9SAndroid Build Coastguard Worker return f"m.def({cpp_string(str(f.func))}, {{}});\n" 551*da0073e9SAndroid Build Coastguard Worker maybe_tags = "" 552*da0073e9SAndroid Build Coastguard Worker if tags not in self.known_tags: 553*da0073e9SAndroid Build Coastguard Worker idx = len(self.known_tags) 554*da0073e9SAndroid Build Coastguard Worker self.known_tags[tags] = idx 555*da0073e9SAndroid Build Coastguard Worker maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n" 556*da0073e9SAndroid Build Coastguard Worker return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n" 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker# Generates Operators.h and Operators.cpp. 560*da0073e9SAndroid Build Coastguard Worker# These provide macros that, given an operator and overload name, allow users 561*da0073e9SAndroid Build Coastguard Worker# to access an "un-overloaded" function version of the operator. This 562*da0073e9SAndroid Build Coastguard Worker# is useful for extension writers who want to (1) want to decltype the operator 563*da0073e9SAndroid Build Coastguard Worker# and (2) don't want to worry about method-only operators. 564*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 565*da0073e9SAndroid Build Coastguard Workerclass ComputeOperators: 566*da0073e9SAndroid Build Coastguard Worker target: Literal[Target.DECLARATION, Target.DEFINITION] 567*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices: list[BackendIndex] 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker @method_with_native_function 570*da0073e9SAndroid Build Coastguard Worker def __call__(self, f: NativeFunction) -> str: 571*da0073e9SAndroid Build Coastguard Worker sig = DispatcherSignature.from_schema(f.func) 572*da0073e9SAndroid Build Coastguard Worker name = f.func.name.unambiguous_name() 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker if self.target is Target.DECLARATION: 575*da0073e9SAndroid Build Coastguard Worker # Note [The ATen Operators API] 576*da0073e9SAndroid Build Coastguard Worker # The ATen Operators API lives in the at::_ops namespace, and contains compile-time 577*da0073e9SAndroid Build Coastguard Worker # metadata about each operator + entry points into the Dispatcher. 578*da0073e9SAndroid Build Coastguard Worker # The C++ function, method, and redispatch API's are all implemented as wrappers 579*da0073e9SAndroid Build Coastguard Worker # into various bits of the structs defined here. 580*da0073e9SAndroid Build Coastguard Worker # 581*da0073e9SAndroid Build Coastguard Worker # Important characteristics about the Operators API: 582*da0073e9SAndroid Build Coastguard Worker # (1) It follows the Dispatcher API. 583*da0073e9SAndroid Build Coastguard Worker # This is kind of necessary to avoid overhead. 584*da0073e9SAndroid Build Coastguard Worker # For example: if it followed the C++ API, then all of the faithful C++ factory functions 585*da0073e9SAndroid Build Coastguard Worker # would need to wrap their arguments into TensorOptions only to unwrap them again. 586*da0073e9SAndroid Build Coastguard Worker # (2) Overload names are disambiguated. 587*da0073e9SAndroid Build Coastguard Worker # This is helpful for pytorch extenders who would like to decltype() an aten operator, 588*da0073e9SAndroid Build Coastguard Worker # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call) 589*da0073e9SAndroid Build Coastguard Worker # (3) No argument defaulting is allowed. 590*da0073e9SAndroid Build Coastguard Worker # This is more of an implementation detail to avoid #include cycles, 591*da0073e9SAndroid Build Coastguard Worker # since TensorBody.h (which defines the Tensor class) needs to include this file. 592*da0073e9SAndroid Build Coastguard Worker # (4) manual_cpp_bindings and faithful names are not included in the API. 593*da0073e9SAndroid Build Coastguard Worker # This applies to stuff like __dispatch__is_complex(), and add_outf(). 594*da0073e9SAndroid Build Coastguard Worker # These aren't "real aten ops", they're just additional functions provided by the C++ API. 595*da0073e9SAndroid Build Coastguard Worker # They're implemented as wrappers in Functions.h that call into the actual operators 596*da0073e9SAndroid Build Coastguard Worker # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). 597*da0073e9SAndroid Build Coastguard Worker # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. 598*da0073e9SAndroid Build Coastguard Worker return f""" 599*da0073e9SAndroid Build Coastguard Workerstruct TORCH_API {name} {{ 600*da0073e9SAndroid Build Coastguard Worker using schema = {sig.type()}; 601*da0073e9SAndroid Build Coastguard Worker using ptr_schema = schema*; 602*da0073e9SAndroid Build Coastguard Worker // See Note [static constexpr char* members for windows NVCC] 603*da0073e9SAndroid Build Coastguard Worker STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") 604*da0073e9SAndroid Build Coastguard Worker STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") 605*da0073e9SAndroid Build Coastguard Worker STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) 606*da0073e9SAndroid Build Coastguard Worker static {sig.defn(name="call", is_redispatching_fn=False)}; 607*da0073e9SAndroid Build Coastguard Worker static {sig.defn(name="redispatch", is_redispatching_fn=True)}; 608*da0073e9SAndroid Build Coastguard Worker}};""" 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker elif self.target is Target.DEFINITION: 611*da0073e9SAndroid Build Coastguard Worker defns = f""" 612*da0073e9SAndroid Build Coastguard WorkerSTATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}") 613*da0073e9SAndroid Build Coastguard WorkerSTATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") 614*da0073e9SAndroid Build Coastguard WorkerSTATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))}) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker// aten::{f.func} 617*da0073e9SAndroid Build Coastguard Workerstatic C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{ 618*da0073e9SAndroid Build Coastguard Worker return c10::Dispatcher::singleton() 619*da0073e9SAndroid Build Coastguard Worker .findSchemaOrThrow({name}::name, {name}::overload_name) 620*da0073e9SAndroid Build Coastguard Worker .typed<{name}::schema>(); 621*da0073e9SAndroid Build Coastguard Worker}} 622*da0073e9SAndroid Build Coastguard Worker""" 623*da0073e9SAndroid Build Coastguard Worker for is_redispatching_fn in [False, True]: 624*da0073e9SAndroid Build Coastguard Worker if is_redispatching_fn: 625*da0073e9SAndroid Build Coastguard Worker dispatcher_exprs_str = ", ".join( 626*da0073e9SAndroid Build Coastguard Worker ["dispatchKeySet"] + [a.name for a in sig.arguments()] 627*da0073e9SAndroid Build Coastguard Worker ) 628*da0073e9SAndroid Build Coastguard Worker method_base = "redispatch" 629*da0073e9SAndroid Build Coastguard Worker else: 630*da0073e9SAndroid Build Coastguard Worker dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()]) 631*da0073e9SAndroid Build Coastguard Worker method_base = "call" 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker dispatcher_call = method_base 634*da0073e9SAndroid Build Coastguard Worker method_name = f"{name}::{method_base}" 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker fn_body = f""" 637*da0073e9SAndroid Build Coastguard Worker static auto op = create_{name}_typed_handle(); 638*da0073e9SAndroid Build Coastguard Worker return op.{dispatcher_call}({dispatcher_exprs_str});""" 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker if ( 641*da0073e9SAndroid Build Coastguard Worker not is_redispatching_fn 642*da0073e9SAndroid Build Coastguard Worker and len(self.static_dispatch_backend_indices) > 0 643*da0073e9SAndroid Build Coastguard Worker ): 644*da0073e9SAndroid Build Coastguard Worker # call() should go through static dispatch 645*da0073e9SAndroid Build Coastguard Worker fn_body = static_dispatch( 646*da0073e9SAndroid Build Coastguard Worker sig, f, backend_indices=self.static_dispatch_backend_indices 647*da0073e9SAndroid Build Coastguard Worker ) 648*da0073e9SAndroid Build Coastguard Worker defns += f""" 649*da0073e9SAndroid Build Coastguard Worker// aten::{f.func} 650*da0073e9SAndroid Build Coastguard Worker{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ 651*da0073e9SAndroid Build Coastguard Worker {fn_body} 652*da0073e9SAndroid Build Coastguard Worker}} 653*da0073e9SAndroid Build Coastguard Worker""" 654*da0073e9SAndroid Build Coastguard Worker return defns 655*da0073e9SAndroid Build Coastguard Worker else: 656*da0073e9SAndroid Build Coastguard Worker assert_never(self.target) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker# Generates Functions.h, which provides the functional public C++ API, 660*da0073e9SAndroid Build Coastguard Worker# and the scaffolding to call into the dispatcher from these functions. 661*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 662*da0073e9SAndroid Build Coastguard Workerclass ComputeFunction: 663*da0073e9SAndroid Build Coastguard Worker @method_with_native_function 664*da0073e9SAndroid Build Coastguard Worker def __call__(self, f: NativeFunction) -> str | None: 665*da0073e9SAndroid Build Coastguard Worker sig_group = CppSignatureGroup.from_native_function( 666*da0073e9SAndroid Build Coastguard Worker f, method=False, fallback_binding=f.manual_cpp_binding 667*da0073e9SAndroid Build Coastguard Worker ) 668*da0073e9SAndroid Build Coastguard Worker has_symint = f.func.has_symint() 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker result = "" 671*da0073e9SAndroid Build Coastguard Worker for sig in sig_group.signatures(): 672*da0073e9SAndroid Build Coastguard Worker # See Note [The ATen Operators API] 673*da0073e9SAndroid Build Coastguard Worker target_sig = DispatcherSignature.from_schema(f.func) 674*da0073e9SAndroid Build Coastguard Worker exprs = translate(sig.arguments(), target_sig.arguments()) 675*da0073e9SAndroid Build Coastguard Worker exprs_str = ", ".join([e.expr for e in exprs]) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker if sig.symint: 678*da0073e9SAndroid Build Coastguard Worker intlike_t = "c10::SymInt" 679*da0073e9SAndroid Build Coastguard Worker else: 680*da0073e9SAndroid Build Coastguard Worker intlike_t = "int64_t" 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker if Variant.function in f.variants: 683*da0073e9SAndroid Build Coastguard Worker result += f""" 684*da0073e9SAndroid Build Coastguard Worker// aten::{f.func} 685*da0073e9SAndroid Build Coastguard Workerinline {sig.decl()} {{ 686*da0073e9SAndroid Build Coastguard Worker return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); 687*da0073e9SAndroid Build Coastguard Worker}}""" 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker # The template function can be used from template situations 690*da0073e9SAndroid Build Coastguard Worker # where you want to switch between the symint or not version 691*da0073e9SAndroid Build Coastguard Worker # depending on a template argument 692*da0073e9SAndroid Build Coastguard Worker # 693*da0073e9SAndroid Build Coastguard Worker # NB: we ALWAYS generate this even for methods. But we put it in 694*da0073e9SAndroid Build Coastguard Worker # this header so it can take advantage of per-op headers 695*da0073e9SAndroid Build Coastguard Worker if has_symint: 696*da0073e9SAndroid Build Coastguard Worker result += f""" 697*da0073e9SAndroid Build Coastguard Workernamespace symint {{ 698*da0073e9SAndroid Build Coastguard Worker template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>> 699*da0073e9SAndroid Build Coastguard Worker {sig.decl(suppress_symint_suffix=True)} {{ 700*da0073e9SAndroid Build Coastguard Worker return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); 701*da0073e9SAndroid Build Coastguard Worker }} 702*da0073e9SAndroid Build Coastguard Worker}} 703*da0073e9SAndroid Build Coastguard Worker""" 704*da0073e9SAndroid Build Coastguard Worker return result 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker# Generates TensorBody.h. This file provides the object-oriented (method-based) 708*da0073e9SAndroid Build Coastguard Worker# public C++ API, and the scaffolding to call into the dispatcher from these functions. 709*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 710*da0073e9SAndroid Build Coastguard Workerclass ComputeTensorMethod: 711*da0073e9SAndroid Build Coastguard Worker target: Literal[Target.DECLARATION, Target.DEFINITION] 712*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices: list[BackendIndex] 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker @method_with_native_function 715*da0073e9SAndroid Build Coastguard Worker def __call__(self, f: NativeFunction) -> str | None: 716*da0073e9SAndroid Build Coastguard Worker if Variant.method not in f.variants: 717*da0073e9SAndroid Build Coastguard Worker return None 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker assert not f.func.is_out_fn() 720*da0073e9SAndroid Build Coastguard Worker assert f.func.arguments.self_arg is not None 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker sig_group = CppSignatureGroup.from_native_function( 723*da0073e9SAndroid Build Coastguard Worker f, method=True, fallback_binding=f.manual_cpp_binding 724*da0073e9SAndroid Build Coastguard Worker ) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker if self.target is Target.DECLARATION: 727*da0073e9SAndroid Build Coastguard Worker result = "" 728*da0073e9SAndroid Build Coastguard Worker for sig in sig_group.signatures(): 729*da0073e9SAndroid Build Coastguard Worker result += f"{sig.decl()} const;\n" 730*da0073e9SAndroid Build Coastguard Worker return result 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker if self.target is not Target.DEFINITION: 733*da0073e9SAndroid Build Coastguard Worker assert_never(self.target) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker result = "" 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker for sig in sig_group.signatures(): 738*da0073e9SAndroid Build Coastguard Worker target_sig = DispatcherSignature.from_schema(f.func) 739*da0073e9SAndroid Build Coastguard Worker exprs = translate(sig.arguments(), target_sig.arguments(), method=True) 740*da0073e9SAndroid Build Coastguard Worker exprs_str = ", ".join([e.expr for e in exprs]) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker result += f""" 743*da0073e9SAndroid Build Coastguard Worker// aten::{f.func} 744*da0073e9SAndroid Build Coastguard Workerinline {sig.defn(prefix="Tensor::")} const {{ 745*da0073e9SAndroid Build Coastguard Worker return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); 746*da0073e9SAndroid Build Coastguard Worker}} 747*da0073e9SAndroid Build Coastguard Worker""" 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker return result 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker# Generates RedispatchFunctions.h. 753*da0073e9SAndroid Build Coastguard Worker# This is similar to the C++ API defined in Functions.h, but provides access 754*da0073e9SAndroid Build Coastguard Worker# to the dispatcher's redispatch API. 755*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 756*da0073e9SAndroid Build Coastguard Workerclass ComputeRedispatchFunction: 757*da0073e9SAndroid Build Coastguard Worker @method_with_native_function 758*da0073e9SAndroid Build Coastguard Worker def __call__(self, f: NativeFunction) -> str | None: 759*da0073e9SAndroid Build Coastguard Worker # We unconditionally generate function variants of the redispatch API. 760*da0073e9SAndroid Build Coastguard Worker # This is mainly because we can namespace functions separately, but not methods, 761*da0073e9SAndroid Build Coastguard Worker sig_group = CppSignatureGroup.from_native_function( 762*da0073e9SAndroid Build Coastguard Worker f, method=False, fallback_binding=f.manual_cpp_binding 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker result = "" 766*da0073e9SAndroid Build Coastguard Worker for sig in sig_group.signatures(): 767*da0073e9SAndroid Build Coastguard Worker target_sig = DispatcherSignature.from_schema(f.func) 768*da0073e9SAndroid Build Coastguard Worker exprs = translate(sig.arguments(), target_sig.arguments()) 769*da0073e9SAndroid Build Coastguard Worker exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs]) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker result += f""" 772*da0073e9SAndroid Build Coastguard Worker// aten::{f.func} 773*da0073e9SAndroid Build Coastguard Workerinline {sig.decl(is_redispatching_fn=True)} {{ 774*da0073e9SAndroid Build Coastguard Worker return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str}); 775*da0073e9SAndroid Build Coastguard Worker}} 776*da0073e9SAndroid Build Coastguard Worker""" 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker return result 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker# Generates ATenOpList.cpp, a runtime accessible list of all aten 782*da0073e9SAndroid Build Coastguard Worker# operators. 783*da0073e9SAndroid Build Coastguard Worker# TODO: This was historically used to help some JIT interop code 784*da0073e9SAndroid Build Coastguard Worker# figure out whether or not to treat aten namespace'd operators 785*da0073e9SAndroid Build Coastguard Worker# one way or another, we should reevaluate if this is actually needed. 786*da0073e9SAndroid Build Coastguard Worker@with_native_function 787*da0073e9SAndroid Build Coastguard Workerdef compute_aten_op(f: NativeFunction) -> str: 788*da0073e9SAndroid Build Coastguard Worker return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},' 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker# Generates MetaFunctions.h 792*da0073e9SAndroid Build Coastguard Workerdef compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None: 793*da0073e9SAndroid Build Coastguard Worker if not g.structured: 794*da0073e9SAndroid Build Coastguard Worker return None 795*da0073e9SAndroid Build Coastguard Worker with native_function_manager(g.out): 796*da0073e9SAndroid Build Coastguard Worker name = meta.name(g) 797*da0073e9SAndroid Build Coastguard Worker args = structured.meta_arguments(g) 798*da0073e9SAndroid Build Coastguard Worker args_str = ", ".join(a.decl() for a in args) 799*da0073e9SAndroid Build Coastguard Worker parent_class = g.out.structured_inherits 800*da0073e9SAndroid Build Coastguard Worker if parent_class is None: 801*da0073e9SAndroid Build Coastguard Worker parent_class = "at::impl::MetaBase" 802*da0073e9SAndroid Build Coastguard Worker meta_return = "void" 803*da0073e9SAndroid Build Coastguard Worker precomputed = g.out.precomputed if g.structured else None 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker if precomputed: 806*da0073e9SAndroid Build Coastguard Worker # Generate the template declaration with one bool parameter for each 807*da0073e9SAndroid Build Coastguard Worker # precomputed element. Each parameter is true if the corresponding (in 808*da0073e9SAndroid Build Coastguard Worker # terms of position) precomputed element has been set. 809*da0073e9SAndroid Build Coastguard Worker precomputed_values = [*precomputed.replace.values(), precomputed.add] 810*da0073e9SAndroid Build Coastguard Worker precomputed_elements = [ 811*da0073e9SAndroid Build Coastguard Worker elem for replace_list in precomputed_values for elem in replace_list 812*da0073e9SAndroid Build Coastguard Worker ] 813*da0073e9SAndroid Build Coastguard Worker precomputed_template_parameters = [ 814*da0073e9SAndroid Build Coastguard Worker elem.name.upper() for elem in precomputed_elements 815*da0073e9SAndroid Build Coastguard Worker ] 816*da0073e9SAndroid Build Coastguard Worker precomputed_template_params_str = ", ".join( 817*da0073e9SAndroid Build Coastguard Worker f"bool {param} = false" for param in precomputed_template_parameters 818*da0073e9SAndroid Build Coastguard Worker ) 819*da0073e9SAndroid Build Coastguard Worker precompute_template_decl = f"template <{precomputed_template_params_str}>" 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker # Generate a string containing declarations of all precomputed elements. 822*da0073e9SAndroid Build Coastguard Worker precomputed_elements_with_cpp_types = [ 823*da0073e9SAndroid Build Coastguard Worker structured.argument_type(elem, binds=elem.name) 824*da0073e9SAndroid Build Coastguard Worker for elem in precomputed_elements 825*da0073e9SAndroid Build Coastguard Worker ] 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker precomputed_elements_decl = ";\n".join( 828*da0073e9SAndroid Build Coastguard Worker f"{elem.cpp_type(strip_ref=True)} {elem.name}" 829*da0073e9SAndroid Build Coastguard Worker for elem in precomputed_elements_with_cpp_types 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker # Generate "setter" methods for each precomputed element. Each method will return 833*da0073e9SAndroid Build Coastguard Worker # a new instance of precompute_out with the template parameter that corresponds to 834*da0073e9SAndroid Build Coastguard Worker # the member set by the method to true (to indicate that it has been set). 835*da0073e9SAndroid Build Coastguard Worker setter_methods = [] 836*da0073e9SAndroid Build Coastguard Worker for i, elem in enumerate(precomputed_elements): 837*da0073e9SAndroid Build Coastguard Worker # Generate the signature. The return type will be the same 838*da0073e9SAndroid Build Coastguard Worker # as the type of `this` but with the template parameter 839*da0073e9SAndroid Build Coastguard Worker # corresponding to the element set by this method set to true. 840*da0073e9SAndroid Build Coastguard Worker # The assert generated below will ensure that this template 841*da0073e9SAndroid Build Coastguard Worker # parameter is false on the type of `this`. 842*da0073e9SAndroid Build Coastguard Worker return_ty_templates = ", ".join( 843*da0073e9SAndroid Build Coastguard Worker precomputed_template_parameters[:i] 844*da0073e9SAndroid Build Coastguard Worker + ["true"] 845*da0073e9SAndroid Build Coastguard Worker + precomputed_template_parameters[i + 1 :] 846*da0073e9SAndroid Build Coastguard Worker ) 847*da0073e9SAndroid Build Coastguard Worker return_ty = f"precompute_out<{return_ty_templates}>" 848*da0073e9SAndroid Build Coastguard Worker elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type( 849*da0073e9SAndroid Build Coastguard Worker strip_ref=True 850*da0073e9SAndroid Build Coastguard Worker ) 851*da0073e9SAndroid Build Coastguard Worker signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)" 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker # Generate an assert which checks that the 854*da0073e9SAndroid Build Coastguard Worker # template parameter corresponding to the precomputed 855*da0073e9SAndroid Build Coastguard Worker # element that is set by this method is false on the 856*da0073e9SAndroid Build Coastguard Worker # class corresponding to the object that `this` points to. 857*da0073e9SAndroid Build Coastguard Worker # This ensures that each element can be set only once. 858*da0073e9SAndroid Build Coastguard Worker assert_msg = f'"{elem.name} already set"' 859*da0073e9SAndroid Build Coastguard Worker assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});" 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker # Generate the new object construction block. All state 862*da0073e9SAndroid Build Coastguard Worker # except the element that this method sets is copied from the 863*da0073e9SAndroid Build Coastguard Worker # object that `this` points to. The value for the element that 864*da0073e9SAndroid Build Coastguard Worker # the method sets is taken from a method parameter. 865*da0073e9SAndroid Build Coastguard Worker construction_stmts = [] 866*da0073e9SAndroid Build Coastguard Worker construction_stmts.append(f"{return_ty} ret;") 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker for j, elem in enumerate(precomputed_elements): 869*da0073e9SAndroid Build Coastguard Worker if i == j: 870*da0073e9SAndroid Build Coastguard Worker construction_stmts.append(f"ret.{elem.name} = value;") 871*da0073e9SAndroid Build Coastguard Worker else: 872*da0073e9SAndroid Build Coastguard Worker construction_stmts.append( 873*da0073e9SAndroid Build Coastguard Worker f"ret.{elem.name} = this->{elem.name};" 874*da0073e9SAndroid Build Coastguard Worker ) 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker construction_stmts.append("return ret;") 877*da0073e9SAndroid Build Coastguard Worker construction_block = "\n".join(construction_stmts) 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker setter_methods.append( 880*da0073e9SAndroid Build Coastguard Worker f""" 881*da0073e9SAndroid Build Coastguard Worker {signature} {{ 882*da0073e9SAndroid Build Coastguard Worker {assert_stmt} 883*da0073e9SAndroid Build Coastguard Worker {construction_block} 884*da0073e9SAndroid Build Coastguard Worker }} 885*da0073e9SAndroid Build Coastguard Worker """ 886*da0073e9SAndroid Build Coastguard Worker ) 887*da0073e9SAndroid Build Coastguard Worker setter_methods_decl = "\n".join(setter_methods) 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker # Meta should return an instance of the struct containing the precomputed elements. 890*da0073e9SAndroid Build Coastguard Worker meta_return_template_params = ", ".join( 891*da0073e9SAndroid Build Coastguard Worker ["true"] * len(precomputed_template_parameters) 892*da0073e9SAndroid Build Coastguard Worker ) 893*da0073e9SAndroid Build Coastguard Worker # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return 894*da0073e9SAndroid Build Coastguard Worker # type (which has a variable number of template parameters). 895*da0073e9SAndroid Build Coastguard Worker meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;" 896*da0073e9SAndroid Build Coastguard Worker meta_return = "meta_return_ty" 897*da0073e9SAndroid Build Coastguard Worker precomputed_decl = f""" 898*da0073e9SAndroid Build Coastguard Worker {precompute_template_decl} 899*da0073e9SAndroid Build Coastguard Worker struct TORCH_API precompute_out {{ 900*da0073e9SAndroid Build Coastguard Worker {setter_methods_decl} 901*da0073e9SAndroid Build Coastguard Worker {precomputed_elements_decl}; 902*da0073e9SAndroid Build Coastguard Worker }};""" 903*da0073e9SAndroid Build Coastguard Worker else: 904*da0073e9SAndroid Build Coastguard Worker meta_return_typedef = "" 905*da0073e9SAndroid Build Coastguard Worker precomputed_decl = "" 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker return f"""\ 908*da0073e9SAndroid Build Coastguard Workerstruct TORCH_API structured_{name} : public {parent_class} {{ 909*da0073e9SAndroid Build Coastguard Worker {precomputed_decl} 910*da0073e9SAndroid Build Coastguard Worker {meta_return_typedef} 911*da0073e9SAndroid Build Coastguard Worker {meta_return} meta({args_str}); 912*da0073e9SAndroid Build Coastguard Worker}}; 913*da0073e9SAndroid Build Coastguard Worker""" 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Workerdef needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool: 917*da0073e9SAndroid Build Coastguard Worker name = str(f.func.name.name) 918*da0073e9SAndroid Build Coastguard Worker if name.endswith("_like") or name.startswith("new_"): 919*da0073e9SAndroid Build Coastguard Worker return False 920*da0073e9SAndroid Build Coastguard Worker if f.func.arguments.tensor_options is None: 921*da0073e9SAndroid Build Coastguard Worker return False 922*da0073e9SAndroid Build Coastguard Worker return selector.is_native_function_selected(f) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker# Generates RegisterBackendSelect.cpp, a series of kernels which provide 926*da0073e9SAndroid Build Coastguard Worker# specialized computation of dispatch key for operator signatures which cannot 927*da0073e9SAndroid Build Coastguard Worker# be easily done automatically using templating. 928*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 929*da0073e9SAndroid Build Coastguard Workerclass ComputeBackendSelect: 930*da0073e9SAndroid Build Coastguard Worker target: Literal[Target.DEFINITION, Target.REGISTRATION] 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker # Selector object to determine which operators to generate 933*da0073e9SAndroid Build Coastguard Worker # registration code for. 934*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker @method_with_native_function 937*da0073e9SAndroid Build Coastguard Worker def __call__(self, f: NativeFunction) -> str | None: 938*da0073e9SAndroid Build Coastguard Worker if not needs_backend_select(f, self.selector): 939*da0073e9SAndroid Build Coastguard Worker return None 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker name = native.name(f.func) 942*da0073e9SAndroid Build Coastguard Worker # BackendSelect can go to Meta, so it must preserve symints 943*da0073e9SAndroid Build Coastguard Worker native_sig = NativeSignature(f.func, symint=True) 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker native_tensor_args = [ 946*da0073e9SAndroid Build Coastguard Worker a 947*da0073e9SAndroid Build Coastguard Worker for a in native_sig.arguments() 948*da0073e9SAndroid Build Coastguard Worker if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() 949*da0073e9SAndroid Build Coastguard Worker ] 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker dispatcher_sig = DispatcherSignature.from_schema(f.func) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker sig: NativeSignature | DispatcherSignature 954*da0073e9SAndroid Build Coastguard Worker sig = dispatcher_sig 955*da0073e9SAndroid Build Coastguard Worker dispatcher_exprs = dispatcher_sig.exprs() 956*da0073e9SAndroid Build Coastguard Worker dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker if self.target is Target.DEFINITION: 959*da0073e9SAndroid Build Coastguard Worker # I don't think there's actually a good reason to generate 960*da0073e9SAndroid Build Coastguard Worker # these two cases differently 961*da0073e9SAndroid Build Coastguard Worker # The first case could probably be improved though- it calls computeDispatchKeySet(), 962*da0073e9SAndroid Build Coastguard Worker # which looks at TLS dispatch keys- there should not be any by the time we reach backend select. 963*da0073e9SAndroid Build Coastguard Worker if native_tensor_args: 964*da0073e9SAndroid Build Coastguard Worker assert f.func.arguments.has_tensor_arg() 965*da0073e9SAndroid Build Coastguard Worker tensor_args = ", ".join(a.name for a in native_tensor_args) 966*da0073e9SAndroid Build Coastguard Worker compute_dk = f"""\ 967*da0073e9SAndroid Build Coastguard WorkerDispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args}); 968*da0073e9SAndroid Build Coastguard WorkerDispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect); 969*da0073e9SAndroid Build Coastguard WorkerDispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);""" 970*da0073e9SAndroid Build Coastguard Worker else: 971*da0073e9SAndroid Build Coastguard Worker assert not f.func.arguments.has_tensor_arg() 972*da0073e9SAndroid Build Coastguard Worker compute_dk = ( 973*da0073e9SAndroid Build Coastguard Worker f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});" 974*da0073e9SAndroid Build Coastguard Worker ) 975*da0073e9SAndroid Build Coastguard Worker return f"""\ 976*da0073e9SAndroid Build Coastguard Worker// aten::{f.func} 977*da0073e9SAndroid Build Coastguard WorkerC10_ALWAYS_INLINE 978*da0073e9SAndroid Build Coastguard Worker{sig.defn(name)} {{ 979*da0073e9SAndroid Build Coastguard Worker {compute_dk} 980*da0073e9SAndroid Build Coastguard Worker return at::_ops::{f.func.name.unambiguous_name()}::redispatch( 981*da0073e9SAndroid Build Coastguard Worker _dk, {', '.join(a.expr for a in dispatcher_exprs)}); 982*da0073e9SAndroid Build Coastguard Worker}} 983*da0073e9SAndroid Build Coastguard Worker""" 984*da0073e9SAndroid Build Coastguard Worker elif self.target is Target.REGISTRATION: 985*da0073e9SAndroid Build Coastguard Worker return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));""" 986*da0073e9SAndroid Build Coastguard Worker else: 987*da0073e9SAndroid Build Coastguard Worker assert_never(self.target) 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 991*da0073e9SAndroid Build Coastguard Worker# 992*da0073e9SAndroid Build Coastguard Worker# YAML CODE GENERATION 993*da0073e9SAndroid Build Coastguard Worker# 994*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Workerdef format_yaml(data: object) -> str: 998*da0073e9SAndroid Build Coastguard Worker # Ignore alias in Dumper 999*da0073e9SAndroid Build Coastguard Worker YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment] 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker # Support serializing OrderedDict 1002*da0073e9SAndroid Build Coastguard Worker def dict_representer(dumper: Any, data: Any) -> Any: 1003*da0073e9SAndroid Build Coastguard Worker return dumper.represent_dict(data.items()) 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call] 1006*da0073e9SAndroid Build Coastguard Worker # Some yaml parsers (e.g. Haskell's) don't understand line breaks. 1007*da0073e9SAndroid Build Coastguard Worker # width=1e9 turns off optional line breaks and improves 1008*da0073e9SAndroid Build Coastguard Worker # the portability of the outputted yaml. 1009*da0073e9SAndroid Build Coastguard Worker return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload] 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker# For some reason, some defaults we write to YAML are written as native 1013*da0073e9SAndroid Build Coastguard Worker# YAML objects, rather than doing them uniformly as strings. This 1014*da0073e9SAndroid Build Coastguard Worker# function detects those cases and converts them into native Python 1015*da0073e9SAndroid Build Coastguard Worker# objects. 1016*da0073e9SAndroid Build Coastguard Workerdef pythonify_default(s: str) -> object: 1017*da0073e9SAndroid Build Coastguard Worker if s == "true": 1018*da0073e9SAndroid Build Coastguard Worker return True 1019*da0073e9SAndroid Build Coastguard Worker elif s == "false": 1020*da0073e9SAndroid Build Coastguard Worker return False 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker try: 1023*da0073e9SAndroid Build Coastguard Worker return int(s) 1024*da0073e9SAndroid Build Coastguard Worker except ValueError: 1025*da0073e9SAndroid Build Coastguard Worker try: 1026*da0073e9SAndroid Build Coastguard Worker return float(s) 1027*da0073e9SAndroid Build Coastguard Worker except ValueError: 1028*da0073e9SAndroid Build Coastguard Worker return s 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker# What is a dynamic type? Over time, the semantic meaning of 1032*da0073e9SAndroid Build Coastguard Worker# dynamic type has degraded to meaninglessness (in the old days, 1033*da0073e9SAndroid Build Coastguard Worker# it captured dtype-ness of types, but that has gone away with 1034*da0073e9SAndroid Build Coastguard Worker# the removal of TH). These days, it's mostly the same thing as 1035*da0073e9SAndroid Build Coastguard Worker# the C++ API argument type, except that Tensor and Tensor? 1036*da0073e9SAndroid Build Coastguard Worker# arguments simply present as Tensor. 1037*da0073e9SAndroid Build Coastguard Worker# 1038*da0073e9SAndroid Build Coastguard Worker# TODO: Get rid of dynamic_type, after getting tools/autograd 1039*da0073e9SAndroid Build Coastguard Worker# to use the new codegen framework 1040*da0073e9SAndroid Build Coastguard Workerdef dynamic_type(t: Type) -> str: 1041*da0073e9SAndroid Build Coastguard Worker if isinstance(t, OptionalType): 1042*da0073e9SAndroid Build Coastguard Worker return dynamic_type(t.elem) 1043*da0073e9SAndroid Build Coastguard Worker # Note we don't use t.is_tensor_like() here because it would 1044*da0073e9SAndroid Build Coastguard Worker # also include Tensor[] 1045*da0073e9SAndroid Build Coastguard Worker if str(t) == "Tensor": 1046*da0073e9SAndroid Build Coastguard Worker return "at::Tensor" 1047*da0073e9SAndroid Build Coastguard Worker # This is a legacy concept, so never report SymInt 1048*da0073e9SAndroid Build Coastguard Worker return cpp.argumenttype_type( 1049*da0073e9SAndroid Build Coastguard Worker t, mutable=False, binds="__placeholder__", symint=False 1050*da0073e9SAndroid Build Coastguard Worker ).cpp_type() 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Workerdef compute_method_of_yaml(variants: set[Variant]) -> list[str]: 1054*da0073e9SAndroid Build Coastguard Worker # This is written out explicitly to ensure that Tensor and 1055*da0073e9SAndroid Build Coastguard Worker # namespace are put into the list in the right order 1056*da0073e9SAndroid Build Coastguard Worker method_of = ["Type"] 1057*da0073e9SAndroid Build Coastguard Worker if Variant.method in variants: 1058*da0073e9SAndroid Build Coastguard Worker method_of.append("Tensor") 1059*da0073e9SAndroid Build Coastguard Worker if Variant.function in variants: 1060*da0073e9SAndroid Build Coastguard Worker method_of.append("namespace") 1061*da0073e9SAndroid Build Coastguard Worker return method_of 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Workerdef compute_returns_yaml( 1065*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 1066*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[dict[str, str]], dict[str, str]]: 1067*da0073e9SAndroid Build Coastguard Worker # Note [name and field_name] 1068*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 1069*da0073e9SAndroid Build Coastguard Worker # To understand name_to_field_name, we must first talk about this 1070*da0073e9SAndroid Build Coastguard Worker # schema: 1071*da0073e9SAndroid Build Coastguard Worker # 1072*da0073e9SAndroid Build Coastguard Worker # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) 1073*da0073e9SAndroid Build Coastguard Worker # 1074*da0073e9SAndroid Build Coastguard Worker # There is something very odd about this schema: it is an out 1075*da0073e9SAndroid Build Coastguard Worker # variant of the function (that is to say, it will convert into 1076*da0073e9SAndroid Build Coastguard Worker # at::lstsq_out() in the C++ API), but the names of the output 1077*da0073e9SAndroid Build Coastguard Worker # return arguments don't match the keyword argument names of 1078*da0073e9SAndroid Build Coastguard Worker # the inputs. It TURNS OUT that in this situation, the historical 1079*da0073e9SAndroid Build Coastguard Worker # Declarations.yaml we want to output is this (abbreviated to 1080*da0073e9SAndroid Build Coastguard Worker # only show relevant fields): 1081*da0073e9SAndroid Build Coastguard Worker # 1082*da0073e9SAndroid Build Coastguard Worker # arguments: 1083*da0073e9SAndroid Build Coastguard Worker # ... 1084*da0073e9SAndroid Build Coastguard Worker # - field_name: solution 1085*da0073e9SAndroid Build Coastguard Worker # name: X 1086*da0073e9SAndroid Build Coastguard Worker # - field_name: QR 1087*da0073e9SAndroid Build Coastguard Worker # name: qr 1088*da0073e9SAndroid Build Coastguard Worker # ... 1089*da0073e9SAndroid Build Coastguard Worker # 1090*da0073e9SAndroid Build Coastguard Worker # returns: 1091*da0073e9SAndroid Build Coastguard Worker # - field_name: solution 1092*da0073e9SAndroid Build Coastguard Worker # name: X 1093*da0073e9SAndroid Build Coastguard Worker # - field_name: QR 1094*da0073e9SAndroid Build Coastguard Worker # name: qr 1095*da0073e9SAndroid Build Coastguard Worker # 1096*da0073e9SAndroid Build Coastguard Worker # The name of the return fields is stored in 'field_name', and the 1097*da0073e9SAndroid Build Coastguard Worker # name of the arguments is stored in 'name'. So when we process 1098*da0073e9SAndroid Build Coastguard Worker # arguments, we need a way to get at the corresponding return. At 1099*da0073e9SAndroid Build Coastguard Worker # the moment, this is most conveniently done by constructing a 1100*da0073e9SAndroid Build Coastguard Worker # mapping from name (the argument concept) to field_name (the 1101*da0073e9SAndroid Build Coastguard Worker # return concept) while processing return arguments, since we don't 1102*da0073e9SAndroid Build Coastguard Worker # directly maintain this correspondence in the modeling of function 1103*da0073e9SAndroid Build Coastguard Worker # schema itself. 1104*da0073e9SAndroid Build Coastguard Worker # 1105*da0073e9SAndroid Build Coastguard Worker # See also https://github.com/pytorch/pytorch/issues/43114 1106*da0073e9SAndroid Build Coastguard Worker name_to_field_name: dict[str, str] = {} 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker # Compute the returns field of the YAML entry 1109*da0073e9SAndroid Build Coastguard Worker names = cpp.return_names(f) 1110*da0073e9SAndroid Build Coastguard Worker returns = [] 1111*da0073e9SAndroid Build Coastguard Worker for i, (r, name) in enumerate(zip(f.func.returns, names)): 1112*da0073e9SAndroid Build Coastguard Worker ret = { 1113*da0073e9SAndroid Build Coastguard Worker "dynamic_type": dynamic_type(r.type), 1114*da0073e9SAndroid Build Coastguard Worker "name": name, 1115*da0073e9SAndroid Build Coastguard Worker # legacy, report ints 1116*da0073e9SAndroid Build Coastguard Worker "type": cpp.return_type(r, symint=False).cpp_type(), 1117*da0073e9SAndroid Build Coastguard Worker } 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker if r.name: 1120*da0073e9SAndroid Build Coastguard Worker # See Note [name and field_name] 1121*da0073e9SAndroid Build Coastguard Worker ret["field_name"] = r.name 1122*da0073e9SAndroid Build Coastguard Worker if f.func.is_out_fn(): 1123*da0073e9SAndroid Build Coastguard Worker name_to_field_name[f.func.arguments.out[i].name] = r.name 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker returns.append(ret) 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker return returns, name_to_field_name 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker# arguments in yaml roughly corresponds to the public C++ API 1131*da0073e9SAndroid Build Coastguard Workerdef compute_cpp_argument_yaml( 1132*da0073e9SAndroid Build Coastguard Worker cpp_a: Binding, 1133*da0073e9SAndroid Build Coastguard Worker *, 1134*da0073e9SAndroid Build Coastguard Worker schema_order: bool, 1135*da0073e9SAndroid Build Coastguard Worker kwarg_only_set: set[str], 1136*da0073e9SAndroid Build Coastguard Worker out_arg_set: set[str], 1137*da0073e9SAndroid Build Coastguard Worker name_to_field_name: dict[str, str], 1138*da0073e9SAndroid Build Coastguard Worker) -> object: 1139*da0073e9SAndroid Build Coastguard Worker if isinstance(cpp_a.argument, TensorOptionsArguments): 1140*da0073e9SAndroid Build Coastguard Worker arg: dict[str, object] = { 1141*da0073e9SAndroid Build Coastguard Worker "annotation": None, 1142*da0073e9SAndroid Build Coastguard Worker "dynamic_type": "at::TensorOptions", 1143*da0073e9SAndroid Build Coastguard Worker "is_nullable": False, 1144*da0073e9SAndroid Build Coastguard Worker "name": cpp_a.name, 1145*da0073e9SAndroid Build Coastguard Worker "type": cpp_a.type, 1146*da0073e9SAndroid Build Coastguard Worker "kwarg_only": True, 1147*da0073e9SAndroid Build Coastguard Worker } 1148*da0073e9SAndroid Build Coastguard Worker if cpp_a.default is not None: 1149*da0073e9SAndroid Build Coastguard Worker arg["default"] = cpp_a.default 1150*da0073e9SAndroid Build Coastguard Worker return arg 1151*da0073e9SAndroid Build Coastguard Worker elif isinstance(cpp_a.argument, SelfArgument): 1152*da0073e9SAndroid Build Coastguard Worker raise AssertionError 1153*da0073e9SAndroid Build Coastguard Worker elif isinstance(cpp_a.argument, Argument): 1154*da0073e9SAndroid Build Coastguard Worker return compute_argument_yaml( 1155*da0073e9SAndroid Build Coastguard Worker cpp_a.argument, 1156*da0073e9SAndroid Build Coastguard Worker schema_order=schema_order, 1157*da0073e9SAndroid Build Coastguard Worker kwarg_only_set=kwarg_only_set, 1158*da0073e9SAndroid Build Coastguard Worker out_arg_set=out_arg_set, 1159*da0073e9SAndroid Build Coastguard Worker name_to_field_name=name_to_field_name, 1160*da0073e9SAndroid Build Coastguard Worker ) 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Workerdef compute_argument_yaml( 1164*da0073e9SAndroid Build Coastguard Worker a: Argument, 1165*da0073e9SAndroid Build Coastguard Worker *, 1166*da0073e9SAndroid Build Coastguard Worker schema_order: bool, 1167*da0073e9SAndroid Build Coastguard Worker kwarg_only_set: set[str], 1168*da0073e9SAndroid Build Coastguard Worker out_arg_set: set[str], 1169*da0073e9SAndroid Build Coastguard Worker name_to_field_name: dict[str, str], 1170*da0073e9SAndroid Build Coastguard Worker) -> object: 1171*da0073e9SAndroid Build Coastguard Worker arg: dict[str, object] = { 1172*da0073e9SAndroid Build Coastguard Worker "annotation": str(a.annotation) if a.annotation else None, 1173*da0073e9SAndroid Build Coastguard Worker "dynamic_type": dynamic_type(a.type), 1174*da0073e9SAndroid Build Coastguard Worker "is_nullable": a.type.is_nullable(), 1175*da0073e9SAndroid Build Coastguard Worker "name": a.name, 1176*da0073e9SAndroid Build Coastguard Worker # legacy, report ints 1177*da0073e9SAndroid Build Coastguard Worker "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(), 1178*da0073e9SAndroid Build Coastguard Worker } 1179*da0073e9SAndroid Build Coastguard Worker if a.default is not None: 1180*da0073e9SAndroid Build Coastguard Worker arg["default"] = pythonify_default( 1181*da0073e9SAndroid Build Coastguard Worker cpp.default_expr(a.default, a.type, symint=False) 1182*da0073e9SAndroid Build Coastguard Worker ) 1183*da0073e9SAndroid Build Coastguard Worker if a.name in kwarg_only_set: 1184*da0073e9SAndroid Build Coastguard Worker arg["kwarg_only"] = True 1185*da0073e9SAndroid Build Coastguard Worker if a.name in out_arg_set: 1186*da0073e9SAndroid Build Coastguard Worker arg["output"] = True 1187*da0073e9SAndroid Build Coastguard Worker arg["allocate"] = True 1188*da0073e9SAndroid Build Coastguard Worker # See Note [name and field_name] 1189*da0073e9SAndroid Build Coastguard Worker if a.name in name_to_field_name: 1190*da0073e9SAndroid Build Coastguard Worker arg["field_name"] = name_to_field_name[a.name] 1191*da0073e9SAndroid Build Coastguard Worker # Historically, booleans don't get their size recorded, because it 1192*da0073e9SAndroid Build Coastguard Worker # is already built into the cpp type (e.g., std::array<bool, 4>) 1193*da0073e9SAndroid Build Coastguard Worker l = a.type.is_list_like() 1194*da0073e9SAndroid Build Coastguard Worker if l is not None and l.size is not None and str(l.elem) != "bool": 1195*da0073e9SAndroid Build Coastguard Worker arg["size"] = l.size 1196*da0073e9SAndroid Build Coastguard Worker return arg 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker@with_native_function 1200*da0073e9SAndroid Build Coastguard Workerdef compute_declaration_yaml(f: NativeFunction) -> object: 1201*da0073e9SAndroid Build Coastguard Worker returns, name_to_field_name = compute_returns_yaml(f) 1202*da0073e9SAndroid Build Coastguard Worker 1203*da0073e9SAndroid Build Coastguard Worker # These sets are used to conveniently test if an argument is a 1204*da0073e9SAndroid Build Coastguard Worker # kwarg-only or out argument 1205*da0073e9SAndroid Build Coastguard Worker kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only} 1206*da0073e9SAndroid Build Coastguard Worker out_arg_set = {a.name for a in f.func.arguments.out} 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker sig_group = CppSignatureGroup.from_native_function( 1209*da0073e9SAndroid Build Coastguard Worker f, method=False, fallback_binding=False 1210*da0073e9SAndroid Build Coastguard Worker ) 1211*da0073e9SAndroid Build Coastguard Worker cpp_args = sig_group.signature.arguments() 1212*da0073e9SAndroid Build Coastguard Worker arguments = [ 1213*da0073e9SAndroid Build Coastguard Worker compute_cpp_argument_yaml( 1214*da0073e9SAndroid Build Coastguard Worker cpp_a, 1215*da0073e9SAndroid Build Coastguard Worker schema_order=False, 1216*da0073e9SAndroid Build Coastguard Worker kwarg_only_set=kwarg_only_set, 1217*da0073e9SAndroid Build Coastguard Worker out_arg_set=out_arg_set, 1218*da0073e9SAndroid Build Coastguard Worker name_to_field_name=name_to_field_name, 1219*da0073e9SAndroid Build Coastguard Worker ) 1220*da0073e9SAndroid Build Coastguard Worker for cpp_a in cpp_args 1221*da0073e9SAndroid Build Coastguard Worker ] 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker schema_order_jit_arguments = list(f.func.schema_order_arguments()) 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker schema_order_arguments = [ 1226*da0073e9SAndroid Build Coastguard Worker compute_argument_yaml( 1227*da0073e9SAndroid Build Coastguard Worker a, 1228*da0073e9SAndroid Build Coastguard Worker schema_order=True, 1229*da0073e9SAndroid Build Coastguard Worker kwarg_only_set=kwarg_only_set, 1230*da0073e9SAndroid Build Coastguard Worker out_arg_set=out_arg_set, 1231*da0073e9SAndroid Build Coastguard Worker name_to_field_name=name_to_field_name, 1232*da0073e9SAndroid Build Coastguard Worker ) 1233*da0073e9SAndroid Build Coastguard Worker for a in schema_order_jit_arguments 1234*da0073e9SAndroid Build Coastguard Worker ] 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker cpp_schema_order_types = [ 1237*da0073e9SAndroid Build Coastguard Worker # NB: method here doesn't matter 1238*da0073e9SAndroid Build Coastguard Worker r.type 1239*da0073e9SAndroid Build Coastguard Worker for a in schema_order_jit_arguments 1240*da0073e9SAndroid Build Coastguard Worker for r in cpp.argument( 1241*da0073e9SAndroid Build Coastguard Worker a, 1242*da0073e9SAndroid Build Coastguard Worker method=False, 1243*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args=set(), 1244*da0073e9SAndroid Build Coastguard Worker faithful=False, 1245*da0073e9SAndroid Build Coastguard Worker symint=False, 1246*da0073e9SAndroid Build Coastguard Worker has_tensor_options=False, 1247*da0073e9SAndroid Build Coastguard Worker ) 1248*da0073e9SAndroid Build Coastguard Worker ] 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker # legacy, report ints 1251*da0073e9SAndroid Build Coastguard Worker cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type() 1252*da0073e9SAndroid Build Coastguard Worker schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker is_factory_method = ( 1255*da0073e9SAndroid Build Coastguard Worker any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) 1256*da0073e9SAndroid Build Coastguard Worker and Variant.method not in f.variants 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker return OrderedDict( 1260*da0073e9SAndroid Build Coastguard Worker [ 1261*da0073e9SAndroid Build Coastguard Worker ("name", cpp.name(f.func)), 1262*da0073e9SAndroid Build Coastguard Worker ("operator_name", str(f.func.name.name)), 1263*da0073e9SAndroid Build Coastguard Worker ("overload_name", str(f.func.name.overload_name)), 1264*da0073e9SAndroid Build Coastguard Worker ("manual_kernel_registration", f.manual_kernel_registration), 1265*da0073e9SAndroid Build Coastguard Worker ( 1266*da0073e9SAndroid Build Coastguard Worker "category_override", 1267*da0073e9SAndroid Build Coastguard Worker f.category_override if f.category_override is not None else "", 1268*da0073e9SAndroid Build Coastguard Worker ), 1269*da0073e9SAndroid Build Coastguard Worker ("schema_string", f"aten::{f.func}"), 1270*da0073e9SAndroid Build Coastguard Worker ("arguments", arguments), 1271*da0073e9SAndroid Build Coastguard Worker ("schema_order_cpp_signature", schema_order_cpp_signature), 1272*da0073e9SAndroid Build Coastguard Worker ("schema_order_arguments", schema_order_arguments), 1273*da0073e9SAndroid Build Coastguard Worker ("method_of", compute_method_of_yaml(f.variants)), 1274*da0073e9SAndroid Build Coastguard Worker ("mode", "native"), 1275*da0073e9SAndroid Build Coastguard Worker ("python_module", "" if f.python_module is None else f.python_module), 1276*da0073e9SAndroid Build Coastguard Worker ("returns", returns), 1277*da0073e9SAndroid Build Coastguard Worker ("inplace", f.func.name.name.inplace), 1278*da0073e9SAndroid Build Coastguard Worker ("is_factory_method", is_factory_method), 1279*da0073e9SAndroid Build Coastguard Worker ("abstract", f.is_abstract), 1280*da0073e9SAndroid Build Coastguard Worker ("device_guard", f.device_guard), 1281*da0073e9SAndroid Build Coastguard Worker ("with_gil", False), 1282*da0073e9SAndroid Build Coastguard Worker ("deprecated", False), 1283*da0073e9SAndroid Build Coastguard Worker ("has_math_kernel", f.has_composite_implicit_autograd_kernel), 1284*da0073e9SAndroid Build Coastguard Worker ] 1285*da0073e9SAndroid Build Coastguard Worker ) 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker# See Note [Auto generated composite kernels] 1289*da0073e9SAndroid Build Coastguard Workerdef has_autogenerated_composite_kernel(f: NativeFunction) -> bool: 1290*da0073e9SAndroid Build Coastguard Worker return (f.structured or f.structured_delegate is not None) and ( 1291*da0073e9SAndroid Build Coastguard Worker f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace 1292*da0073e9SAndroid Build Coastguard Worker ) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker 1295*da0073e9SAndroid Build Coastguard Worker@with_native_function_and_indices 1296*da0073e9SAndroid Build Coastguard Workerdef compute_registration_declarations( 1297*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex] 1298*da0073e9SAndroid Build Coastguard Worker) -> str: 1299*da0073e9SAndroid Build Coastguard Worker name = dispatcher.name(f.func) 1300*da0073e9SAndroid Build Coastguard Worker returns_type = dispatcher.returns_type( 1301*da0073e9SAndroid Build Coastguard Worker f.func.returns 1302*da0073e9SAndroid Build Coastguard Worker ).cpp_type_registration_declarations() 1303*da0073e9SAndroid Build Coastguard Worker args = dispatcher.arguments(f.func) 1304*da0073e9SAndroid Build Coastguard Worker args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args) 1305*da0073e9SAndroid Build Coastguard Worker comment_data: dict[str, str] = { 1306*da0073e9SAndroid Build Coastguard Worker "schema": f"aten::{f.func}", 1307*da0073e9SAndroid Build Coastguard Worker # TODO: What exactly is the semantics of the 'dispatch' field? 1308*da0073e9SAndroid Build Coastguard Worker "dispatch": str( 1309*da0073e9SAndroid Build Coastguard Worker {k for k, v in backend_indices.items() if v.has_kernel(f)} 1310*da0073e9SAndroid Build Coastguard Worker != {DispatchKey.CompositeImplicitAutograd} 1311*da0073e9SAndroid Build Coastguard Worker and {k for k, v in backend_indices.items() if v.has_kernel(f)} 1312*da0073e9SAndroid Build Coastguard Worker != { 1313*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 1314*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor, 1315*da0073e9SAndroid Build Coastguard Worker } 1316*da0073e9SAndroid Build Coastguard Worker ), 1317*da0073e9SAndroid Build Coastguard Worker "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)), 1318*da0073e9SAndroid Build Coastguard Worker } 1319*da0073e9SAndroid Build Coastguard Worker return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} 1320*da0073e9SAndroid Build Coastguard Worker""" 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1324*da0073e9SAndroid Build Coastguard Worker# 1325*da0073e9SAndroid Build Coastguard Worker# RUN IT ALL 1326*da0073e9SAndroid Build Coastguard Worker# 1327*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Workerdef get_custom_build_selector( 1331*da0073e9SAndroid Build Coastguard Worker provided_op_registration_allowlist: list[str] | None, 1332*da0073e9SAndroid Build Coastguard Worker op_selection_yaml_path: str | None, 1333*da0073e9SAndroid Build Coastguard Worker) -> SelectiveBuilder: 1334*da0073e9SAndroid Build Coastguard Worker assert not ( 1335*da0073e9SAndroid Build Coastguard Worker provided_op_registration_allowlist is not None 1336*da0073e9SAndroid Build Coastguard Worker and op_selection_yaml_path is not None 1337*da0073e9SAndroid Build Coastguard Worker ), ( 1338*da0073e9SAndroid Build Coastguard Worker "Both provided_op_registration_allowlist and " 1339*da0073e9SAndroid Build Coastguard Worker + "op_selection_yaml_path can NOT be provided at the " 1340*da0073e9SAndroid Build Coastguard Worker + "same time." 1341*da0073e9SAndroid Build Coastguard Worker ) 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker op_registration_allowlist: set[str] | None = None 1344*da0073e9SAndroid Build Coastguard Worker if provided_op_registration_allowlist is not None: 1345*da0073e9SAndroid Build Coastguard Worker op_registration_allowlist = set(provided_op_registration_allowlist) 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker if op_registration_allowlist is not None: 1348*da0073e9SAndroid Build Coastguard Worker selector = SelectiveBuilder.from_legacy_op_registration_allow_list( 1349*da0073e9SAndroid Build Coastguard Worker op_registration_allowlist, 1350*da0073e9SAndroid Build Coastguard Worker True, 1351*da0073e9SAndroid Build Coastguard Worker False, 1352*da0073e9SAndroid Build Coastguard Worker ) 1353*da0073e9SAndroid Build Coastguard Worker elif op_selection_yaml_path is not None: 1354*da0073e9SAndroid Build Coastguard Worker selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path) 1355*da0073e9SAndroid Build Coastguard Worker else: 1356*da0073e9SAndroid Build Coastguard Worker selector = SelectiveBuilder.get_nop_selector() 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker return selector 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Workerdef get_grouped_by_view_native_functions( 1362*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 1363*da0073e9SAndroid Build Coastguard Worker) -> Sequence[NativeFunction | NativeFunctionsViewGroup]: 1364*da0073e9SAndroid Build Coastguard Worker def maybe_create_view_group( 1365*da0073e9SAndroid Build Coastguard Worker d: dict[ViewSchemaKind | SchemaKind, NativeFunction] 1366*da0073e9SAndroid Build Coastguard Worker ) -> list[NativeFunction | NativeFunctionsViewGroup]: 1367*da0073e9SAndroid Build Coastguard Worker funcs: list[NativeFunction | NativeFunctionsViewGroup] = [] 1368*da0073e9SAndroid Build Coastguard Worker if ViewSchemaKind.aliasing in d: 1369*da0073e9SAndroid Build Coastguard Worker view = d.pop(ViewSchemaKind.aliasing) 1370*da0073e9SAndroid Build Coastguard Worker view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) 1371*da0073e9SAndroid Build Coastguard Worker view_copy = d.pop(SchemaKind.functional, None) 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker funcs.append( 1374*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup( 1375*da0073e9SAndroid Build Coastguard Worker view=view, 1376*da0073e9SAndroid Build Coastguard Worker view_copy=view_copy, 1377*da0073e9SAndroid Build Coastguard Worker view_inplace=view_inplace, 1378*da0073e9SAndroid Build Coastguard Worker ) 1379*da0073e9SAndroid Build Coastguard Worker ) 1380*da0073e9SAndroid Build Coastguard Worker # Take the remaining functions that weren't part of the view group 1381*da0073e9SAndroid Build Coastguard Worker # and emit them separately 1382*da0073e9SAndroid Build Coastguard Worker funcs.extend(d.values()) 1383*da0073e9SAndroid Build Coastguard Worker return funcs 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker grouped_by_views: dict[ 1386*da0073e9SAndroid Build Coastguard Worker FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction] 1387*da0073e9SAndroid Build Coastguard Worker ] = defaultdict(dict) 1388*da0073e9SAndroid Build Coastguard Worker for f in native_functions: 1389*da0073e9SAndroid Build Coastguard Worker schema = f.func.view_signature() 1390*da0073e9SAndroid Build Coastguard Worker view_kind: ViewSchemaKind = f.view_schema_kind 1391*da0073e9SAndroid Build Coastguard Worker # We need to group up ops relevant to the same "view", consisting of: 1392*da0073e9SAndroid Build Coastguard Worker # view op (ViewSchemaKind.aliasing) 1393*da0073e9SAndroid Build Coastguard Worker # view_inplace op (ViewSchemaKind.aliasing_inplace) 1394*da0073e9SAndroid Build Coastguard Worker # view_copy op (SchemaKind.functional) 1395*da0073e9SAndroid Build Coastguard Worker if view_kind == ViewSchemaKind.non_aliasing: 1396*da0073e9SAndroid Build Coastguard Worker kind = f.func.kind() 1397*da0073e9SAndroid Build Coastguard Worker assert kind not in grouped_by_views[schema] 1398*da0073e9SAndroid Build Coastguard Worker grouped_by_views[schema][kind] = f 1399*da0073e9SAndroid Build Coastguard Worker else: 1400*da0073e9SAndroid Build Coastguard Worker assert ( 1401*da0073e9SAndroid Build Coastguard Worker view_kind not in grouped_by_views[schema] 1402*da0073e9SAndroid Build Coastguard Worker ), f"{view_kind} already in {grouped_by_views[schema].keys()}" 1403*da0073e9SAndroid Build Coastguard Worker grouped_by_views[schema][view_kind] = f 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker return list(concatMap(maybe_create_view_group, grouped_by_views.values())) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Workerdef get_grouped_native_functions( 1409*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 1410*da0073e9SAndroid Build Coastguard Worker) -> Sequence[NativeFunction | NativeFunctionsGroup]: 1411*da0073e9SAndroid Build Coastguard Worker def flatten_pre_group( 1412*da0073e9SAndroid Build Coastguard Worker d: dict[SchemaKind, NativeFunction] 1413*da0073e9SAndroid Build Coastguard Worker ) -> Sequence[NativeFunction | NativeFunctionsGroup]: 1414*da0073e9SAndroid Build Coastguard Worker r = NativeFunctionsGroup.from_dict(d) 1415*da0073e9SAndroid Build Coastguard Worker if r is None: 1416*da0073e9SAndroid Build Coastguard Worker # Invariant: any NativeFunctions that are code-generated 1417*da0073e9SAndroid Build Coastguard Worker # should have been grouped into NativeFunctionsGroup objects 1418*da0073e9SAndroid Build Coastguard Worker assert not any("generated" in f.tags for f in d.values()) 1419*da0073e9SAndroid Build Coastguard Worker return list(d.values()) 1420*da0073e9SAndroid Build Coastguard Worker else: 1421*da0073e9SAndroid Build Coastguard Worker return [r] 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker # TODO: how come ValuesView isn't a Sequence lol 1424*da0073e9SAndroid Build Coastguard Worker pre_grouped_native_functions = pre_group_native_functions(native_functions) 1425*da0073e9SAndroid Build Coastguard Worker return list( 1426*da0073e9SAndroid Build Coastguard Worker concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())) 1427*da0073e9SAndroid Build Coastguard Worker ) 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Workerdef get_ns_grouped_kernels( 1431*da0073e9SAndroid Build Coastguard Worker *, 1432*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1433*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, BackendIndex], 1434*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen: Callable[ 1435*da0073e9SAndroid Build Coastguard Worker [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] 1436*da0073e9SAndroid Build Coastguard Worker ] = dest.compute_native_function_declaration, 1437*da0073e9SAndroid Build Coastguard Worker) -> dict[str, list[str]]: 1438*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) 1439*da0073e9SAndroid Build Coastguard Worker for f in grouped_native_functions: 1440*da0073e9SAndroid Build Coastguard Worker native_function_namespaces = set() 1441*da0073e9SAndroid Build Coastguard Worker dispatch_keys = set() 1442*da0073e9SAndroid Build Coastguard Worker for dispatch_key, backend_idx in backend_indices.items(): 1443*da0073e9SAndroid Build Coastguard Worker backend_metadata = backend_idx.get_kernel(f) 1444*da0073e9SAndroid Build Coastguard Worker if backend_metadata: 1445*da0073e9SAndroid Build Coastguard Worker namespace = backend_metadata.cpp_namespace 1446*da0073e9SAndroid Build Coastguard Worker dispatch_keys.add(dispatch_key) 1447*da0073e9SAndroid Build Coastguard Worker native_function_namespaces.add(namespace) 1448*da0073e9SAndroid Build Coastguard Worker else: 1449*da0073e9SAndroid Build Coastguard Worker namespace = DEFAULT_KERNEL_NAMESPACE 1450*da0073e9SAndroid Build Coastguard Worker assert ( 1451*da0073e9SAndroid Build Coastguard Worker len(native_function_namespaces) <= 1 1452*da0073e9SAndroid Build Coastguard Worker ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" 1453*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels[namespace].extend( 1454*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen(f, backend_idx) 1455*da0073e9SAndroid Build Coastguard Worker ) 1456*da0073e9SAndroid Build Coastguard Worker return ns_grouped_kernels 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Workerdef get_native_function_declarations_from_ns_grouped_kernels( 1460*da0073e9SAndroid Build Coastguard Worker *, 1461*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels: dict[str, list[str]], 1462*da0073e9SAndroid Build Coastguard Worker) -> list[str]: 1463*da0073e9SAndroid Build Coastguard Worker declarations: list[str] = [] 1464*da0073e9SAndroid Build Coastguard Worker newline = "\n" 1465*da0073e9SAndroid Build Coastguard Worker for namespace, kernels in ns_grouped_kernels.items(): 1466*da0073e9SAndroid Build Coastguard Worker ns_helper = NamespaceHelper( 1467*da0073e9SAndroid Build Coastguard Worker namespace_str=namespace, 1468*da0073e9SAndroid Build Coastguard Worker entity_name="", 1469*da0073e9SAndroid Build Coastguard Worker max_level=4, 1470*da0073e9SAndroid Build Coastguard Worker ) 1471*da0073e9SAndroid Build Coastguard Worker # Convert to a set first to remove duplicate kernel names. Backends are 1472*da0073e9SAndroid Build Coastguard Worker # allowed to repeat kernel names; only generate the declaration once! 1473*da0073e9SAndroid Build Coastguard Worker ordered_kernels = list(OrderedDict.fromkeys(kernels)) 1474*da0073e9SAndroid Build Coastguard Worker declarations.extend( 1475*da0073e9SAndroid Build Coastguard Worker f""" 1476*da0073e9SAndroid Build Coastguard Worker{ns_helper.prologue} 1477*da0073e9SAndroid Build Coastguard Worker{newline.join(ordered_kernels)} 1478*da0073e9SAndroid Build Coastguard Worker{ns_helper.epilogue} 1479*da0073e9SAndroid Build Coastguard Worker """.split( 1480*da0073e9SAndroid Build Coastguard Worker newline 1481*da0073e9SAndroid Build Coastguard Worker ) 1482*da0073e9SAndroid Build Coastguard Worker ) 1483*da0073e9SAndroid Build Coastguard Worker return declarations 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker# Return native function declarations grouped by their namespaces. 1487*da0073e9SAndroid Build Coastguard Workerdef get_native_function_declarations( 1488*da0073e9SAndroid Build Coastguard Worker *, 1489*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1490*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, BackendIndex], 1491*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen: Callable[ 1492*da0073e9SAndroid Build Coastguard Worker [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str] 1493*da0073e9SAndroid Build Coastguard Worker ] = dest.compute_native_function_declaration, 1494*da0073e9SAndroid Build Coastguard Worker) -> list[str]: 1495*da0073e9SAndroid Build Coastguard Worker """ 1496*da0073e9SAndroid Build Coastguard Worker Generate kernel declarations, in `NativeFunction(s).h`. 1497*da0073e9SAndroid Build Coastguard Worker :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`. 1498*da0073e9SAndroid Build Coastguard Worker :param backend_indices: kernel collections grouped by dispatch key. 1499*da0073e9SAndroid Build Coastguard Worker :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`. 1500*da0073e9SAndroid Build Coastguard Worker :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline. 1501*da0073e9SAndroid Build Coastguard Worker """ 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels = get_ns_grouped_kernels( 1504*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 1505*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 1506*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen=native_function_decl_gen, 1507*da0073e9SAndroid Build Coastguard Worker ) 1508*da0073e9SAndroid Build Coastguard Worker return get_native_function_declarations_from_ns_grouped_kernels( 1509*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels=ns_grouped_kernels 1510*da0073e9SAndroid Build Coastguard Worker ) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Workerdef get_kernel_namespace( 1514*da0073e9SAndroid Build Coastguard Worker *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex 1515*da0073e9SAndroid Build Coastguard Worker) -> str: 1516*da0073e9SAndroid Build Coastguard Worker backend_metadata = backend_idx.get_kernel(f) 1517*da0073e9SAndroid Build Coastguard Worker assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( 1518*da0073e9SAndroid Build Coastguard Worker f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} " 1519*da0073e9SAndroid Build Coastguard Worker f"with dispatch key {backend_idx.dispatch_key}" 1520*da0073e9SAndroid Build Coastguard Worker f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'." 1521*da0073e9SAndroid Build Coastguard Worker ) 1522*da0073e9SAndroid Build Coastguard Worker return ( 1523*da0073e9SAndroid Build Coastguard Worker backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE 1524*da0073e9SAndroid Build Coastguard Worker ) 1525*da0073e9SAndroid Build Coastguard Worker 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker# Return native function definitions grouped by dispatch key and custom namespace. 1528*da0073e9SAndroid Build Coastguard Worker# Used in RegisterDispatchKey.cpp and etc. 1529*da0073e9SAndroid Build Coastguard Workerdef get_native_function_definitions( 1530*da0073e9SAndroid Build Coastguard Worker *, 1531*da0073e9SAndroid Build Coastguard Worker fm: FileManager, 1532*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1533*da0073e9SAndroid Build Coastguard Worker dispatch_key: DispatchKey, 1534*da0073e9SAndroid Build Coastguard Worker backend_idx: BackendIndex, 1535*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder, 1536*da0073e9SAndroid Build Coastguard Worker rocm: bool, 1537*da0073e9SAndroid Build Coastguard Worker symint: bool, 1538*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration: bool, 1539*da0073e9SAndroid Build Coastguard Worker gen_dispatch_helpers: bool, 1540*da0073e9SAndroid Build Coastguard Worker) -> list[str]: 1541*da0073e9SAndroid Build Coastguard Worker definitions: list[str] = [] 1542*da0073e9SAndroid Build Coastguard Worker ns_definitions: dict[str, list[str]] = defaultdict(list) 1543*da0073e9SAndroid Build Coastguard Worker anonymous_definitions: dict[str, list[str]] = defaultdict(list) 1544*da0073e9SAndroid Build Coastguard Worker registrations: dict[str, dict[str, list[str]]] = defaultdict(dict) 1545*da0073e9SAndroid Build Coastguard Worker newline = "\n" 1546*da0073e9SAndroid Build Coastguard Worker ns_gen = dest.RegisterDispatchKey( 1547*da0073e9SAndroid Build Coastguard Worker backend_idx, 1548*da0073e9SAndroid Build Coastguard Worker Target.NAMESPACED_DEFINITION, 1549*da0073e9SAndroid Build Coastguard Worker selector, 1550*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 1551*da0073e9SAndroid Build Coastguard Worker symint=symint, 1552*da0073e9SAndroid Build Coastguard Worker class_method_name=None, 1553*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=skip_dispatcher_op_registration, 1554*da0073e9SAndroid Build Coastguard Worker ) 1555*da0073e9SAndroid Build Coastguard Worker anonymous_gen = dest.RegisterDispatchKey( 1556*da0073e9SAndroid Build Coastguard Worker backend_idx, 1557*da0073e9SAndroid Build Coastguard Worker Target.ANONYMOUS_DEFINITION, 1558*da0073e9SAndroid Build Coastguard Worker selector, 1559*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 1560*da0073e9SAndroid Build Coastguard Worker symint=symint, 1561*da0073e9SAndroid Build Coastguard Worker class_method_name=None, 1562*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=skip_dispatcher_op_registration, 1563*da0073e9SAndroid Build Coastguard Worker ) 1564*da0073e9SAndroid Build Coastguard Worker reg_gen = dest.RegisterDispatchKey( 1565*da0073e9SAndroid Build Coastguard Worker backend_idx, 1566*da0073e9SAndroid Build Coastguard Worker Target.REGISTRATION, 1567*da0073e9SAndroid Build Coastguard Worker selector, 1568*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 1569*da0073e9SAndroid Build Coastguard Worker symint=symint, 1570*da0073e9SAndroid Build Coastguard Worker class_method_name=None, 1571*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=skip_dispatcher_op_registration, 1572*da0073e9SAndroid Build Coastguard Worker ) 1573*da0073e9SAndroid Build Coastguard Worker for f in grouped_native_functions: 1574*da0073e9SAndroid Build Coastguard Worker kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( 1575*da0073e9SAndroid Build Coastguard Worker "::native", "" 1576*da0073e9SAndroid Build Coastguard Worker ) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker ns_definitions[kernel_namespace].extend( 1579*da0073e9SAndroid Build Coastguard Worker ns_gen(f), 1580*da0073e9SAndroid Build Coastguard Worker ) 1581*da0073e9SAndroid Build Coastguard Worker anonymous_definitions[kernel_namespace].extend( 1582*da0073e9SAndroid Build Coastguard Worker anonymous_gen(f), 1583*da0073e9SAndroid Build Coastguard Worker ) 1584*da0073e9SAndroid Build Coastguard Worker namespace = ( 1585*da0073e9SAndroid Build Coastguard Worker f.namespace if isinstance(f, NativeFunction) else f.functional.namespace 1586*da0073e9SAndroid Build Coastguard Worker ) 1587*da0073e9SAndroid Build Coastguard Worker if namespace not in registrations[kernel_namespace]: 1588*da0073e9SAndroid Build Coastguard Worker registrations[kernel_namespace] = defaultdict(list) 1589*da0073e9SAndroid Build Coastguard Worker registrations[kernel_namespace][namespace].extend( 1590*da0073e9SAndroid Build Coastguard Worker reg_gen(f), 1591*da0073e9SAndroid Build Coastguard Worker ) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker for kernel_namespace in ns_definitions: 1594*da0073e9SAndroid Build Coastguard Worker if len(ns_definitions[kernel_namespace]) == 0: 1595*da0073e9SAndroid Build Coastguard Worker continue 1596*da0073e9SAndroid Build Coastguard Worker ns_helper = NamespaceHelper(namespace_str=kernel_namespace) 1597*da0073e9SAndroid Build Coastguard Worker registration_body = "" 1598*da0073e9SAndroid Build Coastguard Worker for namespace in registrations[kernel_namespace]: 1599*da0073e9SAndroid Build Coastguard Worker if not registrations[kernel_namespace][namespace]: 1600*da0073e9SAndroid Build Coastguard Worker continue 1601*da0073e9SAndroid Build Coastguard Worker registration_body += f""" 1602*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ 1603*da0073e9SAndroid Build Coastguard Worker {newline.join(registrations[kernel_namespace][namespace])} 1604*da0073e9SAndroid Build Coastguard Worker}};""" 1605*da0073e9SAndroid Build Coastguard Worker definitions.extend( 1606*da0073e9SAndroid Build Coastguard Worker fm.substitute_with_template( 1607*da0073e9SAndroid Build Coastguard Worker "RegisterDispatchDefinitions.ini", 1608*da0073e9SAndroid Build Coastguard Worker lambda: { 1609*da0073e9SAndroid Build Coastguard Worker "ns_prologue": ns_helper.prologue, 1610*da0073e9SAndroid Build Coastguard Worker "ns_epilogue": ns_helper.epilogue, 1611*da0073e9SAndroid Build Coastguard Worker "dispatch_helpers": dest.gen_registration_helpers(backend_idx) 1612*da0073e9SAndroid Build Coastguard Worker if gen_dispatch_helpers 1613*da0073e9SAndroid Build Coastguard Worker else [], 1614*da0073e9SAndroid Build Coastguard Worker "dispatch_anonymous_definitions": anonymous_definitions[ 1615*da0073e9SAndroid Build Coastguard Worker kernel_namespace 1616*da0073e9SAndroid Build Coastguard Worker ], 1617*da0073e9SAndroid Build Coastguard Worker "static_init_dispatch_registrations": "" 1618*da0073e9SAndroid Build Coastguard Worker if skip_dispatcher_op_registration 1619*da0073e9SAndroid Build Coastguard Worker else registration_body, 1620*da0073e9SAndroid Build Coastguard Worker "deferred_dispatch_registrations": "", 1621*da0073e9SAndroid Build Coastguard Worker "dispatch_namespace": dispatch_key.lower(), 1622*da0073e9SAndroid Build Coastguard Worker "dispatch_namespaced_definitions": ns_definitions[kernel_namespace], 1623*da0073e9SAndroid Build Coastguard Worker }, 1624*da0073e9SAndroid Build Coastguard Worker ).split(newline) 1625*da0073e9SAndroid Build Coastguard Worker ) 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker return definitions 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker# Return native function declarations grouped by dispatch key and custom namespace. 1631*da0073e9SAndroid Build Coastguard Worker# Used in CPUFunctions_inl.h and etc. 1632*da0073e9SAndroid Build Coastguard Workerdef get_namespaced_declaration( 1633*da0073e9SAndroid Build Coastguard Worker *, 1634*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1635*da0073e9SAndroid Build Coastguard Worker dispatch_key: DispatchKey, 1636*da0073e9SAndroid Build Coastguard Worker backend_idx: BackendIndex, 1637*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder, 1638*da0073e9SAndroid Build Coastguard Worker rocm: bool, 1639*da0073e9SAndroid Build Coastguard Worker symint: bool, 1640*da0073e9SAndroid Build Coastguard Worker) -> list[str]: 1641*da0073e9SAndroid Build Coastguard Worker declarations: list[str] = [] 1642*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) 1643*da0073e9SAndroid Build Coastguard Worker newline = "\n" 1644*da0073e9SAndroid Build Coastguard Worker func = dest.RegisterDispatchKey( 1645*da0073e9SAndroid Build Coastguard Worker backend_idx, 1646*da0073e9SAndroid Build Coastguard Worker Target.NAMESPACED_DECLARATION, 1647*da0073e9SAndroid Build Coastguard Worker selector, 1648*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 1649*da0073e9SAndroid Build Coastguard Worker class_method_name=None, 1650*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=False, 1651*da0073e9SAndroid Build Coastguard Worker symint=symint, 1652*da0073e9SAndroid Build Coastguard Worker ) 1653*da0073e9SAndroid Build Coastguard Worker for f in grouped_native_functions: 1654*da0073e9SAndroid Build Coastguard Worker namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace( 1655*da0073e9SAndroid Build Coastguard Worker "native", dispatch_key.lower() 1656*da0073e9SAndroid Build Coastguard Worker ) 1657*da0073e9SAndroid Build Coastguard Worker 1658*da0073e9SAndroid Build Coastguard Worker ns_grouped_kernels[namespace].extend( 1659*da0073e9SAndroid Build Coastguard Worker func(f), 1660*da0073e9SAndroid Build Coastguard Worker ) 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker for namespace, kernels in ns_grouped_kernels.items(): 1663*da0073e9SAndroid Build Coastguard Worker if len(kernels) == 0: 1664*da0073e9SAndroid Build Coastguard Worker continue 1665*da0073e9SAndroid Build Coastguard Worker ns_helper = NamespaceHelper( 1666*da0073e9SAndroid Build Coastguard Worker namespace_str=namespace, entity_name="", max_level=3 1667*da0073e9SAndroid Build Coastguard Worker ) 1668*da0073e9SAndroid Build Coastguard Worker ordered_kernels = list(OrderedDict.fromkeys(kernels)) 1669*da0073e9SAndroid Build Coastguard Worker declarations.extend( 1670*da0073e9SAndroid Build Coastguard Worker f""" 1671*da0073e9SAndroid Build Coastguard Worker{ns_helper.prologue} 1672*da0073e9SAndroid Build Coastguard Worker{newline.join(ordered_kernels)} 1673*da0073e9SAndroid Build Coastguard Worker{ns_helper.epilogue} 1674*da0073e9SAndroid Build Coastguard Worker """.split( 1675*da0073e9SAndroid Build Coastguard Worker newline 1676*da0073e9SAndroid Build Coastguard Worker ) 1677*da0073e9SAndroid Build Coastguard Worker ) 1678*da0073e9SAndroid Build Coastguard Worker return declarations 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker# Return native function schema registration code for aten and other namespaces. 1682*da0073e9SAndroid Build Coastguard Workerdef get_native_function_schema_registrations( 1683*da0073e9SAndroid Build Coastguard Worker *, 1684*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 1685*da0073e9SAndroid Build Coastguard Worker schema_selector: SelectiveBuilder, 1686*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], str]: 1687*da0073e9SAndroid Build Coastguard Worker ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) 1688*da0073e9SAndroid Build Coastguard Worker for native_function in native_functions: 1689*da0073e9SAndroid Build Coastguard Worker ns_native_functions[native_function.namespace].append(native_function) 1690*da0073e9SAndroid Build Coastguard Worker schema_registrations = "" 1691*da0073e9SAndroid Build Coastguard Worker aten_schema_registrations = [] 1692*da0073e9SAndroid Build Coastguard Worker custom_namespace = None 1693*da0073e9SAndroid Build Coastguard Worker for namespace, funcs in ns_native_functions.items(): 1694*da0073e9SAndroid Build Coastguard Worker schema_registrations_body = list( 1695*da0073e9SAndroid Build Coastguard Worker mapMaybe(RegisterSchema(schema_selector), funcs) 1696*da0073e9SAndroid Build Coastguard Worker ) 1697*da0073e9SAndroid Build Coastguard Worker # NB: we have to separate aten namespace registration from other namespaces, 1698*da0073e9SAndroid Build Coastguard Worker # because in the template we hardcoded an operator for ATen already. 1699*da0073e9SAndroid Build Coastguard Worker if namespace == "aten": 1700*da0073e9SAndroid Build Coastguard Worker aten_schema_registrations = schema_registrations_body 1701*da0073e9SAndroid Build Coastguard Worker else: 1702*da0073e9SAndroid Build Coastguard Worker custom_namespace = namespace 1703*da0073e9SAndroid Build Coastguard Worker tab = "\t" 1704*da0073e9SAndroid Build Coastguard Worker # if the namespace is predefined, we should use define a library fragment 1705*da0073e9SAndroid Build Coastguard Worker # instead of a new library 1706*da0073e9SAndroid Build Coastguard Worker torch_library_macro = ( 1707*da0073e9SAndroid Build Coastguard Worker "TORCH_LIBRARY_FRAGMENT" 1708*da0073e9SAndroid Build Coastguard Worker if namespace in FRAGMENT_NAMESPACES 1709*da0073e9SAndroid Build Coastguard Worker else "TORCH_LIBRARY" 1710*da0073e9SAndroid Build Coastguard Worker ) 1711*da0073e9SAndroid Build Coastguard Worker schema_registrations += f""" 1712*da0073e9SAndroid Build Coastguard Worker{torch_library_macro}({custom_namespace}, m) {{ 1713*da0073e9SAndroid Build Coastguard Worker {tab.join(schema_registrations_body)} 1714*da0073e9SAndroid Build Coastguard Worker}};""" 1715*da0073e9SAndroid Build Coastguard Worker return (aten_schema_registrations, schema_registrations) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker 1718*da0073e9SAndroid Build Coastguard Workerdef gen_aggregated_headers( 1719*da0073e9SAndroid Build Coastguard Worker *, 1720*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 1721*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1722*da0073e9SAndroid Build Coastguard Worker structured_native_functions: Sequence[NativeFunctionsGroup], 1723*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx: list[BackendIndex], 1724*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder, 1725*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, BackendIndex], 1726*da0073e9SAndroid Build Coastguard Worker cpu_fm: FileManager, 1727*da0073e9SAndroid Build Coastguard Worker cuda_fm: FileManager, 1728*da0073e9SAndroid Build Coastguard Worker functions_keys: set[DispatchKey], 1729*da0073e9SAndroid Build Coastguard Worker dispatch_keys: Sequence[DispatchKey], 1730*da0073e9SAndroid Build Coastguard Worker rocm: bool, 1731*da0073e9SAndroid Build Coastguard Worker) -> None: 1732*da0073e9SAndroid Build Coastguard Worker # Buck doesn't support dynamic output files, so we aggregate all operator 1733*da0073e9SAndroid Build Coastguard Worker # headers into a single file 1734*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 1735*da0073e9SAndroid Build Coastguard Worker "NativeMetaFunctions.h", 1736*da0073e9SAndroid Build Coastguard Worker lambda: { 1737*da0073e9SAndroid Build Coastguard Worker "NativeMetaFunctions_includes": [], 1738*da0073e9SAndroid Build Coastguard Worker "NativeMetaFunctions_declarations": list( 1739*da0073e9SAndroid Build Coastguard Worker mapMaybe(compute_meta_function_declaration, structured_native_functions) 1740*da0073e9SAndroid Build Coastguard Worker ), 1741*da0073e9SAndroid Build Coastguard Worker }, 1742*da0073e9SAndroid Build Coastguard Worker ) 1743*da0073e9SAndroid Build Coastguard Worker method_native_functions = [ 1744*da0073e9SAndroid Build Coastguard Worker fn for fn in native_functions if Variant.method in fn.variants 1745*da0073e9SAndroid Build Coastguard Worker ] 1746*da0073e9SAndroid Build Coastguard Worker non_method_native_functions = [ 1747*da0073e9SAndroid Build Coastguard Worker fn for fn in native_functions if fn not in method_native_functions 1748*da0073e9SAndroid Build Coastguard Worker ] 1749*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 1750*da0073e9SAndroid Build Coastguard Worker "MethodOperators.h", 1751*da0073e9SAndroid Build Coastguard Worker lambda: { 1752*da0073e9SAndroid Build Coastguard Worker "MethodOperators_includes": [], 1753*da0073e9SAndroid Build Coastguard Worker "MethodOperators_declarations": list( 1754*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1755*da0073e9SAndroid Build Coastguard Worker ComputeOperators( 1756*da0073e9SAndroid Build Coastguard Worker Target.DECLARATION, 1757*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices=static_dispatch_idx, 1758*da0073e9SAndroid Build Coastguard Worker ), 1759*da0073e9SAndroid Build Coastguard Worker method_native_functions, 1760*da0073e9SAndroid Build Coastguard Worker ) 1761*da0073e9SAndroid Build Coastguard Worker ), 1762*da0073e9SAndroid Build Coastguard Worker }, 1763*da0073e9SAndroid Build Coastguard Worker ) 1764*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 1765*da0073e9SAndroid Build Coastguard Worker "Operators.h", 1766*da0073e9SAndroid Build Coastguard Worker lambda: { 1767*da0073e9SAndroid Build Coastguard Worker "Operators_includes": ["#include <ATen/MethodOperators.h>"], 1768*da0073e9SAndroid Build Coastguard Worker "Operators_declarations": list( 1769*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1770*da0073e9SAndroid Build Coastguard Worker ComputeOperators( 1771*da0073e9SAndroid Build Coastguard Worker Target.DECLARATION, 1772*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices=static_dispatch_idx, 1773*da0073e9SAndroid Build Coastguard Worker ), 1774*da0073e9SAndroid Build Coastguard Worker non_method_native_functions, 1775*da0073e9SAndroid Build Coastguard Worker ) 1776*da0073e9SAndroid Build Coastguard Worker ), 1777*da0073e9SAndroid Build Coastguard Worker }, 1778*da0073e9SAndroid Build Coastguard Worker ) 1779*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 1780*da0073e9SAndroid Build Coastguard Worker "Functions.h", 1781*da0073e9SAndroid Build Coastguard Worker lambda: { 1782*da0073e9SAndroid Build Coastguard Worker "static_dispatch_extra_headers": static_dispatch_extra_headers( 1783*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx 1784*da0073e9SAndroid Build Coastguard Worker ), 1785*da0073e9SAndroid Build Coastguard Worker "Functions_includes": ["#include <ATen/Operators.h>"], 1786*da0073e9SAndroid Build Coastguard Worker "Functions_declarations": list( 1787*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1788*da0073e9SAndroid Build Coastguard Worker ComputeFunction(), 1789*da0073e9SAndroid Build Coastguard Worker native_functions, 1790*da0073e9SAndroid Build Coastguard Worker ) 1791*da0073e9SAndroid Build Coastguard Worker ), 1792*da0073e9SAndroid Build Coastguard Worker }, 1793*da0073e9SAndroid Build Coastguard Worker ) 1794*da0073e9SAndroid Build Coastguard Worker declarations = get_native_function_declarations( 1795*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 1796*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 1797*da0073e9SAndroid Build Coastguard Worker ) 1798*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 1799*da0073e9SAndroid Build Coastguard Worker "NativeFunctions.h", 1800*da0073e9SAndroid Build Coastguard Worker lambda: { 1801*da0073e9SAndroid Build Coastguard Worker "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"], 1802*da0073e9SAndroid Build Coastguard Worker "NativeFunctions_declarations": declarations, 1803*da0073e9SAndroid Build Coastguard Worker }, 1804*da0073e9SAndroid Build Coastguard Worker ) 1805*da0073e9SAndroid Build Coastguard Worker 1806*da0073e9SAndroid Build Coastguard Worker for dispatch_key in dispatch_keys: 1807*da0073e9SAndroid Build Coastguard Worker fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm 1808*da0073e9SAndroid Build Coastguard Worker if dispatch_key in functions_keys: 1809*da0073e9SAndroid Build Coastguard Worker inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>" 1810*da0073e9SAndroid Build Coastguard Worker 1811*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1812*da0073e9SAndroid Build Coastguard Worker f"{dispatch_key}Functions.h", 1813*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunctions.h", 1814*da0073e9SAndroid Build Coastguard Worker lambda: { 1815*da0073e9SAndroid Build Coastguard Worker "dispatch_key": str(dispatch_key), 1816*da0073e9SAndroid Build Coastguard Worker "inline_headers": inl_headers, 1817*da0073e9SAndroid Build Coastguard Worker }, 1818*da0073e9SAndroid Build Coastguard Worker ) 1819*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1820*da0073e9SAndroid Build Coastguard Worker f"{dispatch_key}Functions_inl.h", 1821*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunctions_inl.h", 1822*da0073e9SAndroid Build Coastguard Worker lambda: { 1823*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunctions_inl_includes": [], 1824*da0073e9SAndroid Build Coastguard Worker "dispatch_namespace": dispatch_key.lower(), 1825*da0073e9SAndroid Build Coastguard Worker "dispatch_namespaced_declarations": get_namespaced_declaration( 1826*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 1827*da0073e9SAndroid Build Coastguard Worker dispatch_key=dispatch_key, 1828*da0073e9SAndroid Build Coastguard Worker backend_idx=backend_indices[dispatch_key], 1829*da0073e9SAndroid Build Coastguard Worker selector=selector, 1830*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 1831*da0073e9SAndroid Build Coastguard Worker symint=True, 1832*da0073e9SAndroid Build Coastguard Worker ), 1833*da0073e9SAndroid Build Coastguard Worker }, 1834*da0073e9SAndroid Build Coastguard Worker ) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker del fm 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Workerdef gen_per_operator_headers( 1840*da0073e9SAndroid Build Coastguard Worker *, 1841*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 1842*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 1843*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx: list[BackendIndex], 1844*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder, 1845*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, BackendIndex], 1846*da0073e9SAndroid Build Coastguard Worker cpu_fm: FileManager, 1847*da0073e9SAndroid Build Coastguard Worker cuda_fm: FileManager, 1848*da0073e9SAndroid Build Coastguard Worker ops_fm: FileManager, 1849*da0073e9SAndroid Build Coastguard Worker functions_keys: set[DispatchKey], 1850*da0073e9SAndroid Build Coastguard Worker dispatch_keys: Sequence[DispatchKey], 1851*da0073e9SAndroid Build Coastguard Worker rocm: bool, 1852*da0073e9SAndroid Build Coastguard Worker) -> None: 1853*da0073e9SAndroid Build Coastguard Worker # For CMake builds, split operator declarations into separate headers in 1854*da0073e9SAndroid Build Coastguard Worker # the ATen/ops folder to split up header dependencies 1855*da0073e9SAndroid Build Coastguard Worker functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list) 1856*da0073e9SAndroid Build Coastguard Worker for fn in native_functions: 1857*da0073e9SAndroid Build Coastguard Worker functions_by_root_name[fn.root_name].append(fn) 1858*da0073e9SAndroid Build Coastguard Worker 1859*da0073e9SAndroid Build Coastguard Worker grouped_functions_by_root_name: dict[ 1860*da0073e9SAndroid Build Coastguard Worker str, list[NativeFunction | NativeFunctionsGroup] 1861*da0073e9SAndroid Build Coastguard Worker ] = defaultdict(list) 1862*da0073e9SAndroid Build Coastguard Worker for group in grouped_native_functions: 1863*da0073e9SAndroid Build Coastguard Worker name = group.root_name 1864*da0073e9SAndroid Build Coastguard Worker grouped_functions_by_root_name[name].append(group) 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker for name, functions in functions_by_root_name.items(): 1867*da0073e9SAndroid Build Coastguard Worker ops_fm.write_with_template( 1868*da0073e9SAndroid Build Coastguard Worker f"{name}_ops.h", 1869*da0073e9SAndroid Build Coastguard Worker "Operator.h", 1870*da0073e9SAndroid Build Coastguard Worker lambda: { 1871*da0073e9SAndroid Build Coastguard Worker "declarations": list( 1872*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1873*da0073e9SAndroid Build Coastguard Worker ComputeOperators( 1874*da0073e9SAndroid Build Coastguard Worker Target.DECLARATION, 1875*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices=static_dispatch_idx, 1876*da0073e9SAndroid Build Coastguard Worker ), 1877*da0073e9SAndroid Build Coastguard Worker functions, 1878*da0073e9SAndroid Build Coastguard Worker ) 1879*da0073e9SAndroid Build Coastguard Worker ), 1880*da0073e9SAndroid Build Coastguard Worker }, 1881*da0073e9SAndroid Build Coastguard Worker ) 1882*da0073e9SAndroid Build Coastguard Worker 1883*da0073e9SAndroid Build Coastguard Worker ops_fm.write_with_template( 1884*da0073e9SAndroid Build Coastguard Worker f"{name}.h", 1885*da0073e9SAndroid Build Coastguard Worker "Function.h", 1886*da0073e9SAndroid Build Coastguard Worker lambda: { 1887*da0073e9SAndroid Build Coastguard Worker "static_dispatch_ops_headers": list( 1888*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1889*da0073e9SAndroid Build Coastguard Worker lambda fn: static_dispatch_ops_header( 1890*da0073e9SAndroid Build Coastguard Worker fn, backend_index=static_dispatch_idx 1891*da0073e9SAndroid Build Coastguard Worker ), 1892*da0073e9SAndroid Build Coastguard Worker functions, 1893*da0073e9SAndroid Build Coastguard Worker ) 1894*da0073e9SAndroid Build Coastguard Worker ), 1895*da0073e9SAndroid Build Coastguard Worker "operator_includes": f"#include <ATen/ops/{name}_ops.h>", 1896*da0073e9SAndroid Build Coastguard Worker "function_definitions": list( 1897*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1898*da0073e9SAndroid Build Coastguard Worker ComputeFunction(), 1899*da0073e9SAndroid Build Coastguard Worker functions, 1900*da0073e9SAndroid Build Coastguard Worker ) 1901*da0073e9SAndroid Build Coastguard Worker ), 1902*da0073e9SAndroid Build Coastguard Worker }, 1903*da0073e9SAndroid Build Coastguard Worker ) 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker grouped_functions = grouped_functions_by_root_name.get(name, []) 1906*da0073e9SAndroid Build Coastguard Worker structured_functions = [ 1907*da0073e9SAndroid Build Coastguard Worker fn 1908*da0073e9SAndroid Build Coastguard Worker for fn in grouped_functions 1909*da0073e9SAndroid Build Coastguard Worker if isinstance(fn, NativeFunctionsGroup) and fn.structured 1910*da0073e9SAndroid Build Coastguard Worker ] 1911*da0073e9SAndroid Build Coastguard Worker is_structured = len(structured_functions) > 0 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker if is_structured: 1914*da0073e9SAndroid Build Coastguard Worker ops_fm.write_with_template( 1915*da0073e9SAndroid Build Coastguard Worker f"{name}_meta.h", 1916*da0073e9SAndroid Build Coastguard Worker "NativeMetaFunction.h", 1917*da0073e9SAndroid Build Coastguard Worker lambda: { 1918*da0073e9SAndroid Build Coastguard Worker "meta_function_declarations": list( 1919*da0073e9SAndroid Build Coastguard Worker mapMaybe( 1920*da0073e9SAndroid Build Coastguard Worker compute_meta_function_declaration, structured_functions 1921*da0073e9SAndroid Build Coastguard Worker ) 1922*da0073e9SAndroid Build Coastguard Worker ), 1923*da0073e9SAndroid Build Coastguard Worker }, 1924*da0073e9SAndroid Build Coastguard Worker ) 1925*da0073e9SAndroid Build Coastguard Worker declarations = get_native_function_declarations( 1926*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_functions, 1927*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 1928*da0073e9SAndroid Build Coastguard Worker native_function_decl_gen=dest.compute_native_function_declaration, 1929*da0073e9SAndroid Build Coastguard Worker ) 1930*da0073e9SAndroid Build Coastguard Worker ops_fm.write_with_template( 1931*da0073e9SAndroid Build Coastguard Worker f"{name}_native.h", 1932*da0073e9SAndroid Build Coastguard Worker "NativeFunction.h", 1933*da0073e9SAndroid Build Coastguard Worker lambda: { 1934*da0073e9SAndroid Build Coastguard Worker "extra_includes": ( 1935*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{name}_meta.h>" if is_structured else [] 1936*da0073e9SAndroid Build Coastguard Worker ), 1937*da0073e9SAndroid Build Coastguard Worker "native_function_declarations": declarations, 1938*da0073e9SAndroid Build Coastguard Worker }, 1939*da0073e9SAndroid Build Coastguard Worker ) 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker for category, suffix in [ 1942*da0073e9SAndroid Build Coastguard Worker ("Functions", ""), 1943*da0073e9SAndroid Build Coastguard Worker ("Operators", "_ops"), 1944*da0073e9SAndroid Build Coastguard Worker ("NativeMetaFunctions", "_meta"), 1945*da0073e9SAndroid Build Coastguard Worker ("NativeFunctions", "_native"), 1946*da0073e9SAndroid Build Coastguard Worker ]: 1947*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 1948*da0073e9SAndroid Build Coastguard Worker f"{category}.h", 1949*da0073e9SAndroid Build Coastguard Worker lambda: { 1950*da0073e9SAndroid Build Coastguard Worker f"{category}_includes": [ 1951*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{name}{suffix}.h>" 1952*da0073e9SAndroid Build Coastguard Worker for name in sorted(functions_by_root_name.keys()) 1953*da0073e9SAndroid Build Coastguard Worker ], 1954*da0073e9SAndroid Build Coastguard Worker f"{category}_declarations": [], 1955*da0073e9SAndroid Build Coastguard Worker }, 1956*da0073e9SAndroid Build Coastguard Worker ) 1957*da0073e9SAndroid Build Coastguard Worker 1958*da0073e9SAndroid Build Coastguard Worker for dispatch_key in dispatch_keys: 1959*da0073e9SAndroid Build Coastguard Worker if dispatch_key not in functions_keys: 1960*da0073e9SAndroid Build Coastguard Worker continue 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker dispatch_namespace = dispatch_key.lower() 1963*da0073e9SAndroid Build Coastguard Worker dispatch_names = [] 1964*da0073e9SAndroid Build Coastguard Worker 1965*da0073e9SAndroid Build Coastguard Worker for name, functions in functions_by_root_name.items(): 1966*da0073e9SAndroid Build Coastguard Worker grouped_functions = grouped_functions_by_root_name.get(name, []) 1967*da0073e9SAndroid Build Coastguard Worker declarations = list( 1968*da0073e9SAndroid Build Coastguard Worker concatMap( 1969*da0073e9SAndroid Build Coastguard Worker dest.RegisterDispatchKey( 1970*da0073e9SAndroid Build Coastguard Worker backend_indices[dispatch_key], 1971*da0073e9SAndroid Build Coastguard Worker Target.NAMESPACED_DECLARATION, 1972*da0073e9SAndroid Build Coastguard Worker selector, 1973*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 1974*da0073e9SAndroid Build Coastguard Worker symint=True, 1975*da0073e9SAndroid Build Coastguard Worker class_method_name=None, 1976*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=False, 1977*da0073e9SAndroid Build Coastguard Worker ), 1978*da0073e9SAndroid Build Coastguard Worker grouped_functions, 1979*da0073e9SAndroid Build Coastguard Worker ) 1980*da0073e9SAndroid Build Coastguard Worker ) 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker if len(declarations) == 0: 1983*da0073e9SAndroid Build Coastguard Worker continue 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker dispatch_names.append(name) 1986*da0073e9SAndroid Build Coastguard Worker ops_fm.write_with_template( 1987*da0073e9SAndroid Build Coastguard Worker f"{name}_{dispatch_namespace}_dispatch.h", 1988*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunction.h", 1989*da0073e9SAndroid Build Coastguard Worker lambda: { 1990*da0073e9SAndroid Build Coastguard Worker "dispatch_namespace": dispatch_namespace, 1991*da0073e9SAndroid Build Coastguard Worker "dispatch_namespaced_declarations": declarations, 1992*da0073e9SAndroid Build Coastguard Worker }, 1993*da0073e9SAndroid Build Coastguard Worker ) 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm 1996*da0073e9SAndroid Build Coastguard Worker inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>" 1997*da0073e9SAndroid Build Coastguard Worker 1998*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1999*da0073e9SAndroid Build Coastguard Worker f"{dispatch_key}Functions.h", 2000*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunctions.h", 2001*da0073e9SAndroid Build Coastguard Worker lambda: { 2002*da0073e9SAndroid Build Coastguard Worker "dispatch_key": str(dispatch_key), 2003*da0073e9SAndroid Build Coastguard Worker "inline_headers": inl_headers, 2004*da0073e9SAndroid Build Coastguard Worker }, 2005*da0073e9SAndroid Build Coastguard Worker ) 2006*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 2007*da0073e9SAndroid Build Coastguard Worker f"{dispatch_key}Functions_inl.h", 2008*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunctions_inl.h", 2009*da0073e9SAndroid Build Coastguard Worker lambda: { 2010*da0073e9SAndroid Build Coastguard Worker "dispatch_namespace": dispatch_namespace, 2011*da0073e9SAndroid Build Coastguard Worker "DispatchKeyFunctions_inl_includes": [ 2012*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>" 2013*da0073e9SAndroid Build Coastguard Worker for name in sorted(dispatch_names) 2014*da0073e9SAndroid Build Coastguard Worker ], 2015*da0073e9SAndroid Build Coastguard Worker "dispatch_namespaced_declarations": [], 2016*da0073e9SAndroid Build Coastguard Worker }, 2017*da0073e9SAndroid Build Coastguard Worker ) 2018*da0073e9SAndroid Build Coastguard Worker del fm 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2021*da0073e9SAndroid Build Coastguard Worker "MethodOperators.h", 2022*da0073e9SAndroid Build Coastguard Worker lambda: { 2023*da0073e9SAndroid Build Coastguard Worker "MethodOperators_includes": sorted( 2024*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{name}_ops.h>" 2025*da0073e9SAndroid Build Coastguard Worker for name, functions in functions_by_root_name.items() 2026*da0073e9SAndroid Build Coastguard Worker if any(Variant.method in fn.variants for fn in functions) 2027*da0073e9SAndroid Build Coastguard Worker ), 2028*da0073e9SAndroid Build Coastguard Worker "MethodOperators_declarations": [], 2029*da0073e9SAndroid Build Coastguard Worker }, 2030*da0073e9SAndroid Build Coastguard Worker ) 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Workerdef gen_headers( 2034*da0073e9SAndroid Build Coastguard Worker *, 2035*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 2036*da0073e9SAndroid Build Coastguard Worker valid_tags: set[str], 2037*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 2038*da0073e9SAndroid Build Coastguard Worker structured_native_functions: Sequence[NativeFunctionsGroup], 2039*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx: list[BackendIndex], 2040*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder, 2041*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, BackendIndex], 2042*da0073e9SAndroid Build Coastguard Worker core_fm: FileManager, 2043*da0073e9SAndroid Build Coastguard Worker cpu_fm: FileManager, 2044*da0073e9SAndroid Build Coastguard Worker cuda_fm: FileManager, 2045*da0073e9SAndroid Build Coastguard Worker ops_fm: FileManager, 2046*da0073e9SAndroid Build Coastguard Worker dispatch_keys: Sequence[DispatchKey], 2047*da0073e9SAndroid Build Coastguard Worker functions_keys: set[DispatchKey], 2048*da0073e9SAndroid Build Coastguard Worker rocm: bool, 2049*da0073e9SAndroid Build Coastguard Worker per_operator_headers: bool, 2050*da0073e9SAndroid Build Coastguard Worker) -> None: 2051*da0073e9SAndroid Build Coastguard Worker if per_operator_headers: 2052*da0073e9SAndroid Build Coastguard Worker gen_per_operator_headers( 2053*da0073e9SAndroid Build Coastguard Worker native_functions=native_functions, 2054*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 2055*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx=static_dispatch_idx, 2056*da0073e9SAndroid Build Coastguard Worker selector=selector, 2057*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 2058*da0073e9SAndroid Build Coastguard Worker cpu_fm=cpu_fm, 2059*da0073e9SAndroid Build Coastguard Worker cuda_fm=cuda_fm, 2060*da0073e9SAndroid Build Coastguard Worker ops_fm=ops_fm, 2061*da0073e9SAndroid Build Coastguard Worker dispatch_keys=dispatch_keys, 2062*da0073e9SAndroid Build Coastguard Worker functions_keys=functions_keys, 2063*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 2064*da0073e9SAndroid Build Coastguard Worker ) 2065*da0073e9SAndroid Build Coastguard Worker else: 2066*da0073e9SAndroid Build Coastguard Worker gen_aggregated_headers( 2067*da0073e9SAndroid Build Coastguard Worker native_functions=native_functions, 2068*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 2069*da0073e9SAndroid Build Coastguard Worker structured_native_functions=structured_native_functions, 2070*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx=static_dispatch_idx, 2071*da0073e9SAndroid Build Coastguard Worker selector=selector, 2072*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 2073*da0073e9SAndroid Build Coastguard Worker cpu_fm=cpu_fm, 2074*da0073e9SAndroid Build Coastguard Worker cuda_fm=cuda_fm, 2075*da0073e9SAndroid Build Coastguard Worker dispatch_keys=dispatch_keys, 2076*da0073e9SAndroid Build Coastguard Worker functions_keys=functions_keys, 2077*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 2078*da0073e9SAndroid Build Coastguard Worker ) 2079*da0073e9SAndroid Build Coastguard Worker 2080*da0073e9SAndroid Build Coastguard Worker core_fm.write( 2081*da0073e9SAndroid Build Coastguard Worker "TensorBody.h", 2082*da0073e9SAndroid Build Coastguard Worker lambda: { 2083*da0073e9SAndroid Build Coastguard Worker "tensor_method_declarations": list( 2084*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2085*da0073e9SAndroid Build Coastguard Worker ComputeTensorMethod( 2086*da0073e9SAndroid Build Coastguard Worker target=Target.DECLARATION, 2087*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices=static_dispatch_idx, 2088*da0073e9SAndroid Build Coastguard Worker ), 2089*da0073e9SAndroid Build Coastguard Worker native_functions, 2090*da0073e9SAndroid Build Coastguard Worker ) 2091*da0073e9SAndroid Build Coastguard Worker ), 2092*da0073e9SAndroid Build Coastguard Worker "tensor_method_definitions": list( 2093*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2094*da0073e9SAndroid Build Coastguard Worker ComputeTensorMethod( 2095*da0073e9SAndroid Build Coastguard Worker target=Target.DEFINITION, 2096*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices=static_dispatch_idx, 2097*da0073e9SAndroid Build Coastguard Worker ), 2098*da0073e9SAndroid Build Coastguard Worker native_functions, 2099*da0073e9SAndroid Build Coastguard Worker ) 2100*da0073e9SAndroid Build Coastguard Worker ), 2101*da0073e9SAndroid Build Coastguard Worker }, 2102*da0073e9SAndroid Build Coastguard Worker ) 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2105*da0073e9SAndroid Build Coastguard Worker "RedispatchFunctions.h", 2106*da0073e9SAndroid Build Coastguard Worker lambda: { 2107*da0073e9SAndroid Build Coastguard Worker "function_redispatch_definitions": list( 2108*da0073e9SAndroid Build Coastguard Worker mapMaybe(ComputeRedispatchFunction(), native_functions) 2109*da0073e9SAndroid Build Coastguard Worker ), 2110*da0073e9SAndroid Build Coastguard Worker }, 2111*da0073e9SAndroid Build Coastguard Worker ) 2112*da0073e9SAndroid Build Coastguard Worker 2113*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2114*da0073e9SAndroid Build Coastguard Worker "RegistrationDeclarations.h", 2115*da0073e9SAndroid Build Coastguard Worker lambda: { 2116*da0073e9SAndroid Build Coastguard Worker "registration_declarations": [ 2117*da0073e9SAndroid Build Coastguard Worker compute_registration_declarations(f, backend_indices) 2118*da0073e9SAndroid Build Coastguard Worker for f in native_functions 2119*da0073e9SAndroid Build Coastguard Worker ], 2120*da0073e9SAndroid Build Coastguard Worker }, 2121*da0073e9SAndroid Build Coastguard Worker ) 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2124*da0073e9SAndroid Build Coastguard Worker "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) 2125*da0073e9SAndroid Build Coastguard Worker ) 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker def gen_aten_interned_strings() -> dict[str, str]: 2128*da0073e9SAndroid Build Coastguard Worker attrs: set[str] = set() # All function argument names 2129*da0073e9SAndroid Build Coastguard Worker names = set() # All ATen function names 2130*da0073e9SAndroid Build Coastguard Worker for func in native_functions: 2131*da0073e9SAndroid Build Coastguard Worker names.add(str(func.func.name.name)) 2132*da0073e9SAndroid Build Coastguard Worker # Some operators don't have a functional variant but we still create a 2133*da0073e9SAndroid Build Coastguard Worker # symbol without the underscore 2134*da0073e9SAndroid Build Coastguard Worker names.add(func.func.name.name.base) 2135*da0073e9SAndroid Build Coastguard Worker 2136*da0073e9SAndroid Build Coastguard Worker attrs.update(arg.name for arg in func.func.schema_order_arguments()) 2137*da0073e9SAndroid Build Coastguard Worker 2138*da0073e9SAndroid Build Coastguard Worker # These are keywords in C++, so aren't valid symbol names 2139*da0073e9SAndroid Build Coastguard Worker # https://en.cppreference.com/w/cpp/language/operator_alternative 2140*da0073e9SAndroid Build Coastguard Worker names -= { 2141*da0073e9SAndroid Build Coastguard Worker "and", 2142*da0073e9SAndroid Build Coastguard Worker "and_eq", 2143*da0073e9SAndroid Build Coastguard Worker "bitand", 2144*da0073e9SAndroid Build Coastguard Worker "bitor", 2145*da0073e9SAndroid Build Coastguard Worker "compl", 2146*da0073e9SAndroid Build Coastguard Worker "not", 2147*da0073e9SAndroid Build Coastguard Worker "not_eq", 2148*da0073e9SAndroid Build Coastguard Worker "or", 2149*da0073e9SAndroid Build Coastguard Worker "or_eq", 2150*da0073e9SAndroid Build Coastguard Worker "xor", 2151*da0073e9SAndroid Build Coastguard Worker "xor_eq", 2152*da0073e9SAndroid Build Coastguard Worker } 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker return { 2155*da0073e9SAndroid Build Coastguard Worker "aten_symbols": " \\\n".join( 2156*da0073e9SAndroid Build Coastguard Worker [f"_(aten, {name})" for name in sorted(names)] 2157*da0073e9SAndroid Build Coastguard Worker ), 2158*da0073e9SAndroid Build Coastguard Worker "attr_symbols": " \\\n".join( 2159*da0073e9SAndroid Build Coastguard Worker [f"_(attr, {name})" for name in sorted(attrs)] 2160*da0073e9SAndroid Build Coastguard Worker ), 2161*da0073e9SAndroid Build Coastguard Worker } 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) 2164*da0073e9SAndroid Build Coastguard Worker 2165*da0073e9SAndroid Build Coastguard Worker def gen_tags_enum() -> dict[str, str]: 2166*da0073e9SAndroid Build Coastguard Worker return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} 2167*da0073e9SAndroid Build Coastguard Worker 2168*da0073e9SAndroid Build Coastguard Worker core_fm.write("enum_tag.h", gen_tags_enum) 2169*da0073e9SAndroid Build Coastguard Worker 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Workerdef gen_source_files( 2172*da0073e9SAndroid Build Coastguard Worker *, 2173*da0073e9SAndroid Build Coastguard Worker native_functions: Sequence[NativeFunction], 2174*da0073e9SAndroid Build Coastguard Worker grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 2175*da0073e9SAndroid Build Coastguard Worker structured_native_functions: Sequence[NativeFunctionsGroup], 2176*da0073e9SAndroid Build Coastguard Worker view_groups: Sequence[NativeFunctionsViewGroup], 2177*da0073e9SAndroid Build Coastguard Worker selector: SelectiveBuilder, 2178*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx: list[BackendIndex], 2179*da0073e9SAndroid Build Coastguard Worker backend_indices: dict[DispatchKey, BackendIndex], 2180*da0073e9SAndroid Build Coastguard Worker aoti_fm: FileManager, 2181*da0073e9SAndroid Build Coastguard Worker core_fm: FileManager, 2182*da0073e9SAndroid Build Coastguard Worker cpu_fm: FileManager, 2183*da0073e9SAndroid Build Coastguard Worker cpu_vec_fm: FileManager, 2184*da0073e9SAndroid Build Coastguard Worker cuda_fm: FileManager, 2185*da0073e9SAndroid Build Coastguard Worker dispatch_keys: Sequence[DispatchKey], 2186*da0073e9SAndroid Build Coastguard Worker functions_keys: set[DispatchKey], 2187*da0073e9SAndroid Build Coastguard Worker rocm: bool, 2188*da0073e9SAndroid Build Coastguard Worker force_schema_registration: bool, 2189*da0073e9SAndroid Build Coastguard Worker per_operator_headers: bool, 2190*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration: bool, 2191*da0073e9SAndroid Build Coastguard Worker update_aoti_c_shim: bool, 2192*da0073e9SAndroid Build Coastguard Worker) -> None: 2193*da0073e9SAndroid Build Coastguard Worker extra_cuda_headers = """\ 2194*da0073e9SAndroid Build Coastguard Worker#include <c10/cuda/CUDAGuard.h> 2195*da0073e9SAndroid Build Coastguard Worker#include <ATen/cuda/ATenCUDAGeneral.h> 2196*da0073e9SAndroid Build Coastguard Worker#include <ATen/cuda/CUDADevice.h> 2197*da0073e9SAndroid Build Coastguard Worker#include <ATen/cuda/CUDAContext.h>""" 2198*da0073e9SAndroid Build Coastguard Worker if rocm: 2199*da0073e9SAndroid Build Coastguard Worker extra_cuda_headers = """\ 2200*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> 2201*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/ATenHIPGeneral.h> 2202*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/HIPDevice.h> 2203*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/HIPContext.h>""" 2204*da0073e9SAndroid Build Coastguard Worker 2205*da0073e9SAndroid Build Coastguard Worker for dispatch_key in dispatch_keys: 2206*da0073e9SAndroid Build Coastguard Worker fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm 2207*da0073e9SAndroid Build Coastguard Worker 2208*da0073e9SAndroid Build Coastguard Worker if per_operator_headers: 2209*da0073e9SAndroid Build Coastguard Worker 2210*da0073e9SAndroid Build Coastguard Worker def operator_headers() -> list[str]: 2211*da0073e9SAndroid Build Coastguard Worker headers = [] 2212*da0073e9SAndroid Build Coastguard Worker for g in grouped_native_functions: 2213*da0073e9SAndroid Build Coastguard Worker is_registered = False 2214*da0073e9SAndroid Build Coastguard Worker if backend_index.has_kernel(g): 2215*da0073e9SAndroid Build Coastguard Worker is_registered = True 2216*da0073e9SAndroid Build Coastguard Worker # The above has_kernel test on a group will only test for 2217*da0073e9SAndroid Build Coastguard Worker # the existence of out dispatch, because that's how 2218*da0073e9SAndroid Build Coastguard Worker # structured kernels work. But sometimes functions can be 2219*da0073e9SAndroid Build Coastguard Worker # grouped but not be structured, and then you need to check 2220*da0073e9SAndroid Build Coastguard Worker # each individual piece, as they may have manual dispatch 2221*da0073e9SAndroid Build Coastguard Worker # entries. 2222*da0073e9SAndroid Build Coastguard Worker elif isinstance(g, NativeFunctionsGroup) and any( 2223*da0073e9SAndroid Build Coastguard Worker backend_index.has_kernel(fn) for fn in g.functions() 2224*da0073e9SAndroid Build Coastguard Worker ): 2225*da0073e9SAndroid Build Coastguard Worker is_registered = True 2226*da0073e9SAndroid Build Coastguard Worker # TODO: this condition is a bit questionable 2227*da0073e9SAndroid Build Coastguard Worker # (It has to do with the fact that structured kernels get generated kernels 2228*da0073e9SAndroid Build Coastguard Worker # to the Meta + CompositeExplicitAutogradNonFunctional keys). 2229*da0073e9SAndroid Build Coastguard Worker elif g.structured and dispatch_key in ( 2230*da0073e9SAndroid Build Coastguard Worker DispatchKey.Meta, 2231*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional, 2232*da0073e9SAndroid Build Coastguard Worker ): 2233*da0073e9SAndroid Build Coastguard Worker is_registered = True 2234*da0073e9SAndroid Build Coastguard Worker if not is_registered: 2235*da0073e9SAndroid Build Coastguard Worker continue 2236*da0073e9SAndroid Build Coastguard Worker 2237*da0073e9SAndroid Build Coastguard Worker headers.append(f"#include <ATen/ops/{g.root_name}_native.h>") 2238*da0073e9SAndroid Build Coastguard Worker if ( 2239*da0073e9SAndroid Build Coastguard Worker dispatch_key 2240*da0073e9SAndroid Build Coastguard Worker == DispatchKey.CompositeExplicitAutogradNonFunctional 2241*da0073e9SAndroid Build Coastguard Worker ): 2242*da0073e9SAndroid Build Coastguard Worker headers.append(f"#include <ATen/ops/{g.root_name}.h>") 2243*da0073e9SAndroid Build Coastguard Worker if dispatch_key in functions_keys: 2244*da0073e9SAndroid Build Coastguard Worker headers.append( 2245*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>" 2246*da0073e9SAndroid Build Coastguard Worker ) 2247*da0073e9SAndroid Build Coastguard Worker 2248*da0073e9SAndroid Build Coastguard Worker return sorted(set(headers)) 2249*da0073e9SAndroid Build Coastguard Worker 2250*da0073e9SAndroid Build Coastguard Worker else: 2251*da0073e9SAndroid Build Coastguard Worker 2252*da0073e9SAndroid Build Coastguard Worker def operator_headers() -> list[str]: 2253*da0073e9SAndroid Build Coastguard Worker headers = ["#include <ATen/NativeFunctions.h>"] 2254*da0073e9SAndroid Build Coastguard Worker if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: 2255*da0073e9SAndroid Build Coastguard Worker headers.append("#include <ATen/Functions.h>") 2256*da0073e9SAndroid Build Coastguard Worker if dispatch_key in functions_keys: 2257*da0073e9SAndroid Build Coastguard Worker headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>") 2258*da0073e9SAndroid Build Coastguard Worker return headers 2259*da0073e9SAndroid Build Coastguard Worker 2260*da0073e9SAndroid Build Coastguard Worker backend_index = backend_indices[dispatch_key] 2261*da0073e9SAndroid Build Coastguard Worker ns_grouped_native_functions = defaultdict(list) 2262*da0073e9SAndroid Build Coastguard Worker for grouped_native_function in grouped_native_functions: 2263*da0073e9SAndroid Build Coastguard Worker namespace = ( 2264*da0073e9SAndroid Build Coastguard Worker grouped_native_function.namespace 2265*da0073e9SAndroid Build Coastguard Worker if isinstance(grouped_native_function, NativeFunction) 2266*da0073e9SAndroid Build Coastguard Worker else grouped_native_function.functional.namespace 2267*da0073e9SAndroid Build Coastguard Worker ) 2268*da0073e9SAndroid Build Coastguard Worker ns_grouped_native_functions[namespace].append(grouped_native_function) 2269*da0073e9SAndroid Build Coastguard Worker 2270*da0073e9SAndroid Build Coastguard Worker dispatch_namespace = str(dispatch_key).lower() 2271*da0073e9SAndroid Build Coastguard Worker 2272*da0073e9SAndroid Build Coastguard Worker # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated 2273*da0073e9SAndroid Build Coastguard Worker # compilation will fail when `-Werror=unused-function` flag is set 2274*da0073e9SAndroid Build Coastguard Worker gen_dispatch_helpers: bool = ( 2275*da0073e9SAndroid Build Coastguard Worker dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor 2276*da0073e9SAndroid Build Coastguard Worker ) 2277*da0073e9SAndroid Build Coastguard Worker 2278*da0073e9SAndroid Build Coastguard Worker dispatch_definitions = get_native_function_definitions( 2279*da0073e9SAndroid Build Coastguard Worker fm=fm, 2280*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 2281*da0073e9SAndroid Build Coastguard Worker dispatch_key=dispatch_key, 2282*da0073e9SAndroid Build Coastguard Worker backend_idx=backend_index, 2283*da0073e9SAndroid Build Coastguard Worker selector=selector, 2284*da0073e9SAndroid Build Coastguard Worker rocm=rocm, 2285*da0073e9SAndroid Build Coastguard Worker symint=True, 2286*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=skip_dispatcher_op_registration, 2287*da0073e9SAndroid Build Coastguard Worker gen_dispatch_helpers=gen_dispatch_helpers, 2288*da0073e9SAndroid Build Coastguard Worker ) 2289*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 2290*da0073e9SAndroid Build Coastguard Worker f"Register{dispatch_key}.cpp", 2291*da0073e9SAndroid Build Coastguard Worker "RegisterDispatchKey.cpp", 2292*da0073e9SAndroid Build Coastguard Worker lambda: { 2293*da0073e9SAndroid Build Coastguard Worker "extra_cuda_headers": extra_cuda_headers 2294*da0073e9SAndroid Build Coastguard Worker if is_cuda_dispatch_key(dispatch_key) 2295*da0073e9SAndroid Build Coastguard Worker else "", 2296*da0073e9SAndroid Build Coastguard Worker "external_backend_headers": "", 2297*da0073e9SAndroid Build Coastguard Worker "dispatch_headers": dest.gen_registration_headers( 2298*da0073e9SAndroid Build Coastguard Worker backend_index, per_operator_headers, rocm 2299*da0073e9SAndroid Build Coastguard Worker ), 2300*da0073e9SAndroid Build Coastguard Worker "ops_headers": operator_headers(), 2301*da0073e9SAndroid Build Coastguard Worker "dispatch_helpers": "", 2302*da0073e9SAndroid Build Coastguard Worker "dispatch_definitions": dispatch_definitions, 2303*da0073e9SAndroid Build Coastguard Worker }, 2304*da0073e9SAndroid Build Coastguard Worker ) 2305*da0073e9SAndroid Build Coastguard Worker 2306*da0073e9SAndroid Build Coastguard Worker for g in structured_native_functions: 2307*da0073e9SAndroid Build Coastguard Worker if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key): 2308*da0073e9SAndroid Build Coastguard Worker continue 2309*da0073e9SAndroid Build Coastguard Worker name = g.functional.func.name.name 2310*da0073e9SAndroid Build Coastguard Worker if dispatch_key is DispatchKey.CPU: 2311*da0073e9SAndroid Build Coastguard Worker assert fm is cpu_fm 2312*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 2313*da0073e9SAndroid Build Coastguard Worker f"UfuncCPU_{name}.cpp", 2314*da0073e9SAndroid Build Coastguard Worker "UfuncCPU.cpp", 2315*da0073e9SAndroid Build Coastguard Worker lambda: { 2316*da0073e9SAndroid Build Coastguard Worker "meta_declaration": compute_meta_function_declaration(g), 2317*da0073e9SAndroid Build Coastguard Worker "native_declaration": dest.compute_native_function_declaration( 2318*da0073e9SAndroid Build Coastguard Worker g, backend_indices[dispatch_key] 2319*da0073e9SAndroid Build Coastguard Worker ), 2320*da0073e9SAndroid Build Coastguard Worker "native_definitions": dest.compute_ufunc_cpu(g), 2321*da0073e9SAndroid Build Coastguard Worker }, 2322*da0073e9SAndroid Build Coastguard Worker ) 2323*da0073e9SAndroid Build Coastguard Worker cpu_vec_fm.write_with_template( 2324*da0073e9SAndroid Build Coastguard Worker f"UfuncCPUKernel_{name}.cpp", 2325*da0073e9SAndroid Build Coastguard Worker "UfuncCPUKernel.cpp", 2326*da0073e9SAndroid Build Coastguard Worker lambda: { 2327*da0073e9SAndroid Build Coastguard Worker "name": name, 2328*da0073e9SAndroid Build Coastguard Worker "native_definitions": dest.compute_ufunc_cpu_kernel(g), 2329*da0073e9SAndroid Build Coastguard Worker }, 2330*da0073e9SAndroid Build Coastguard Worker ) 2331*da0073e9SAndroid Build Coastguard Worker elif dispatch_key is DispatchKey.CUDA: 2332*da0073e9SAndroid Build Coastguard Worker cuda_headers = "#include <ATen/native/cuda/Loops.cuh>" 2333*da0073e9SAndroid Build Coastguard Worker if rocm: 2334*da0073e9SAndroid Build Coastguard Worker cuda_headers = "#include <ATen/native/hip/Loops.cuh>" 2335*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 2336*da0073e9SAndroid Build Coastguard Worker f"UfuncCUDA_{name}.cu", 2337*da0073e9SAndroid Build Coastguard Worker "UfuncCUDA.cu", 2338*da0073e9SAndroid Build Coastguard Worker lambda: { 2339*da0073e9SAndroid Build Coastguard Worker "name": name, 2340*da0073e9SAndroid Build Coastguard Worker "cuda_headers": cuda_headers, 2341*da0073e9SAndroid Build Coastguard Worker "meta_declaration": compute_meta_function_declaration(g), 2342*da0073e9SAndroid Build Coastguard Worker "native_declaration": dest.compute_native_function_declaration( 2343*da0073e9SAndroid Build Coastguard Worker g, backend_indices[dispatch_key] 2344*da0073e9SAndroid Build Coastguard Worker ), 2345*da0073e9SAndroid Build Coastguard Worker "native_definitions": dest.compute_ufunc_cuda(g), 2346*da0073e9SAndroid Build Coastguard Worker }, 2347*da0073e9SAndroid Build Coastguard Worker ) 2348*da0073e9SAndroid Build Coastguard Worker else: 2349*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized {dispatch_key} for ufunc") 2350*da0073e9SAndroid Build Coastguard Worker 2351*da0073e9SAndroid Build Coastguard Worker structured_func_group_dict = {} 2352*da0073e9SAndroid Build Coastguard Worker for func_group in structured_native_functions: 2353*da0073e9SAndroid Build Coastguard Worker for func in func_group.functions(): 2354*da0073e9SAndroid Build Coastguard Worker if func.structured_delegate is not None: 2355*da0073e9SAndroid Build Coastguard Worker structured_func_group_dict[func.structured_delegate] = func_group 2356*da0073e9SAndroid Build Coastguard Worker break 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA): 2359*da0073e9SAndroid Build Coastguard Worker fallbacks = {} 2360*da0073e9SAndroid Build Coastguard Worker for func in native_functions: 2361*da0073e9SAndroid Build Coastguard Worker op_name = get_fallback_op_name(func) 2362*da0073e9SAndroid Build Coastguard Worker if op_name in inductor_fallback_ops: 2363*da0073e9SAndroid Build Coastguard Worker fallbacks[op_name] = func 2364*da0073e9SAndroid Build Coastguard Worker fallback_native_functions = tuple( 2365*da0073e9SAndroid Build Coastguard Worker value for _, value in sorted(fallbacks.items()) 2366*da0073e9SAndroid Build Coastguard Worker ) 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker # header files were checked in for ABI-compatiblilty checking 2369*da0073e9SAndroid Build Coastguard Worker header_file_name = f"c_shim_{dispatch_key.lower()}.h" 2370*da0073e9SAndroid Build Coastguard Worker new_header = gen_aoti_c_shim( 2371*da0073e9SAndroid Build Coastguard Worker fallback_native_functions, 2372*da0073e9SAndroid Build Coastguard Worker structured_func_group_dict, 2373*da0073e9SAndroid Build Coastguard Worker dispatch_key, 2374*da0073e9SAndroid Build Coastguard Worker backend_indices, 2375*da0073e9SAndroid Build Coastguard Worker header=True, 2376*da0073e9SAndroid Build Coastguard Worker includes="", 2377*da0073e9SAndroid Build Coastguard Worker ) 2378*da0073e9SAndroid Build Coastguard Worker if update_aoti_c_shim: 2379*da0073e9SAndroid Build Coastguard Worker aoti_fm.write( 2380*da0073e9SAndroid Build Coastguard Worker header_file_name, 2381*da0073e9SAndroid Build Coastguard Worker lambda: new_header, 2382*da0073e9SAndroid Build Coastguard Worker ) 2383*da0073e9SAndroid Build Coastguard Worker else: 2384*da0073e9SAndroid Build Coastguard Worker try: 2385*da0073e9SAndroid Build Coastguard Worker with open( 2386*da0073e9SAndroid Build Coastguard Worker os.path.join(aoti_fm.install_dir, header_file_name) 2387*da0073e9SAndroid Build Coastguard Worker ) as old_file: 2388*da0073e9SAndroid Build Coastguard Worker old_header = old_file.read() 2389*da0073e9SAndroid Build Coastguard Worker assert ( 2390*da0073e9SAndroid Build Coastguard Worker old_header == new_header 2391*da0073e9SAndroid Build Coastguard Worker ), """ 2392*da0073e9SAndroid Build Coastguard Worker 2393*da0073e9SAndroid Build Coastguard WorkerWARNING: The generated AOTInductor C shim header files have unexpectedly changed. This 2394*da0073e9SAndroid Build Coastguard Workerindicates an AOTInductor fallback operator ABI backward compatibility breakage!!! 2395*da0073e9SAndroid Build Coastguard WorkerOnly in a limited number of situations, this is allowed: 2396*da0073e9SAndroid Build Coastguard Worker 2397*da0073e9SAndroid Build Coastguard Worker1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py. 2398*da0073e9SAndroid Build Coastguard WorkerIf that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing 2399*da0073e9SAndroid Build Coastguard WorkerC shim header files. 2400*da0073e9SAndroid Build Coastguard Worker 2401*da0073e9SAndroid Build Coastguard Worker2. You added a new default argument to an existing fallback op. This is clearly a BC breaking 2402*da0073e9SAndroid Build Coastguard Workerchange in the AOTInductor land. In this case, you need to keep a manual copy of that existing 2403*da0073e9SAndroid Build Coastguard Workerfallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version 2404*da0073e9SAndroid Build Coastguard Workernumber of that fallback op in the newly generated C shim files, and update the cpp wrapper 2405*da0073e9SAndroid Build Coastguard Workercodegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance. 2406*da0073e9SAndroid Build Coastguard Worker 2407*da0073e9SAndroid Build Coastguard Worker """ 2408*da0073e9SAndroid Build Coastguard Worker except FileNotFoundError: 2409*da0073e9SAndroid Build Coastguard Worker print( 2410*da0073e9SAndroid Build Coastguard Worker f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found" 2411*da0073e9SAndroid Build Coastguard Worker ) 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker # cpp files are always generated on-the-fly 2414*da0073e9SAndroid Build Coastguard Worker def headers_for_aoti() -> str: 2415*da0073e9SAndroid Build Coastguard Worker headers = [] 2416*da0073e9SAndroid Build Coastguard Worker for func in fallback_native_functions: 2417*da0073e9SAndroid Build Coastguard Worker header = get_header_for_aoti( 2418*da0073e9SAndroid Build Coastguard Worker func, structured_func_group_dict, dispatch_key, backend_indices 2419*da0073e9SAndroid Build Coastguard Worker ) 2420*da0073e9SAndroid Build Coastguard Worker if header is not None: 2421*da0073e9SAndroid Build Coastguard Worker headers.append(header) 2422*da0073e9SAndroid Build Coastguard Worker return "\n".join(sorted(set(headers))) 2423*da0073e9SAndroid Build Coastguard Worker 2424*da0073e9SAndroid Build Coastguard Worker extra_headers = ( 2425*da0073e9SAndroid Build Coastguard Worker extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" 2426*da0073e9SAndroid Build Coastguard Worker ) 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker aoti_fm.write( 2429*da0073e9SAndroid Build Coastguard Worker f"c_shim_{dispatch_key.lower()}.cpp", 2430*da0073e9SAndroid Build Coastguard Worker lambda: gen_aoti_c_shim( 2431*da0073e9SAndroid Build Coastguard Worker fallback_native_functions, 2432*da0073e9SAndroid Build Coastguard Worker structured_func_group_dict, 2433*da0073e9SAndroid Build Coastguard Worker dispatch_key, 2434*da0073e9SAndroid Build Coastguard Worker backend_indices, 2435*da0073e9SAndroid Build Coastguard Worker header=False, 2436*da0073e9SAndroid Build Coastguard Worker includes=headers_for_aoti() + "\n" + extra_headers, 2437*da0073e9SAndroid Build Coastguard Worker ), 2438*da0073e9SAndroid Build Coastguard Worker ) 2439*da0073e9SAndroid Build Coastguard Worker 2440*da0073e9SAndroid Build Coastguard Worker del fm 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker # BackendSelect is generated specially 2443*da0073e9SAndroid Build Coastguard Worker def gen_backend_select() -> dict[str, list[str]]: 2444*da0073e9SAndroid Build Coastguard Worker relevant_fns = [ 2445*da0073e9SAndroid Build Coastguard Worker fn for fn in native_functions if needs_backend_select(fn, selector) 2446*da0073e9SAndroid Build Coastguard Worker ] 2447*da0073e9SAndroid Build Coastguard Worker return { 2448*da0073e9SAndroid Build Coastguard Worker "ops_headers": [ 2449*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns 2450*da0073e9SAndroid Build Coastguard Worker ], 2451*da0073e9SAndroid Build Coastguard Worker "backend_select_method_definitions": list( 2452*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2453*da0073e9SAndroid Build Coastguard Worker ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns 2454*da0073e9SAndroid Build Coastguard Worker ) 2455*da0073e9SAndroid Build Coastguard Worker ), 2456*da0073e9SAndroid Build Coastguard Worker "backend_select_function_registrations": list( 2457*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2458*da0073e9SAndroid Build Coastguard Worker ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns 2459*da0073e9SAndroid Build Coastguard Worker ) 2460*da0073e9SAndroid Build Coastguard Worker ), 2461*da0073e9SAndroid Build Coastguard Worker } 2462*da0073e9SAndroid Build Coastguard Worker 2463*da0073e9SAndroid Build Coastguard Worker cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select) 2464*da0073e9SAndroid Build Coastguard Worker 2465*da0073e9SAndroid Build Coastguard Worker schema_selector = selector 2466*da0073e9SAndroid Build Coastguard Worker if force_schema_registration: 2467*da0073e9SAndroid Build Coastguard Worker schema_selector = SelectiveBuilder.get_nop_selector() 2468*da0073e9SAndroid Build Coastguard Worker 2469*da0073e9SAndroid Build Coastguard Worker ( 2470*da0073e9SAndroid Build Coastguard Worker aten_schema_registrations, 2471*da0073e9SAndroid Build Coastguard Worker schema_registrations, 2472*da0073e9SAndroid Build Coastguard Worker ) = get_native_function_schema_registrations( 2473*da0073e9SAndroid Build Coastguard Worker native_functions=native_functions, schema_selector=schema_selector 2474*da0073e9SAndroid Build Coastguard Worker ) 2475*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2476*da0073e9SAndroid Build Coastguard Worker "RegisterSchema.cpp", 2477*da0073e9SAndroid Build Coastguard Worker lambda: { 2478*da0073e9SAndroid Build Coastguard Worker "aten_schema_registrations": [] 2479*da0073e9SAndroid Build Coastguard Worker if skip_dispatcher_op_registration 2480*da0073e9SAndroid Build Coastguard Worker else aten_schema_registrations, 2481*da0073e9SAndroid Build Coastguard Worker "schema_registrations": [] 2482*da0073e9SAndroid Build Coastguard Worker if skip_dispatcher_op_registration 2483*da0073e9SAndroid Build Coastguard Worker else schema_registrations, 2484*da0073e9SAndroid Build Coastguard Worker }, 2485*da0073e9SAndroid Build Coastguard Worker ) 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker def key_func( 2488*da0073e9SAndroid Build Coastguard Worker fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 2489*da0073e9SAndroid Build Coastguard Worker ) -> str: 2490*da0073e9SAndroid Build Coastguard Worker return fn.root_name 2491*da0073e9SAndroid Build Coastguard Worker 2492*da0073e9SAndroid Build Coastguard Worker cpu_fm.write_sharded( 2493*da0073e9SAndroid Build Coastguard Worker "Operators.cpp", 2494*da0073e9SAndroid Build Coastguard Worker native_functions, 2495*da0073e9SAndroid Build Coastguard Worker key_fn=key_func, 2496*da0073e9SAndroid Build Coastguard Worker env_callable=lambda fn: { 2497*da0073e9SAndroid Build Coastguard Worker "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"], 2498*da0073e9SAndroid Build Coastguard Worker "definitions": [ 2499*da0073e9SAndroid Build Coastguard Worker ComputeOperators( 2500*da0073e9SAndroid Build Coastguard Worker Target.DEFINITION, 2501*da0073e9SAndroid Build Coastguard Worker static_dispatch_backend_indices=static_dispatch_idx, 2502*da0073e9SAndroid Build Coastguard Worker )(fn) 2503*da0073e9SAndroid Build Coastguard Worker ], 2504*da0073e9SAndroid Build Coastguard Worker }, 2505*da0073e9SAndroid Build Coastguard Worker base_env={ 2506*da0073e9SAndroid Build Coastguard Worker "static_dispatch_extra_headers": static_dispatch_extra_headers( 2507*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx 2508*da0073e9SAndroid Build Coastguard Worker ), 2509*da0073e9SAndroid Build Coastguard Worker }, 2510*da0073e9SAndroid Build Coastguard Worker num_shards=5, 2511*da0073e9SAndroid Build Coastguard Worker sharded_keys={ 2512*da0073e9SAndroid Build Coastguard Worker "operator_headers", 2513*da0073e9SAndroid Build Coastguard Worker "definitions", 2514*da0073e9SAndroid Build Coastguard Worker "static_dispatch_extra_headers", 2515*da0073e9SAndroid Build Coastguard Worker }, 2516*da0073e9SAndroid Build Coastguard Worker ) 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker cpu_fm.write("Functions.cpp", dict) 2519*da0073e9SAndroid Build Coastguard Worker 2520*da0073e9SAndroid Build Coastguard Worker core_fm.write("TensorMethods.cpp", dict) 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker core_fm.write( 2523*da0073e9SAndroid Build Coastguard Worker "ATenOpList.cpp", 2524*da0073e9SAndroid Build Coastguard Worker lambda: { 2525*da0073e9SAndroid Build Coastguard Worker "aten_ops": list(mapMaybe(compute_aten_op, native_functions)), 2526*da0073e9SAndroid Build Coastguard Worker }, 2527*da0073e9SAndroid Build Coastguard Worker ) 2528*da0073e9SAndroid Build Coastguard Worker 2529*da0073e9SAndroid Build Coastguard Worker def functionalization_env_callable( 2530*da0073e9SAndroid Build Coastguard Worker g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 2531*da0073e9SAndroid Build Coastguard Worker ) -> dict[str, list[str]]: 2532*da0073e9SAndroid Build Coastguard Worker def gen_op_headers( 2533*da0073e9SAndroid Build Coastguard Worker g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 2534*da0073e9SAndroid Build Coastguard Worker ) -> list[str]: 2535*da0073e9SAndroid Build Coastguard Worker if isinstance(g, NativeFunctionsViewGroup): 2536*da0073e9SAndroid Build Coastguard Worker # view ops always get a functionalization kernel 2537*da0073e9SAndroid Build Coastguard Worker headers = [ 2538*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.view.root_name}_native.h>", 2539*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.view.root_name}_ops.h>", 2540*da0073e9SAndroid Build Coastguard Worker ] 2541*da0073e9SAndroid Build Coastguard Worker if g.view_copy is not None: 2542*da0073e9SAndroid Build Coastguard Worker headers += [ 2543*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.view_copy.root_name}_native.h>", 2544*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>", 2545*da0073e9SAndroid Build Coastguard Worker ] 2546*da0073e9SAndroid Build Coastguard Worker return headers 2547*da0073e9SAndroid Build Coastguard Worker elif isinstance(g, NativeFunctionsGroup): 2548*da0073e9SAndroid Build Coastguard Worker headers = [ 2549*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.functional.root_name}_native.h>", 2550*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.functional.root_name}_ops.h>", 2551*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.out.root_name}_native.h>", 2552*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.out.root_name}_ops.h>", 2553*da0073e9SAndroid Build Coastguard Worker ] 2554*da0073e9SAndroid Build Coastguard Worker if g.inplace is not None: 2555*da0073e9SAndroid Build Coastguard Worker headers += [ 2556*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.inplace.root_name}_native.h>", 2557*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.inplace.root_name}_ops.h>", 2558*da0073e9SAndroid Build Coastguard Worker ] 2559*da0073e9SAndroid Build Coastguard Worker if g.mutable is not None: 2560*da0073e9SAndroid Build Coastguard Worker headers += [ 2561*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.mutable.root_name}_native.h>", 2562*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.mutable.root_name}_ops.h>", 2563*da0073e9SAndroid Build Coastguard Worker ] 2564*da0073e9SAndroid Build Coastguard Worker return headers 2565*da0073e9SAndroid Build Coastguard Worker else: 2566*da0073e9SAndroid Build Coastguard Worker return [ 2567*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.root_name}_native.h>", 2568*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{g.root_name}_ops.h>", 2569*da0073e9SAndroid Build Coastguard Worker ] 2570*da0073e9SAndroid Build Coastguard Worker 2571*da0073e9SAndroid Build Coastguard Worker return { 2572*da0073e9SAndroid Build Coastguard Worker "ops_headers": gen_op_headers(g), 2573*da0073e9SAndroid Build Coastguard Worker "func_definitions": gen_functionalization_definition( 2574*da0073e9SAndroid Build Coastguard Worker selector, 2575*da0073e9SAndroid Build Coastguard Worker g, 2576*da0073e9SAndroid Build Coastguard Worker ), 2577*da0073e9SAndroid Build Coastguard Worker "func_registrations": gen_functionalization_registration( 2578*da0073e9SAndroid Build Coastguard Worker selector, 2579*da0073e9SAndroid Build Coastguard Worker g, 2580*da0073e9SAndroid Build Coastguard Worker backend_indices[DispatchKey.CompositeImplicitAutograd], 2581*da0073e9SAndroid Build Coastguard Worker ), 2582*da0073e9SAndroid Build Coastguard Worker } 2583*da0073e9SAndroid Build Coastguard Worker 2584*da0073e9SAndroid Build Coastguard Worker all_groups: list[ 2585*da0073e9SAndroid Build Coastguard Worker NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup 2586*da0073e9SAndroid Build Coastguard Worker ] = list(structured_native_functions) + list( 2587*da0073e9SAndroid Build Coastguard Worker view_groups # type: ignore[assignment, arg-type, operator] 2588*da0073e9SAndroid Build Coastguard Worker ) 2589*da0073e9SAndroid Build Coastguard Worker # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly. 2590*da0073e9SAndroid Build Coastguard Worker # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because: 2591*da0073e9SAndroid Build Coastguard Worker # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) 2592*da0073e9SAndroid Build Coastguard Worker # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. 2593*da0073e9SAndroid Build Coastguard Worker # Although this could go away long-term if we add a dedicated dispatch key for decompositions. 2594*da0073e9SAndroid Build Coastguard Worker structured_map: dict[OperatorName, NativeFunction] = { 2595*da0073e9SAndroid Build Coastguard Worker f.func.name: f 2596*da0073e9SAndroid Build Coastguard Worker for f in concatMap(lambda g: list(g.functions()), structured_native_functions) 2597*da0073e9SAndroid Build Coastguard Worker } 2598*da0073e9SAndroid Build Coastguard Worker view_map: dict[OperatorName, NativeFunction] = { 2599*da0073e9SAndroid Build Coastguard Worker f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) 2600*da0073e9SAndroid Build Coastguard Worker } 2601*da0073e9SAndroid Build Coastguard Worker for f in native_functions: 2602*da0073e9SAndroid Build Coastguard Worker if f.func.name not in structured_map and f.func.name not in view_map: 2603*da0073e9SAndroid Build Coastguard Worker all_groups.append(f) 2604*da0073e9SAndroid Build Coastguard Worker 2605*da0073e9SAndroid Build Coastguard Worker cpu_fm.write_sharded( 2606*da0073e9SAndroid Build Coastguard Worker "RegisterFunctionalization.cpp", 2607*da0073e9SAndroid Build Coastguard Worker all_groups, 2608*da0073e9SAndroid Build Coastguard Worker key_fn=key_func, 2609*da0073e9SAndroid Build Coastguard Worker env_callable=functionalization_env_callable, 2610*da0073e9SAndroid Build Coastguard Worker num_shards=4, 2611*da0073e9SAndroid Build Coastguard Worker sharded_keys={ 2612*da0073e9SAndroid Build Coastguard Worker "ops_headers", 2613*da0073e9SAndroid Build Coastguard Worker "func_definitions", 2614*da0073e9SAndroid Build Coastguard Worker "func_registrations", 2615*da0073e9SAndroid Build Coastguard Worker "func_add_back_views_definitions", 2616*da0073e9SAndroid Build Coastguard Worker "func_add_back_views_registrations", 2617*da0073e9SAndroid Build Coastguard Worker }, 2618*da0073e9SAndroid Build Coastguard Worker ) 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2621*da0073e9SAndroid Build Coastguard Worker "FunctionalInverses.h", 2622*da0073e9SAndroid Build Coastguard Worker lambda: { 2623*da0073e9SAndroid Build Coastguard Worker "view_inverse_declarations": list( 2624*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2625*da0073e9SAndroid Build Coastguard Worker lambda g: gen_functionalization_view_inverse_declaration( 2626*da0073e9SAndroid Build Coastguard Worker selector, g 2627*da0073e9SAndroid Build Coastguard Worker ), 2628*da0073e9SAndroid Build Coastguard Worker view_groups, 2629*da0073e9SAndroid Build Coastguard Worker ) 2630*da0073e9SAndroid Build Coastguard Worker ) 2631*da0073e9SAndroid Build Coastguard Worker }, 2632*da0073e9SAndroid Build Coastguard Worker ) 2633*da0073e9SAndroid Build Coastguard Worker 2634*da0073e9SAndroid Build Coastguard Worker # Note [view_copy NativeFunctions] 2635*da0073e9SAndroid Build Coastguard Worker # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd 2636*da0073e9SAndroid Build Coastguard Worker # needs to have a corresponding non-aliasing {view}_copy variant. 2637*da0073e9SAndroid Build Coastguard Worker # Backends that use functionalization and don't know how to handle aliasing ops 2638*da0073e9SAndroid Build Coastguard Worker # are expected to implement kernels for these {view}_copy kernels instead. 2639*da0073e9SAndroid Build Coastguard Worker # The code for {view}_copy operators in core is pretty boilerplate-heavy however, 2640*da0073e9SAndroid Build Coastguard Worker # so we codegen the following: 2641*da0073e9SAndroid Build Coastguard Worker # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator. 2642*da0073e9SAndroid Build Coastguard Worker # These are never explicitly invoked by the functionalization pass, 2643*da0073e9SAndroid Build Coastguard Worker # but they could theoretically be called from user code (I added these kernels for completeness, 2644*da0073e9SAndroid Build Coastguard Worker # since the ops are part of the public API). 2645*da0073e9SAndroid Build Coastguard Worker # (2) A derivative formula for every {view}_copy operator 2646*da0073e9SAndroid Build Coastguard Worker # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts, 2647*da0073e9SAndroid Build Coastguard Worker # so rather than stamping all of the entries out in derivatives.yaml, 2648*da0073e9SAndroid Build Coastguard Worker # we codegen them in. 2649*da0073e9SAndroid Build Coastguard Worker # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry. 2650*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2651*da0073e9SAndroid Build Coastguard Worker "CompositeViewCopyKernels.cpp", 2652*da0073e9SAndroid Build Coastguard Worker lambda: { 2653*da0073e9SAndroid Build Coastguard Worker "ops_headers": [ 2654*da0073e9SAndroid Build Coastguard Worker "\n".join( 2655*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{f.root_name}_ops.h>\n" 2656*da0073e9SAndroid Build Coastguard Worker # NB: this include is important as it ensures we 2657*da0073e9SAndroid Build Coastguard Worker # set the visibility on generated view_copy kernels 2658*da0073e9SAndroid Build Coastguard Worker # correctly 2659*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{f.root_name}_native.h>" 2660*da0073e9SAndroid Build Coastguard Worker for f in ( 2661*da0073e9SAndroid Build Coastguard Worker [g.view] if g.view_copy is None else [g.view, g.view_copy] 2662*da0073e9SAndroid Build Coastguard Worker ) 2663*da0073e9SAndroid Build Coastguard Worker ) 2664*da0073e9SAndroid Build Coastguard Worker for g in view_groups 2665*da0073e9SAndroid Build Coastguard Worker ] 2666*da0073e9SAndroid Build Coastguard Worker + [ 2667*da0073e9SAndroid Build Coastguard Worker "\n".join( 2668*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{f.root_name}_ops.h>\n" 2669*da0073e9SAndroid Build Coastguard Worker # NB: this include is also important for correct visibility 2670*da0073e9SAndroid Build Coastguard Worker f"#include <ATen/ops/{f.root_name}_native.h>" 2671*da0073e9SAndroid Build Coastguard Worker for f in [g.inplace, g.mutable, g.functional] 2672*da0073e9SAndroid Build Coastguard Worker if f is not None and "generated" not in f.tags 2673*da0073e9SAndroid Build Coastguard Worker ) 2674*da0073e9SAndroid Build Coastguard Worker for g in structured_native_functions 2675*da0073e9SAndroid Build Coastguard Worker ], 2676*da0073e9SAndroid Build Coastguard Worker "CompositeViewCopyKernel_Definitions": list( 2677*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2678*da0073e9SAndroid Build Coastguard Worker GenCompositeViewCopyKernel( 2679*da0073e9SAndroid Build Coastguard Worker backend_indices[ 2680*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional 2681*da0073e9SAndroid Build Coastguard Worker ] 2682*da0073e9SAndroid Build Coastguard Worker ), 2683*da0073e9SAndroid Build Coastguard Worker view_groups, 2684*da0073e9SAndroid Build Coastguard Worker ) 2685*da0073e9SAndroid Build Coastguard Worker ), 2686*da0073e9SAndroid Build Coastguard Worker "GeneratedCompositeFunctional_Definitions": list( 2687*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2688*da0073e9SAndroid Build Coastguard Worker gen_composite_functional_kernel, 2689*da0073e9SAndroid Build Coastguard Worker structured_native_functions, 2690*da0073e9SAndroid Build Coastguard Worker ) 2691*da0073e9SAndroid Build Coastguard Worker ), 2692*da0073e9SAndroid Build Coastguard Worker "GeneratedCompositeOut_Definitions": list( 2693*da0073e9SAndroid Build Coastguard Worker mapMaybe( 2694*da0073e9SAndroid Build Coastguard Worker gen_composite_out_kernel, 2695*da0073e9SAndroid Build Coastguard Worker structured_native_functions, 2696*da0073e9SAndroid Build Coastguard Worker ) 2697*da0073e9SAndroid Build Coastguard Worker ), 2698*da0073e9SAndroid Build Coastguard Worker }, 2699*da0073e9SAndroid Build Coastguard Worker ) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker 2702*da0073e9SAndroid Build Coastguard Workerdef gen_declarations_yaml( 2703*da0073e9SAndroid Build Coastguard Worker cpu_fm: FileManager, native_functions: Sequence[NativeFunction] 2704*da0073e9SAndroid Build Coastguard Worker) -> None: 2705*da0073e9SAndroid Build Coastguard Worker cpu_fm.write( 2706*da0073e9SAndroid Build Coastguard Worker "Declarations.yaml", 2707*da0073e9SAndroid Build Coastguard Worker lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]), 2708*da0073e9SAndroid Build Coastguard Worker ) 2709*da0073e9SAndroid Build Coastguard Worker 2710*da0073e9SAndroid Build Coastguard Worker 2711*da0073e9SAndroid Build Coastguard Workerdef get_torchgen_root() -> Path: 2712*da0073e9SAndroid Build Coastguard Worker """ 2713*da0073e9SAndroid Build Coastguard Worker If you're depending on torchgen out-of-tree, you can use the root to figure 2714*da0073e9SAndroid Build Coastguard Worker out the path to native_functions.yaml 2715*da0073e9SAndroid Build Coastguard Worker """ 2716*da0073e9SAndroid Build Coastguard Worker return Path(__file__).parent.resolve() 2717*da0073e9SAndroid Build Coastguard Worker 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Workerdef main() -> None: 2720*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Generate ATen source files") 2721*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2722*da0073e9SAndroid Build Coastguard Worker "-s", 2723*da0073e9SAndroid Build Coastguard Worker "--source-path", 2724*da0073e9SAndroid Build Coastguard Worker help="path to source directory for ATen", 2725*da0073e9SAndroid Build Coastguard Worker default="aten/src/ATen", 2726*da0073e9SAndroid Build Coastguard Worker ) 2727*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2728*da0073e9SAndroid Build Coastguard Worker "-o", 2729*da0073e9SAndroid Build Coastguard Worker "--output-dependencies", 2730*da0073e9SAndroid Build Coastguard Worker help="output a list of dependencies into the given file and exit", 2731*da0073e9SAndroid Build Coastguard Worker ) 2732*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2733*da0073e9SAndroid Build Coastguard Worker "--dry-run", 2734*da0073e9SAndroid Build Coastguard Worker action="store_true", 2735*da0073e9SAndroid Build Coastguard Worker help="run without writing any files (still updates outputs)", 2736*da0073e9SAndroid Build Coastguard Worker ) 2737*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2738*da0073e9SAndroid Build Coastguard Worker "--per-operator-headers", 2739*da0073e9SAndroid Build Coastguard Worker action="store_true", 2740*da0073e9SAndroid Build Coastguard Worker help="generate separate headers per operator in ATen/ops", 2741*da0073e9SAndroid Build Coastguard Worker ) 2742*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2743*da0073e9SAndroid Build Coastguard Worker "-d", 2744*da0073e9SAndroid Build Coastguard Worker "--install-dir", 2745*da0073e9SAndroid Build Coastguard Worker "--install_dir", 2746*da0073e9SAndroid Build Coastguard Worker help="output directory", 2747*da0073e9SAndroid Build Coastguard Worker default="build/aten/src/ATen", 2748*da0073e9SAndroid Build Coastguard Worker ) 2749*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2750*da0073e9SAndroid Build Coastguard Worker "--aoti-install-dir", 2751*da0073e9SAndroid Build Coastguard Worker "--aoti_install_dir", 2752*da0073e9SAndroid Build Coastguard Worker help="output directory for AOTInductor shim", 2753*da0073e9SAndroid Build Coastguard Worker default="torch/csrc/inductor/aoti_torch/generated", 2754*da0073e9SAndroid Build Coastguard Worker ) 2755*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2756*da0073e9SAndroid Build Coastguard Worker "--rocm", 2757*da0073e9SAndroid Build Coastguard Worker action="store_true", 2758*da0073e9SAndroid Build Coastguard Worker help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", 2759*da0073e9SAndroid Build Coastguard Worker ) 2760*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2761*da0073e9SAndroid Build Coastguard Worker "--mps", 2762*da0073e9SAndroid Build Coastguard Worker action="store_true", 2763*da0073e9SAndroid Build Coastguard Worker help="Generate MPS registration code when set", 2764*da0073e9SAndroid Build Coastguard Worker ) 2765*da0073e9SAndroid Build Coastguard Worker # TODO: --op-registration-whitelist will be removed when all call-sites 2766*da0073e9SAndroid Build Coastguard Worker # for gen.py are moved over to using the operator YAML file for mobile 2767*da0073e9SAndroid Build Coastguard Worker # custom build. 2768*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2769*da0073e9SAndroid Build Coastguard Worker "--op-registration-whitelist", 2770*da0073e9SAndroid Build Coastguard Worker "--op_registration_whitelist", 2771*da0073e9SAndroid Build Coastguard Worker nargs="*", 2772*da0073e9SAndroid Build Coastguard Worker help="filter op registrations by the whitelist (if set); " 2773*da0073e9SAndroid Build Coastguard Worker "each item is `namespace`::`operator name` without overload name; " 2774*da0073e9SAndroid Build Coastguard Worker "e.g.: aten::empty aten::conv2d ...", 2775*da0073e9SAndroid Build Coastguard Worker ) 2776*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2777*da0073e9SAndroid Build Coastguard Worker "--op-selection-yaml-path", 2778*da0073e9SAndroid Build Coastguard Worker "--op_selection_yaml_path", 2779*da0073e9SAndroid Build Coastguard Worker help="Provide a path to the operator selection (for custom build) YAML " 2780*da0073e9SAndroid Build Coastguard Worker "that contains the information about the set of selected operators " 2781*da0073e9SAndroid Build Coastguard Worker "and their categories (training, ...). Each operator is either a " 2782*da0073e9SAndroid Build Coastguard Worker "full operator name with overload or just a bare operator name. " 2783*da0073e9SAndroid Build Coastguard Worker "The operator names also contain the namespace prefix (e.g. aten::)", 2784*da0073e9SAndroid Build Coastguard Worker ) 2785*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2786*da0073e9SAndroid Build Coastguard Worker "--backend-whitelist", 2787*da0073e9SAndroid Build Coastguard Worker "--backend_whitelist", 2788*da0073e9SAndroid Build Coastguard Worker nargs="*", 2789*da0073e9SAndroid Build Coastguard Worker help="filter dispatch backend by the whitelist (if set), " 2790*da0073e9SAndroid Build Coastguard Worker "e.g.: CPU CUDA QuantizedCPU ...", 2791*da0073e9SAndroid Build Coastguard Worker ) 2792*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2793*da0073e9SAndroid Build Coastguard Worker "--static-dispatch-backend", 2794*da0073e9SAndroid Build Coastguard Worker "--static_dispatch_backend", 2795*da0073e9SAndroid Build Coastguard Worker nargs="*", 2796*da0073e9SAndroid Build Coastguard Worker help="generate static dispatch code for the specific backend (if set)", 2797*da0073e9SAndroid Build Coastguard Worker ) 2798*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2799*da0073e9SAndroid Build Coastguard Worker "--skip-dispatcher-op-registration", 2800*da0073e9SAndroid Build Coastguard Worker "--skip_dispatcher_op_registration", 2801*da0073e9SAndroid Build Coastguard Worker action="store_true", 2802*da0073e9SAndroid Build Coastguard Worker help="Avoid registering operators into the dispatcher.", 2803*da0073e9SAndroid Build Coastguard Worker ) 2804*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2805*da0073e9SAndroid Build Coastguard Worker "--force-schema-registration", 2806*da0073e9SAndroid Build Coastguard Worker "--force_schema_registration", 2807*da0073e9SAndroid Build Coastguard Worker action="store_true", 2808*da0073e9SAndroid Build Coastguard Worker help="force it to generate schema-only registrations for all ops, including" 2809*da0073e9SAndroid Build Coastguard Worker "those that are not listed on --op-registration-whitelist", 2810*da0073e9SAndroid Build Coastguard Worker ) 2811*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2812*da0073e9SAndroid Build Coastguard Worker "--generate", 2813*da0073e9SAndroid Build Coastguard Worker type=str, 2814*da0073e9SAndroid Build Coastguard Worker nargs="*", 2815*da0073e9SAndroid Build Coastguard Worker choices=["headers", "sources", "declarations_yaml"], 2816*da0073e9SAndroid Build Coastguard Worker default=["headers", "sources", "declarations_yaml"], 2817*da0073e9SAndroid Build Coastguard Worker help="Generate only a subset of files", 2818*da0073e9SAndroid Build Coastguard Worker ) 2819*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 2820*da0073e9SAndroid Build Coastguard Worker "--update-aoti-c-shim", 2821*da0073e9SAndroid Build Coastguard Worker action="store_true", 2822*da0073e9SAndroid Build Coastguard Worker help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. " 2823*da0073e9SAndroid Build Coastguard Worker "WARNING: Do not use this unless you are sure what you are doing!!!", 2824*da0073e9SAndroid Build Coastguard Worker ) 2825*da0073e9SAndroid Build Coastguard Worker 2826*da0073e9SAndroid Build Coastguard Worker options = parser.parse_args() 2827*da0073e9SAndroid Build Coastguard Worker 2828*da0073e9SAndroid Build Coastguard Worker selector = get_custom_build_selector( 2829*da0073e9SAndroid Build Coastguard Worker options.op_registration_whitelist, 2830*da0073e9SAndroid Build Coastguard Worker options.op_selection_yaml_path, 2831*da0073e9SAndroid Build Coastguard Worker ) 2832*da0073e9SAndroid Build Coastguard Worker 2833*da0073e9SAndroid Build Coastguard Worker native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") 2834*da0073e9SAndroid Build Coastguard Worker tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Worker from torchgen.model import dispatch_keys 2837*da0073e9SAndroid Build Coastguard Worker 2838*da0073e9SAndroid Build Coastguard Worker # TODO: stop generating CUDA kernels for non-CUDA builds 2839*da0073e9SAndroid Build Coastguard Worker ignore_keys = set() 2840*da0073e9SAndroid Build Coastguard Worker if not options.mps: 2841*da0073e9SAndroid Build Coastguard Worker ignore_keys.add(DispatchKey.MPS) 2842*da0073e9SAndroid Build Coastguard Worker 2843*da0073e9SAndroid Build Coastguard Worker if DispatchKey.MPS in dispatch_keys: 2844*da0073e9SAndroid Build Coastguard Worker del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] 2845*da0073e9SAndroid Build Coastguard Worker 2846*da0073e9SAndroid Build Coastguard Worker parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) 2847*da0073e9SAndroid Build Coastguard Worker valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] 2848*da0073e9SAndroid Build Coastguard Worker native_functions, backend_indices = ( 2849*da0073e9SAndroid Build Coastguard Worker parsed_yaml.native_functions, 2850*da0073e9SAndroid Build Coastguard Worker parsed_yaml.backend_indices, 2851*da0073e9SAndroid Build Coastguard Worker ) 2852*da0073e9SAndroid Build Coastguard Worker 2853*da0073e9SAndroid Build Coastguard Worker grouped_native_functions = get_grouped_native_functions(native_functions) 2854*da0073e9SAndroid Build Coastguard Worker 2855*da0073e9SAndroid Build Coastguard Worker structured_native_functions = [ 2856*da0073e9SAndroid Build Coastguard Worker g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup) 2857*da0073e9SAndroid Build Coastguard Worker ] 2858*da0073e9SAndroid Build Coastguard Worker native_functions_with_view_groups = get_grouped_by_view_native_functions( 2859*da0073e9SAndroid Build Coastguard Worker native_functions 2860*da0073e9SAndroid Build Coastguard Worker ) 2861*da0073e9SAndroid Build Coastguard Worker view_groups = [ 2862*da0073e9SAndroid Build Coastguard Worker g 2863*da0073e9SAndroid Build Coastguard Worker for g in native_functions_with_view_groups 2864*da0073e9SAndroid Build Coastguard Worker if isinstance(g, NativeFunctionsViewGroup) 2865*da0073e9SAndroid Build Coastguard Worker ] 2866*da0073e9SAndroid Build Coastguard Worker 2867*da0073e9SAndroid Build Coastguard Worker # NB: It is mandatory to NOT use os.path.join here, as the install directory 2868*da0073e9SAndroid Build Coastguard Worker # will eventually be ingested by cmake, which does not respect Windows style 2869*da0073e9SAndroid Build Coastguard Worker # path slashes. If you switch this to use os.path.join, you'll get an error 2870*da0073e9SAndroid Build Coastguard Worker # like: 2871*da0073e9SAndroid Build Coastguard Worker # 2872*da0073e9SAndroid Build Coastguard Worker # Syntax error in cmake code when parsing string 2873*da0073e9SAndroid Build Coastguard Worker # 2874*da0073e9SAndroid Build Coastguard Worker # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h 2875*da0073e9SAndroid Build Coastguard Worker # 2876*da0073e9SAndroid Build Coastguard Worker # Invalid character escape '\c'. 2877*da0073e9SAndroid Build Coastguard Worker core_install_dir = f"{options.install_dir}/core" 2878*da0073e9SAndroid Build Coastguard Worker Path(core_install_dir).mkdir(parents=True, exist_ok=True) 2879*da0073e9SAndroid Build Coastguard Worker ops_install_dir = f"{options.install_dir}/ops" 2880*da0073e9SAndroid Build Coastguard Worker Path(ops_install_dir).mkdir(parents=True, exist_ok=True) 2881*da0073e9SAndroid Build Coastguard Worker aoti_install_dir = f"{options.aoti_install_dir}" 2882*da0073e9SAndroid Build Coastguard Worker Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) 2883*da0073e9SAndroid Build Coastguard Worker 2884*da0073e9SAndroid Build Coastguard Worker core_fm = make_file_manager(options=options, install_dir=core_install_dir) 2885*da0073e9SAndroid Build Coastguard Worker cpu_fm = make_file_manager(options=options) 2886*da0073e9SAndroid Build Coastguard Worker cpu_vec_fm = make_file_manager(options=options) 2887*da0073e9SAndroid Build Coastguard Worker cuda_fm = make_file_manager(options=options) 2888*da0073e9SAndroid Build Coastguard Worker ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) 2889*da0073e9SAndroid Build Coastguard Worker aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir) 2890*da0073e9SAndroid Build Coastguard Worker 2891*da0073e9SAndroid Build Coastguard Worker # Only a limited set of dispatch keys get CPUFunctions.h headers generated 2892*da0073e9SAndroid Build Coastguard Worker # for them; this is the set 2893*da0073e9SAndroid Build Coastguard Worker functions_keys = { 2894*da0073e9SAndroid Build Coastguard Worker DispatchKey.CPU, 2895*da0073e9SAndroid Build Coastguard Worker DispatchKey.CUDA, 2896*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 2897*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor, 2898*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutograd, 2899*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional, 2900*da0073e9SAndroid Build Coastguard Worker DispatchKey.Meta, 2901*da0073e9SAndroid Build Coastguard Worker } 2902*da0073e9SAndroid Build Coastguard Worker if options.mps: 2903*da0073e9SAndroid Build Coastguard Worker functions_keys.add(DispatchKey.MPS) 2904*da0073e9SAndroid Build Coastguard Worker 2905*da0073e9SAndroid Build Coastguard Worker if options.backend_whitelist: 2906*da0073e9SAndroid Build Coastguard Worker dispatch_keys = [ 2907*da0073e9SAndroid Build Coastguard Worker k 2908*da0073e9SAndroid Build Coastguard Worker for k in dispatch_keys 2909*da0073e9SAndroid Build Coastguard Worker if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist 2910*da0073e9SAndroid Build Coastguard Worker ] 2911*da0073e9SAndroid Build Coastguard Worker 2912*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx: list[BackendIndex] = [] 2913*da0073e9SAndroid Build Coastguard Worker if options.static_dispatch_backend: 2914*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx = [ 2915*da0073e9SAndroid Build Coastguard Worker backend_indices[DispatchKey.parse(key)] 2916*da0073e9SAndroid Build Coastguard Worker for key in options.static_dispatch_backend 2917*da0073e9SAndroid Build Coastguard Worker ] 2918*da0073e9SAndroid Build Coastguard Worker for key in options.static_dispatch_backend: 2919*da0073e9SAndroid Build Coastguard Worker dp_key = DispatchKey.parse(key) 2920*da0073e9SAndroid Build Coastguard Worker if dp_key not in functions_keys: 2921*da0073e9SAndroid Build Coastguard Worker functions_keys.add(dp_key) 2922*da0073e9SAndroid Build Coastguard Worker 2923*da0073e9SAndroid Build Coastguard Worker if "sources" in options.generate: 2924*da0073e9SAndroid Build Coastguard Worker gen_source_files( 2925*da0073e9SAndroid Build Coastguard Worker native_functions=native_functions, 2926*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 2927*da0073e9SAndroid Build Coastguard Worker structured_native_functions=structured_native_functions, 2928*da0073e9SAndroid Build Coastguard Worker view_groups=view_groups, 2929*da0073e9SAndroid Build Coastguard Worker selector=selector, 2930*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx=static_dispatch_idx, 2931*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 2932*da0073e9SAndroid Build Coastguard Worker aoti_fm=aoti_fm, 2933*da0073e9SAndroid Build Coastguard Worker core_fm=core_fm, 2934*da0073e9SAndroid Build Coastguard Worker cpu_fm=cpu_fm, 2935*da0073e9SAndroid Build Coastguard Worker cpu_vec_fm=cpu_vec_fm, 2936*da0073e9SAndroid Build Coastguard Worker cuda_fm=cuda_fm, 2937*da0073e9SAndroid Build Coastguard Worker dispatch_keys=dispatch_keys, 2938*da0073e9SAndroid Build Coastguard Worker functions_keys=functions_keys, 2939*da0073e9SAndroid Build Coastguard Worker rocm=options.rocm, 2940*da0073e9SAndroid Build Coastguard Worker force_schema_registration=options.force_schema_registration, 2941*da0073e9SAndroid Build Coastguard Worker per_operator_headers=options.per_operator_headers, 2942*da0073e9SAndroid Build Coastguard Worker skip_dispatcher_op_registration=options.skip_dispatcher_op_registration, 2943*da0073e9SAndroid Build Coastguard Worker update_aoti_c_shim=options.update_aoti_c_shim, 2944*da0073e9SAndroid Build Coastguard Worker ) 2945*da0073e9SAndroid Build Coastguard Worker 2946*da0073e9SAndroid Build Coastguard Worker if "headers" in options.generate: 2947*da0073e9SAndroid Build Coastguard Worker gen_headers( 2948*da0073e9SAndroid Build Coastguard Worker native_functions=native_functions, 2949*da0073e9SAndroid Build Coastguard Worker valid_tags=valid_tags, 2950*da0073e9SAndroid Build Coastguard Worker grouped_native_functions=grouped_native_functions, 2951*da0073e9SAndroid Build Coastguard Worker structured_native_functions=structured_native_functions, 2952*da0073e9SAndroid Build Coastguard Worker static_dispatch_idx=static_dispatch_idx, 2953*da0073e9SAndroid Build Coastguard Worker selector=selector, 2954*da0073e9SAndroid Build Coastguard Worker backend_indices=backend_indices, 2955*da0073e9SAndroid Build Coastguard Worker core_fm=core_fm, 2956*da0073e9SAndroid Build Coastguard Worker cpu_fm=cpu_fm, 2957*da0073e9SAndroid Build Coastguard Worker cuda_fm=cuda_fm, 2958*da0073e9SAndroid Build Coastguard Worker ops_fm=ops_fm, 2959*da0073e9SAndroid Build Coastguard Worker dispatch_keys=dispatch_keys, 2960*da0073e9SAndroid Build Coastguard Worker functions_keys=functions_keys, 2961*da0073e9SAndroid Build Coastguard Worker rocm=options.rocm, 2962*da0073e9SAndroid Build Coastguard Worker per_operator_headers=options.per_operator_headers, 2963*da0073e9SAndroid Build Coastguard Worker ) 2964*da0073e9SAndroid Build Coastguard Worker 2965*da0073e9SAndroid Build Coastguard Worker if "declarations_yaml" in options.generate: 2966*da0073e9SAndroid Build Coastguard Worker gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm) 2967*da0073e9SAndroid Build Coastguard Worker 2968*da0073e9SAndroid Build Coastguard Worker if options.output_dependencies: 2969*da0073e9SAndroid Build Coastguard Worker depfile_path = Path(options.output_dependencies).resolve() 2970*da0073e9SAndroid Build Coastguard Worker depfile_name = depfile_path.name 2971*da0073e9SAndroid Build Coastguard Worker depfile_stem = depfile_path.stem 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker for fm, prefix in [ 2974*da0073e9SAndroid Build Coastguard Worker (cpu_fm, ""), 2975*da0073e9SAndroid Build Coastguard Worker (cpu_vec_fm, "cpu_vec_"), 2976*da0073e9SAndroid Build Coastguard Worker (core_fm, "core_"), 2977*da0073e9SAndroid Build Coastguard Worker (cuda_fm, "cuda_"), 2978*da0073e9SAndroid Build Coastguard Worker (ops_fm, "ops_"), 2979*da0073e9SAndroid Build Coastguard Worker ]: 2980*da0073e9SAndroid Build Coastguard Worker varname = prefix + depfile_stem 2981*da0073e9SAndroid Build Coastguard Worker path = depfile_path.parent / (prefix + depfile_name) 2982*da0073e9SAndroid Build Coastguard Worker fm.write_outputs(varname, str(path)) 2983*da0073e9SAndroid Build Coastguard Worker 2984*da0073e9SAndroid Build Coastguard Worker 2985*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2986*da0073e9SAndroid Build Coastguard Worker main() 2987