1*da0073e9SAndroid Build Coastguard Worker# Parses derivatives.yaml into autograd functions 2*da0073e9SAndroid Build Coastguard Worker# 3*da0073e9SAndroid Build Coastguard Worker# Each autograd function is represented by `DifferentiabilityInfo` containing 4*da0073e9SAndroid Build Coastguard Worker# a list of `Derivative`. See `torchgen.api.autograd` for the data models. 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport re 9*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 10*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Counter, Dict, Sequence, Set, Tuple 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport yaml 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 15*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.autograd import ( 16*da0073e9SAndroid Build Coastguard Worker Derivative, 17*da0073e9SAndroid Build Coastguard Worker DifferentiabilityInfo, 18*da0073e9SAndroid Build Coastguard Worker ForwardDerivative, 19*da0073e9SAndroid Build Coastguard Worker SavedAttribute, 20*da0073e9SAndroid Build Coastguard Worker) 21*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 22*da0073e9SAndroid Build Coastguard Worker BaseCType, 23*da0073e9SAndroid Build Coastguard Worker Binding, 24*da0073e9SAndroid Build Coastguard Worker boolT, 25*da0073e9SAndroid Build Coastguard Worker CppSignatureGroup, 26*da0073e9SAndroid Build Coastguard Worker layoutT, 27*da0073e9SAndroid Build Coastguard Worker longT, 28*da0073e9SAndroid Build Coastguard Worker NamedCType, 29*da0073e9SAndroid Build Coastguard Worker OptionalCType, 30*da0073e9SAndroid Build Coastguard Worker scalarTypeT, 31*da0073e9SAndroid Build Coastguard Worker SpecialArgName, 32*da0073e9SAndroid Build Coastguard Worker stringT, 33*da0073e9SAndroid Build Coastguard Worker symIntArrayRefT, 34*da0073e9SAndroid Build Coastguard Worker SymIntT, 35*da0073e9SAndroid Build Coastguard Worker tensorGeometryT, 36*da0073e9SAndroid Build Coastguard Worker tensorOptionsT, 37*da0073e9SAndroid Build Coastguard Worker typeAndSizeT, 38*da0073e9SAndroid Build Coastguard Worker VectorCType, 39*da0073e9SAndroid Build Coastguard Worker) 40*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import with_native_function 41*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml 42*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 43*da0073e9SAndroid Build Coastguard Worker AUTOGRAD_KEYS, 44*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 45*da0073e9SAndroid Build Coastguard Worker NativeFunction, 46*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup, 47*da0073e9SAndroid Build Coastguard Worker OperatorName, 48*da0073e9SAndroid Build Coastguard Worker SchemaKind, 49*da0073e9SAndroid Build Coastguard Worker Type, 50*da0073e9SAndroid Build Coastguard Worker Variant, 51*da0073e9SAndroid Build Coastguard Worker) 52*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import concatMap, IDENT_REGEX, split_name_params 53*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlLoader 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard WorkerDerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op. 64*da0073e9SAndroid Build Coastguard Worker# Since every {view} and {view}_copy op shares the same derivative formula, 65*da0073e9SAndroid Build Coastguard Worker# we generate them here instead of duplicating them in the yaml. 66*da0073e9SAndroid Build Coastguard Worker# See Note [Codegen'd {view}_copy Operators] 67*da0073e9SAndroid Build Coastguard Workerdef add_view_copy_derivatives( 68*da0073e9SAndroid Build Coastguard Worker infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], 69*da0073e9SAndroid Build Coastguard Worker view_groups: list[NativeFunctionsViewGroup], 70*da0073e9SAndroid Build Coastguard Worker) -> None: 71*da0073e9SAndroid Build Coastguard Worker # Get the map from each view op's name to its corresponding view group 72*da0073e9SAndroid Build Coastguard Worker view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = { 73*da0073e9SAndroid Build Coastguard Worker g.view.func.name: g for g in view_groups 74*da0073e9SAndroid Build Coastguard Worker } 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker view_infos = {} 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker for info_dispatch_dict in infos.values(): 79*da0073e9SAndroid Build Coastguard Worker # maybe_view_group only needs to be calculated once per info_dispatch_dict 80*da0073e9SAndroid Build Coastguard Worker maybe_view_group = None 81*da0073e9SAndroid Build Coastguard Worker view_copy_differentiability_infos = {} 82*da0073e9SAndroid Build Coastguard Worker for dispatch_key, info in info_dispatch_dict.items(): 83*da0073e9SAndroid Build Coastguard Worker maybe_view_group = view_name_to_group.get(info.func.func.name, None) 84*da0073e9SAndroid Build Coastguard Worker if maybe_view_group is not None and maybe_view_group.view_copy is not None: 85*da0073e9SAndroid Build Coastguard Worker view_copy_info = info.create_view_copy_from_view_derivative( 86*da0073e9SAndroid Build Coastguard Worker maybe_view_group 87*da0073e9SAndroid Build Coastguard Worker ) 88*da0073e9SAndroid Build Coastguard Worker if view_copy_info is not None: 89*da0073e9SAndroid Build Coastguard Worker fn_schema = view_copy_info.func.func 90*da0073e9SAndroid Build Coastguard Worker view_copy_differentiability_infos[dispatch_key] = view_copy_info 91*da0073e9SAndroid Build Coastguard Worker else: 92*da0073e9SAndroid Build Coastguard Worker break 93*da0073e9SAndroid Build Coastguard Worker # prefer manually-defined derivatives if any 94*da0073e9SAndroid Build Coastguard Worker if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: 95*da0073e9SAndroid Build Coastguard Worker assert fn_schema is not None 96*da0073e9SAndroid Build Coastguard Worker view_infos[fn_schema] = view_copy_differentiability_infos 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker infos.update(view_infos) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerdef load_derivatives( 102*da0073e9SAndroid Build Coastguard Worker derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str 103*da0073e9SAndroid Build Coastguard Worker) -> DerivativeRet: 104*da0073e9SAndroid Build Coastguard Worker # Do some caching as this is a deterministic function 105*da0073e9SAndroid Build Coastguard Worker global _GLOBAL_LOAD_DERIVATIVE_CACHE 106*da0073e9SAndroid Build Coastguard Worker key = (derivatives_yaml_path, native_yaml_path) 107*da0073e9SAndroid Build Coastguard Worker if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: 108*da0073e9SAndroid Build Coastguard Worker with open(derivatives_yaml_path) as f: 109*da0073e9SAndroid Build Coastguard Worker definitions = yaml.load(f, Loader=YamlLoader) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions 112*da0073e9SAndroid Build Coastguard Worker # From the parsed native functions, separate out the (generated) view_copy functions, 113*da0073e9SAndroid Build Coastguard Worker # so we can generate derivatives for them separately. 114*da0073e9SAndroid Build Coastguard Worker native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) 115*da0073e9SAndroid Build Coastguard Worker native_functions = concatMap( 116*da0073e9SAndroid Build Coastguard Worker lambda g: [g] 117*da0073e9SAndroid Build Coastguard Worker if isinstance(g, NativeFunction) 118*da0073e9SAndroid Build Coastguard Worker else list(g.functions(include_copy=True)), 119*da0073e9SAndroid Build Coastguard Worker native_functions_with_view_groups, 120*da0073e9SAndroid Build Coastguard Worker ) 121*da0073e9SAndroid Build Coastguard Worker view_groups = [ 122*da0073e9SAndroid Build Coastguard Worker g 123*da0073e9SAndroid Build Coastguard Worker for g in native_functions_with_view_groups 124*da0073e9SAndroid Build Coastguard Worker if isinstance(g, NativeFunctionsViewGroup) 125*da0073e9SAndroid Build Coastguard Worker ] 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker # What's the difference between function schema v.s. signature? 128*da0073e9SAndroid Build Coastguard Worker # function schema is the complete declaration including mutability annotation / default value and etc. 129*da0073e9SAndroid Build Coastguard Worker # signature is the canonical schema for a group of functions (in-place/out/functional variants) 130*da0073e9SAndroid Build Coastguard Worker # that are semantically related. 131*da0073e9SAndroid Build Coastguard Worker functions_by_signature: dict[ 132*da0073e9SAndroid Build Coastguard Worker FunctionSchema, list[NativeFunction] 133*da0073e9SAndroid Build Coastguard Worker ] = defaultdict(list) 134*da0073e9SAndroid Build Coastguard Worker functions_by_schema: dict[str, NativeFunction] = {} 135*da0073e9SAndroid Build Coastguard Worker for function in native_functions: 136*da0073e9SAndroid Build Coastguard Worker functions_by_signature[function.func.signature()].append(function) 137*da0073e9SAndroid Build Coastguard Worker assert str(function.func) not in functions_by_schema 138*da0073e9SAndroid Build Coastguard Worker functions_by_schema[str(function.func)] = function 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker # Keep track of how many of which ops we've seen so we can 141*da0073e9SAndroid Build Coastguard Worker # disambiguate them with a numeric suffix. 142*da0073e9SAndroid Build Coastguard Worker op_counter = Counter[str]() 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos 145*da0073e9SAndroid Build Coastguard Worker # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info 146*da0073e9SAndroid Build Coastguard Worker # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema 147*da0073e9SAndroid Build Coastguard Worker infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {} 148*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys: set[str] = set() 149*da0073e9SAndroid Build Coastguard Worker for defn_dict in definitions: 150*da0073e9SAndroid Build Coastguard Worker # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. 151*da0073e9SAndroid Build Coastguard Worker if "dispatch" not in defn_dict: 152*da0073e9SAndroid Build Coastguard Worker specification = defn_dict.pop("name") 153*da0073e9SAndroid Build Coastguard Worker output_differentiability = defn_dict.pop( 154*da0073e9SAndroid Build Coastguard Worker "output_differentiability", None 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}} 157*da0073e9SAndroid Build Coastguard Worker if output_differentiability: 158*da0073e9SAndroid Build Coastguard Worker defn_dict["output_differentiability"] = output_differentiability 159*da0073e9SAndroid Build Coastguard Worker name, per_dispatch_diffinfos = create_differentiability_info( 160*da0073e9SAndroid Build Coastguard Worker defn_dict, 161*da0073e9SAndroid Build Coastguard Worker functions_by_signature, 162*da0073e9SAndroid Build Coastguard Worker functions_by_schema, 163*da0073e9SAndroid Build Coastguard Worker op_counter, 164*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys, 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker infos[name] = per_dispatch_diffinfos 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker add_view_copy_derivatives(infos, view_groups) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker # cache both loaded infos as well a a set of all the dispatch_keys/aliases 171*da0073e9SAndroid Build Coastguard Worker # that appear in derivatives.yaml. used_dispatch_keys is useful for generating 172*da0073e9SAndroid Build Coastguard Worker # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used 173*da0073e9SAndroid Build Coastguard Worker _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker# TODO: Why is this going through CppSignatureGroup, that doesn't make sense... 179*da0073e9SAndroid Build Coastguard Worker@with_native_function 180*da0073e9SAndroid Build Coastguard Workerdef cpp_arguments(f: NativeFunction) -> Sequence[Binding]: 181*da0073e9SAndroid Build Coastguard Worker sigs = CppSignatureGroup.from_native_function(f, method=False) 182*da0073e9SAndroid Build Coastguard Worker if sigs.symint_signature is not None: 183*da0073e9SAndroid Build Coastguard Worker return sigs.symint_signature.arguments() 184*da0073e9SAndroid Build Coastguard Worker else: 185*da0073e9SAndroid Build Coastguard Worker return sigs.signature.arguments() 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Workerdef create_derivative( 189*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 190*da0073e9SAndroid Build Coastguard Worker formula: str, 191*da0073e9SAndroid Build Coastguard Worker var_names: tuple[str, ...], 192*da0073e9SAndroid Build Coastguard Worker available_named_gradients: Sequence[str], 193*da0073e9SAndroid Build Coastguard Worker) -> Derivative: 194*da0073e9SAndroid Build Coastguard Worker original_formula = formula 195*da0073e9SAndroid Build Coastguard Worker arguments: list[NamedCType] = [ 196*da0073e9SAndroid Build Coastguard Worker a.nctype.remove_const_ref() for a in cpp_arguments(f) 197*da0073e9SAndroid Build Coastguard Worker ] 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) 200*da0073e9SAndroid Build Coastguard Worker return_types = tuple( 201*da0073e9SAndroid Build Coastguard Worker cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker named_returns = [ 205*da0073e9SAndroid Build Coastguard Worker NamedCType(name, type) for name, type in zip(return_names, return_types) 206*da0073e9SAndroid Build Coastguard Worker ] 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker formula, saved_inputs = saved_variables(formula, arguments, var_names) 209*da0073e9SAndroid Build Coastguard Worker formula, saved_outputs = saved_variables(formula, named_returns, var_names) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker used_named_gradients = { 212*da0073e9SAndroid Build Coastguard Worker name 213*da0073e9SAndroid Build Coastguard Worker for name in available_named_gradients 214*da0073e9SAndroid Build Coastguard Worker if re.search(IDENT_REGEX.format(name), formula) 215*da0073e9SAndroid Build Coastguard Worker } 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker # Check that the referenced derivatives in the formula are in bounds 218*da0073e9SAndroid Build Coastguard Worker for i in used_gradient_indices(formula): 219*da0073e9SAndroid Build Coastguard Worker if i >= len(f.func.returns): 220*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 221*da0073e9SAndroid Build Coastguard Worker f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " 222*da0073e9SAndroid Build Coastguard Worker f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker return Derivative( 226*da0073e9SAndroid Build Coastguard Worker formula=formula, 227*da0073e9SAndroid Build Coastguard Worker original_formula=original_formula, 228*da0073e9SAndroid Build Coastguard Worker var_names=var_names, 229*da0073e9SAndroid Build Coastguard Worker saved_inputs=saved_inputs, 230*da0073e9SAndroid Build Coastguard Worker saved_outputs=saved_outputs, 231*da0073e9SAndroid Build Coastguard Worker named_gradients=used_named_gradients, 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Workerdef create_forward_derivative( 236*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, formula: str, names: tuple[str, ...] 237*da0073e9SAndroid Build Coastguard Worker) -> ForwardDerivative: 238*da0073e9SAndroid Build Coastguard Worker var_names = names 239*da0073e9SAndroid Build Coastguard Worker var_types: tuple[Type, ...] | None = None 240*da0073e9SAndroid Build Coastguard Worker for r in f.func.returns: 241*da0073e9SAndroid Build Coastguard Worker if r.name in var_names: 242*da0073e9SAndroid Build Coastguard Worker if var_types is None: 243*da0073e9SAndroid Build Coastguard Worker var_types = () 244*da0073e9SAndroid Build Coastguard Worker var_types = var_types + (r.type,) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # Handle default return names 247*da0073e9SAndroid Build Coastguard Worker if var_types is None: 248*da0073e9SAndroid Build Coastguard Worker if var_names == ("result",): 249*da0073e9SAndroid Build Coastguard Worker assert len(f.func.returns) == 1 250*da0073e9SAndroid Build Coastguard Worker var_types = (f.func.returns[0].type,) 251*da0073e9SAndroid Build Coastguard Worker else: 252*da0073e9SAndroid Build Coastguard Worker for var_name in var_names: 253*da0073e9SAndroid Build Coastguard Worker res = re.findall(r"^result(\d+)$", var_name) 254*da0073e9SAndroid Build Coastguard Worker if len(res) == 1: 255*da0073e9SAndroid Build Coastguard Worker if var_types is None: 256*da0073e9SAndroid Build Coastguard Worker var_types = () 257*da0073e9SAndroid Build Coastguard Worker arg_idx = int(res[0]) 258*da0073e9SAndroid Build Coastguard Worker var_types = var_types + (f.func.returns[arg_idx].type,) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker assert var_types is not None, "No matching output for forward derivative definition" 261*da0073e9SAndroid Build Coastguard Worker return ForwardDerivative( 262*da0073e9SAndroid Build Coastguard Worker formula=formula, 263*da0073e9SAndroid Build Coastguard Worker var_names=var_names, 264*da0073e9SAndroid Build Coastguard Worker var_types=var_types, 265*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad=None, 266*da0073e9SAndroid Build Coastguard Worker required_inputs_primal=None, 267*da0073e9SAndroid Build Coastguard Worker required_original_self_value=False, 268*da0073e9SAndroid Build Coastguard Worker is_reusing_outplace_formula=False, 269*da0073e9SAndroid Build Coastguard Worker ) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Workerdef postprocess_forward_derivatives( 273*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 274*da0073e9SAndroid Build Coastguard Worker defn_name: str, 275*da0073e9SAndroid Build Coastguard Worker all_arg_names: list[str], 276*da0073e9SAndroid Build Coastguard Worker derivatives: list[Derivative], 277*da0073e9SAndroid Build Coastguard Worker forward_derivatives: list[ForwardDerivative], 278*da0073e9SAndroid Build Coastguard Worker args_with_derivatives: Sequence[Binding], 279*da0073e9SAndroid Build Coastguard Worker) -> list[ForwardDerivative]: 280*da0073e9SAndroid Build Coastguard Worker def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: 281*da0073e9SAndroid Build Coastguard Worker is_foreach = f.func.name.name.base.startswith("_foreach_") 282*da0073e9SAndroid Build Coastguard Worker required_inputs = set() 283*da0073e9SAndroid Build Coastguard Worker for arg in args_with_derivatives: 284*da0073e9SAndroid Build Coastguard Worker if ( 285*da0073e9SAndroid Build Coastguard Worker arg.type in ("at::TensorList", "const at::ITensorListRef &") 286*da0073e9SAndroid Build Coastguard Worker and not is_foreach 287*da0073e9SAndroid Build Coastguard Worker ): 288*da0073e9SAndroid Build Coastguard Worker # The functions taking TensorList handle everything internally 289*da0073e9SAndroid Build Coastguard Worker continue 290*da0073e9SAndroid Build Coastguard Worker arg_name = arg.name 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker found = re.search(IDENT_REGEX.format(arg_name), formula) 293*da0073e9SAndroid Build Coastguard Worker if found: 294*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 295*da0073e9SAndroid Build Coastguard Worker f"The forward formula for {defn_name} is using the base name of the {arg_name} " 296*da0073e9SAndroid Build Coastguard Worker f"argument which is ambiguous. You should use {arg_name}_p to access the primal " 297*da0073e9SAndroid Build Coastguard Worker f"value and {arg_name}_t to access the tangent." 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) 301*da0073e9SAndroid Build Coastguard Worker if found: 302*da0073e9SAndroid Build Coastguard Worker required_inputs.add(arg_name) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker return tuple(required_inputs) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker updated_derivatives: list[ForwardDerivative] = [] 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker for defn in forward_derivatives: 309*da0073e9SAndroid Build Coastguard Worker formula = defn.formula 310*da0073e9SAndroid Build Coastguard Worker required_inputs_tangent = find_required_inputs(formula, "_t") 311*da0073e9SAndroid Build Coastguard Worker if formula == "auto_element_wise": 312*da0073e9SAndroid Build Coastguard Worker assert ( 313*da0073e9SAndroid Build Coastguard Worker f.func.kind() != SchemaKind.inplace 314*da0073e9SAndroid Build Coastguard Worker ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" 315*da0073e9SAndroid Build Coastguard Worker if ( 316*da0073e9SAndroid Build Coastguard Worker (not len(args_with_derivatives) == 1) 317*da0073e9SAndroid Build Coastguard Worker or len(forward_derivatives) > 1 318*da0073e9SAndroid Build Coastguard Worker or len(forward_derivatives[0].var_names) > 1 319*da0073e9SAndroid Build Coastguard Worker ): 320*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 321*da0073e9SAndroid Build Coastguard Worker f"Derivative definition of {defn_name} in derivatives.yaml defines the " 322*da0073e9SAndroid Build Coastguard Worker "forward definition of gradient as element_wise but this only " 323*da0073e9SAndroid Build Coastguard Worker "works for functions with a single differentiable input and a " 324*da0073e9SAndroid Build Coastguard Worker "single differentiable output." 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker if not len(derivatives) == 1: 327*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 328*da0073e9SAndroid Build Coastguard Worker f"Derivative definition of {defn_name} in derivatives.yaml defines the " 329*da0073e9SAndroid Build Coastguard Worker "forward definition of gradient as element_wise but it does not " 330*da0073e9SAndroid Build Coastguard Worker "defines the gradient formula for its argument which is required." 331*da0073e9SAndroid Build Coastguard Worker ) 332*da0073e9SAndroid Build Coastguard Worker # This transformation is based on the observation that for element-wise functions, the Jacobian 333*da0073e9SAndroid Build Coastguard Worker # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) 334*da0073e9SAndroid Build Coastguard Worker # For the complex case, we use hermitian transpose and get (v.conj() J).conj() 335*da0073e9SAndroid Build Coastguard Worker # So here we are going to re-use the backward formula and replace two things: 336*da0073e9SAndroid Build Coastguard Worker # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. 337*da0073e9SAndroid Build Coastguard Worker # 2) all usage of an original input "foo" with its primal value "foo_p". 338*da0073e9SAndroid Build Coastguard Worker # 3) conjugate the final result 339*da0073e9SAndroid Build Coastguard Worker # For example, for abs, the backward formula is: 340*da0073e9SAndroid Build Coastguard Worker # grad * self.sgn() 341*da0073e9SAndroid Build Coastguard Worker # And this function generates a forward formula that is: 342*da0073e9SAndroid Build Coastguard Worker # (self_t.conj() * self_p.sgn()).conj() 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker backward_formula = derivatives[0].original_formula 345*da0073e9SAndroid Build Coastguard Worker input_name = args_with_derivatives[0].name 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker # Do replacement 1) of the grad 348*da0073e9SAndroid Build Coastguard Worker def repl(m: Any) -> str: 349*da0073e9SAndroid Build Coastguard Worker return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker # Do replacement 2) of the input variables 354*da0073e9SAndroid Build Coastguard Worker for arg in args_with_derivatives: 355*da0073e9SAndroid Build Coastguard Worker arg_name = arg.name 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def repl(m: Any) -> str: 358*da0073e9SAndroid Build Coastguard Worker return f"{m.group(1)}{arg_name}_p{m.group(2)}" 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker # Do the final conjugate 3) 363*da0073e9SAndroid Build Coastguard Worker fw_formula = f"({fw_formula}).conj()" 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker # Since there is a single differentiable inputs and we necessarily need its tangent we can 366*da0073e9SAndroid Build Coastguard Worker # simply require all differentiable input's tangent. 367*da0073e9SAndroid Build Coastguard Worker required_inputs_tangent = tuple(all_arg_names) 368*da0073e9SAndroid Build Coastguard Worker formula = fw_formula 369*da0073e9SAndroid Build Coastguard Worker elif formula == "auto_linear": 370*da0073e9SAndroid Build Coastguard Worker if ( 371*da0073e9SAndroid Build Coastguard Worker len(forward_derivatives) > 1 372*da0073e9SAndroid Build Coastguard Worker or len(forward_derivatives[0].var_names) > 1 373*da0073e9SAndroid Build Coastguard Worker ): 374*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 375*da0073e9SAndroid Build Coastguard Worker f"Derivative definition of {defn_name} in derivatives.yaml defines the " 376*da0073e9SAndroid Build Coastguard Worker "forward definition of gradient as linear but this only works " 377*da0073e9SAndroid Build Coastguard Worker "for functions with a single differentiable output." 378*da0073e9SAndroid Build Coastguard Worker ) 379*da0073e9SAndroid Build Coastguard Worker # This transformation is based on the observation that linear functions can be written as: 380*da0073e9SAndroid Build Coastguard Worker # y = f(x) = A * x 381*da0073e9SAndroid Build Coastguard Worker # For some matrix A and the Jacobian of the function f is also A. 382*da0073e9SAndroid Build Coastguard Worker # So doing J * v = A * v = f(v). 383*da0073e9SAndroid Build Coastguard Worker # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. 384*da0073e9SAndroid Build Coastguard Worker # We do this by calling the forward again by replacing any occurrence of the differentiable 385*da0073e9SAndroid Build Coastguard Worker # input "foo" by it's tangent "foo_t". 386*da0073e9SAndroid Build Coastguard Worker # Note that multiple inputs are not a problem as long as the function is truly linear wrt to 387*da0073e9SAndroid Build Coastguard Worker # the vector where all the differentiable inputs are stacked. 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker diff_arg_names = [arg.name for arg in args_with_derivatives] 390*da0073e9SAndroid Build Coastguard Worker assert len(diff_arg_names) > 0 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker # Do replacement of input variables 393*da0073e9SAndroid Build Coastguard Worker new_args = [] 394*da0073e9SAndroid Build Coastguard Worker for arg_name in all_arg_names: 395*da0073e9SAndroid Build Coastguard Worker if arg_name in diff_arg_names: 396*da0073e9SAndroid Build Coastguard Worker arg_name = arg_name + "_t" 397*da0073e9SAndroid Build Coastguard Worker new_args.append(arg_name) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker # TODO we are trolling 400*da0073e9SAndroid Build Coastguard Worker if f.func.has_symint(): 401*da0073e9SAndroid Build Coastguard Worker defn_name += "_symint" 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. 404*da0073e9SAndroid Build Coastguard Worker if Variant.function in f.variants: 405*da0073e9SAndroid Build Coastguard Worker fw_formula = f"at::{defn_name}({', '.join(new_args)})" 406*da0073e9SAndroid Build Coastguard Worker else: 407*da0073e9SAndroid Build Coastguard Worker assert Variant.method in f.variants 408*da0073e9SAndroid Build Coastguard Worker fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})" 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker # All of the input tangents are always used so all of them are required here. 411*da0073e9SAndroid Build Coastguard Worker required_inputs_tangent = tuple(diff_arg_names) 412*da0073e9SAndroid Build Coastguard Worker formula = fw_formula 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker # At this point, the formula is final and is not modified anymore. 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker # During forward formula, we use the primal instead of the input Tensors. 417*da0073e9SAndroid Build Coastguard Worker # This call inspects the formula to find for which input's primal are used. 418*da0073e9SAndroid Build Coastguard Worker required_inputs_primal = find_required_inputs(formula, "_p") 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker updated_derivatives.append( 421*da0073e9SAndroid Build Coastguard Worker ForwardDerivative( 422*da0073e9SAndroid Build Coastguard Worker formula=formula, 423*da0073e9SAndroid Build Coastguard Worker var_names=defn.var_names, 424*da0073e9SAndroid Build Coastguard Worker var_types=defn.var_types, 425*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad=required_inputs_tangent, 426*da0073e9SAndroid Build Coastguard Worker required_inputs_primal=required_inputs_primal, 427*da0073e9SAndroid Build Coastguard Worker required_original_self_value=False, 428*da0073e9SAndroid Build Coastguard Worker is_reusing_outplace_formula=False, 429*da0073e9SAndroid Build Coastguard Worker ) 430*da0073e9SAndroid Build Coastguard Worker ) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker return updated_derivatives 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Workerdef is_forward_derivative_definition( 436*da0073e9SAndroid Build Coastguard Worker all_arg_names: list[str], names: tuple[str, ...] 437*da0073e9SAndroid Build Coastguard Worker) -> bool: 438*da0073e9SAndroid Build Coastguard Worker for name in names: 439*da0073e9SAndroid Build Coastguard Worker return name not in all_arg_names 440*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Expected `names` to be non-empty") 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Workerdef create_differentiability_info( 444*da0073e9SAndroid Build Coastguard Worker defn_dict: dict[Any, Any], 445*da0073e9SAndroid Build Coastguard Worker functions_by_signature: dict[FunctionSchema, list[NativeFunction]], 446*da0073e9SAndroid Build Coastguard Worker functions_by_schema: dict[str, NativeFunction], 447*da0073e9SAndroid Build Coastguard Worker op_counter: Counter[str], 448*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys: set[str], 449*da0073e9SAndroid Build Coastguard Worker) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]: 450*da0073e9SAndroid Build Coastguard Worker """Processes a single entry `defn` in derivatives.yaml""" 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker def canonical_function( 453*da0073e9SAndroid Build Coastguard Worker functions: Sequence[NativeFunction], name: str 454*da0073e9SAndroid Build Coastguard Worker ) -> NativeFunction: 455*da0073e9SAndroid Build Coastguard Worker for f in functions: 456*da0073e9SAndroid Build Coastguard Worker if ( 457*da0073e9SAndroid Build Coastguard Worker not f.func.is_functional_fn() 458*da0073e9SAndroid Build Coastguard Worker and not f.func.is_out_fn() 459*da0073e9SAndroid Build Coastguard Worker and name == str(f.func.name.name) 460*da0073e9SAndroid Build Coastguard Worker ): 461*da0073e9SAndroid Build Coastguard Worker return f 462*da0073e9SAndroid Build Coastguard Worker # some functions only have in-place variants 463*da0073e9SAndroid Build Coastguard Worker assert name + "_" == cpp.name(functions[0].func) 464*da0073e9SAndroid Build Coastguard Worker return functions[0] 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker def split_names(raw_names: str) -> tuple[str, ...]: 467*da0073e9SAndroid Build Coastguard Worker """Given "foo, bar", return ["foo", "bar"].""" 468*da0073e9SAndroid Build Coastguard Worker return tuple(x.strip() for x in raw_names.split(",")) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: 471*da0073e9SAndroid Build Coastguard Worker """ 472*da0073e9SAndroid Build Coastguard Worker Check for some subtle mistakes one might make when writing derivatives. 473*da0073e9SAndroid Build Coastguard Worker These mistakes will compile, but will be latent until a function is 474*da0073e9SAndroid Build Coastguard Worker used with double backwards. 475*da0073e9SAndroid Build Coastguard Worker """ 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker uses_grad = False # true if any derivative uses "grad" 478*da0073e9SAndroid Build Coastguard Worker num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" 479*da0073e9SAndroid Build Coastguard Worker uses_named_grads = False # true if any derivative uses "grad_{name}" 480*da0073e9SAndroid Build Coastguard Worker used_grads_indices: list[int] = [] # which indices of grads are used 481*da0073e9SAndroid Build Coastguard Worker for d in derivatives: 482*da0073e9SAndroid Build Coastguard Worker formula = d.formula 483*da0073e9SAndroid Build Coastguard Worker uses_grad = uses_grad or bool( 484*da0073e9SAndroid Build Coastguard Worker re.findall(IDENT_REGEX.format("grad"), formula) 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula)) 487*da0073e9SAndroid Build Coastguard Worker uses_named_grads = uses_named_grads or bool(d.named_gradients) 488*da0073e9SAndroid Build Coastguard Worker used_grads_indices.extend(used_gradient_indices(formula)) 489*da0073e9SAndroid Build Coastguard Worker # This is a basic sanity check: the number of places we see 490*da0073e9SAndroid Build Coastguard Worker # "grads" should be no fewer than the number of indices we see 491*da0073e9SAndroid Build Coastguard Worker # inside "grads". They may not be equal because we may use 492*da0073e9SAndroid Build Coastguard Worker # "grads" without an index. 493*da0073e9SAndroid Build Coastguard Worker assert num_grads_uses >= len(used_grads_indices) 494*da0073e9SAndroid Build Coastguard Worker # Thus if the number is equal, every use of grads is also 495*da0073e9SAndroid Build Coastguard Worker # indexed. 496*da0073e9SAndroid Build Coastguard Worker only_used_grads_indices = num_grads_uses == len(used_grads_indices) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker if uses_grad and num_grads_uses > 0: 499*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 500*da0073e9SAndroid Build Coastguard Worker f"Derivative definition of {defn_name} in derivatives.yaml illegally " 501*da0073e9SAndroid Build Coastguard Worker "mixes use of 'grad' and 'grads'. Consider replacing " 502*da0073e9SAndroid Build Coastguard Worker "occurrences of 'grad' with 'grads[0]'" 503*da0073e9SAndroid Build Coastguard Worker ) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker if only_used_grads_indices and set(used_grads_indices) == {0}: 506*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 507*da0073e9SAndroid Build Coastguard Worker f"Derivative definition of {defn_name} in derivatives.yaml solely " 508*da0073e9SAndroid Build Coastguard Worker "refers to 'grads[0]'. If the first output is indeed the " 509*da0073e9SAndroid Build Coastguard Worker "only differentiable output, replace 'grads[0]' with 'grad'; " 510*da0073e9SAndroid Build Coastguard Worker "otherwise, there is a likely error in your derivatives " 511*da0073e9SAndroid Build Coastguard Worker "declaration." 512*da0073e9SAndroid Build Coastguard Worker ) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker if uses_named_grads and (uses_grad or num_grads_uses > 0): 515*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 516*da0073e9SAndroid Build Coastguard Worker f"Derivative definition of {defn_name} in derivatives.yaml illegally " 517*da0073e9SAndroid Build Coastguard Worker 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' 518*da0073e9SAndroid Build Coastguard Worker "only one method for identifying gradients." 519*da0073e9SAndroid Build Coastguard Worker ) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker @with_native_function 522*da0073e9SAndroid Build Coastguard Worker def set_up_derivatives( 523*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 524*da0073e9SAndroid Build Coastguard Worker ) -> tuple[ 525*da0073e9SAndroid Build Coastguard Worker Sequence[Derivative], 526*da0073e9SAndroid Build Coastguard Worker Sequence[ForwardDerivative], 527*da0073e9SAndroid Build Coastguard Worker Sequence[Binding], 528*da0073e9SAndroid Build Coastguard Worker Sequence[str], 529*da0073e9SAndroid Build Coastguard Worker Sequence[str], 530*da0073e9SAndroid Build Coastguard Worker ]: 531*da0073e9SAndroid Build Coastguard Worker # Set up the derivative information 532*da0073e9SAndroid Build Coastguard Worker derivatives: list[Derivative] = [] 533*da0073e9SAndroid Build Coastguard Worker forward_derivatives: list[ForwardDerivative] = [] 534*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names: list[str] = [] 535*da0073e9SAndroid Build Coastguard Worker args_with_derivatives_set: set[str] = set() 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker all_arg_names = [a.name for a in cpp_arguments(f)] 538*da0073e9SAndroid Build Coastguard Worker all_ret_names = [ 539*da0073e9SAndroid Build Coastguard Worker r.name for r in f.func.returns 540*da0073e9SAndroid Build Coastguard Worker ] # only used for the assert below 541*da0073e9SAndroid Build Coastguard Worker # output_differentiability is captured from the enclosed 542*da0073e9SAndroid Build Coastguard Worker # scope. Don't modify it. 543*da0073e9SAndroid Build Coastguard Worker # 544*da0073e9SAndroid Build Coastguard Worker # If it is not present, then no output is explicitly 545*da0073e9SAndroid Build Coastguard Worker # undifferentiable. 546*da0073e9SAndroid Build Coastguard Worker # 547*da0073e9SAndroid Build Coastguard Worker # It may be present and shorter than the length of return 548*da0073e9SAndroid Build Coastguard Worker # values. If that's the case, any return value that does not 549*da0073e9SAndroid Build Coastguard Worker # have a corresponding entry is considered not differentiable. 550*da0073e9SAndroid Build Coastguard Worker differentiability = output_differentiability or [True] * len(f.func.returns) 551*da0073e9SAndroid Build Coastguard Worker # A return is available as a named gradient ... 552*da0073e9SAndroid Build Coastguard Worker available_named_gradients = [ 553*da0073e9SAndroid Build Coastguard Worker f"grad_{ret.name}" 554*da0073e9SAndroid Build Coastguard Worker for ret, differentiable in zip(f.func.returns, differentiability) 555*da0073e9SAndroid Build Coastguard Worker # if it has not been explicitly made undifferentiable 556*da0073e9SAndroid Build Coastguard Worker if differentiable 557*da0073e9SAndroid Build Coastguard Worker # and if it has a name 558*da0073e9SAndroid Build Coastguard Worker and ret.name is not None 559*da0073e9SAndroid Build Coastguard Worker # and if its type is differentiable 560*da0073e9SAndroid Build Coastguard Worker and ret.type.is_tensor_like() 561*da0073e9SAndroid Build Coastguard Worker ] 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker for raw_names in sorted(defn.keys()): 564*da0073e9SAndroid Build Coastguard Worker formula = defn[raw_names] 565*da0073e9SAndroid Build Coastguard Worker names = split_names(raw_names) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker for name in names: 568*da0073e9SAndroid Build Coastguard Worker assert not (name in all_arg_names and name in all_ret_names), ( 569*da0073e9SAndroid Build Coastguard Worker f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " 570*da0073e9SAndroid Build Coastguard Worker f"expected '{name}' to not be both an input arg and named return. " 571*da0073e9SAndroid Build Coastguard Worker ) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker if is_forward_derivative_definition(all_arg_names, names): 574*da0073e9SAndroid Build Coastguard Worker forward_derivatives.append(create_forward_derivative(f, formula, names)) 575*da0073e9SAndroid Build Coastguard Worker else: 576*da0073e9SAndroid Build Coastguard Worker if formula.lower().strip() == "non_differentiable": 577*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names += names 578*da0073e9SAndroid Build Coastguard Worker else: 579*da0073e9SAndroid Build Coastguard Worker derivative = create_derivative( 580*da0073e9SAndroid Build Coastguard Worker f, formula, names, available_named_gradients 581*da0073e9SAndroid Build Coastguard Worker ) 582*da0073e9SAndroid Build Coastguard Worker derivatives.append(derivative) 583*da0073e9SAndroid Build Coastguard Worker args_with_derivatives_set |= set(names) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) 586*da0073e9SAndroid Build Coastguard Worker if overlap: 587*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 588*da0073e9SAndroid Build Coastguard Worker f"derivatives definition for {defn} have overlapped non_differentiable " 589*da0073e9SAndroid Build Coastguard Worker f"and differentiable variables: {overlap}" 590*da0073e9SAndroid Build Coastguard Worker ) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker # Next, let us determine the list of inputs in order. 593*da0073e9SAndroid Build Coastguard Worker # TODO: do we need eagerly calculate and save it here? Can it be derived 594*da0073e9SAndroid Build Coastguard Worker # from NativeFunction and `derivatives` on callsites instead? 595*da0073e9SAndroid Build Coastguard Worker args_with_derivatives = [ 596*da0073e9SAndroid Build Coastguard Worker a for a in cpp_arguments(f) if a.name in args_with_derivatives_set 597*da0073e9SAndroid Build Coastguard Worker ] 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker # Postprocess forward derivatives definitions now that we know the differentiable arguments 600*da0073e9SAndroid Build Coastguard Worker forward_derivatives = postprocess_forward_derivatives( 601*da0073e9SAndroid Build Coastguard Worker f, 602*da0073e9SAndroid Build Coastguard Worker defn_name, 603*da0073e9SAndroid Build Coastguard Worker all_arg_names, 604*da0073e9SAndroid Build Coastguard Worker derivatives, 605*da0073e9SAndroid Build Coastguard Worker forward_derivatives, 606*da0073e9SAndroid Build Coastguard Worker args_with_derivatives, 607*da0073e9SAndroid Build Coastguard Worker ) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker # Test to see if the use of 'grads' makes sense. 610*da0073e9SAndroid Build Coastguard Worker check_grad_usage(defn_name, derivatives) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker return ( 613*da0073e9SAndroid Build Coastguard Worker derivatives, 614*da0073e9SAndroid Build Coastguard Worker forward_derivatives, 615*da0073e9SAndroid Build Coastguard Worker args_with_derivatives, 616*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names, 617*da0073e9SAndroid Build Coastguard Worker available_named_gradients, 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker # NB: Removes 'name' from defn dictionary 621*da0073e9SAndroid Build Coastguard Worker specification = defn_dict.pop("name") 622*da0073e9SAndroid Build Coastguard Worker defn_name, _ = split_name_params(specification) 623*da0073e9SAndroid Build Coastguard Worker # NB: Removes 'output_differentiability' from defn dictionary 624*da0073e9SAndroid Build Coastguard Worker # `None` means all differentiable. 625*da0073e9SAndroid Build Coastguard Worker output_differentiability = defn_dict.pop("output_differentiability", None) 626*da0073e9SAndroid Build Coastguard Worker output_differentiability_conditions = None 627*da0073e9SAndroid Build Coastguard Worker if output_differentiability and any( 628*da0073e9SAndroid Build Coastguard Worker isinstance(diff, str) for diff in output_differentiability 629*da0073e9SAndroid Build Coastguard Worker ): 630*da0073e9SAndroid Build Coastguard Worker if len(output_differentiability) != 1: 631*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 632*da0073e9SAndroid Build Coastguard Worker f"Not supported: for {specification}," 633*da0073e9SAndroid Build Coastguard Worker f"output_differentiability must either be " 634*da0073e9SAndroid Build Coastguard Worker f"List[bool] or a List[str] where each str is a " 635*da0073e9SAndroid Build Coastguard Worker f"condition. In the case where it is a condition, " 636*da0073e9SAndroid Build Coastguard Worker f"we only support single-output functions. " 637*da0073e9SAndroid Build Coastguard Worker f"Please file us an issue. " 638*da0073e9SAndroid Build Coastguard Worker ) 639*da0073e9SAndroid Build Coastguard Worker output_differentiability_conditions = output_differentiability 640*da0073e9SAndroid Build Coastguard Worker output_differentiability = [True] 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker schema_function = functions_by_schema.get(specification) 643*da0073e9SAndroid Build Coastguard Worker if not schema_function: 644*da0073e9SAndroid Build Coastguard Worker avail = "\n".join( 645*da0073e9SAndroid Build Coastguard Worker k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name 646*da0073e9SAndroid Build Coastguard Worker ) 647*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 648*da0073e9SAndroid Build Coastguard Worker f"could not find ATen function for schema: {specification} " 649*da0073e9SAndroid Build Coastguard Worker f". Available signatures:\n{avail}" 650*da0073e9SAndroid Build Coastguard Worker ) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here 653*da0073e9SAndroid Build Coastguard Worker # to map in-place schemas to the out-of-place variants. 654*da0073e9SAndroid Build Coastguard Worker # TODO: maybe the logic to handle the legacy schema is no longer necessary? 655*da0073e9SAndroid Build Coastguard Worker signature = schema_function.func.signature() 656*da0073e9SAndroid Build Coastguard Worker functions = functions_by_signature[signature] 657*da0073e9SAndroid Build Coastguard Worker if len(functions) == 0: 658*da0073e9SAndroid Build Coastguard Worker avail = "\n".join( 659*da0073e9SAndroid Build Coastguard Worker str(k) 660*da0073e9SAndroid Build Coastguard Worker for k, v in functions_by_signature.items() 661*da0073e9SAndroid Build Coastguard Worker if cpp.name(k) == defn_name 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 664*da0073e9SAndroid Build Coastguard Worker f"could not find ATen function for legacy signature: {signature} " 665*da0073e9SAndroid Build Coastguard Worker f"corresponding to schema {specification}. Please report a bug to PyTorch. " 666*da0073e9SAndroid Build Coastguard Worker f"Available signatures:\n{avail}" 667*da0073e9SAndroid Build Coastguard Worker ) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker canonical = canonical_function(functions, defn_name) 670*da0073e9SAndroid Build Coastguard Worker if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): 671*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 672*da0073e9SAndroid Build Coastguard Worker f"Schema for {defn_name} has an argument named grad_input_mask, " 673*da0073e9SAndroid Build Coastguard Worker "but this name would be shadowed by our codegen. " 674*da0073e9SAndroid Build Coastguard Worker "Please use a different name in native_functions.yaml." 675*da0073e9SAndroid Build Coastguard Worker ) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker if "result" in (a.name for a in cpp_arguments(canonical)): 678*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 679*da0073e9SAndroid Build Coastguard Worker f"Schema for {defn_name} has an argument named result, " 680*da0073e9SAndroid Build Coastguard Worker "but this is only allowed for outputs." 681*da0073e9SAndroid Build Coastguard Worker "Please use a different name in native_functions.yaml." 682*da0073e9SAndroid Build Coastguard Worker ) 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker diffinfo_dict = {} 685*da0073e9SAndroid Build Coastguard Worker for key, defn in defn_dict["dispatch"].items(): 686*da0073e9SAndroid Build Coastguard Worker if key != "Default" and key not in _VALID_AUTOGRAD_KEYS: 687*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 688*da0073e9SAndroid Build Coastguard Worker f"Invalid dispatch key {key} in derivatives.yaml for {specification}," 689*da0073e9SAndroid Build Coastguard Worker f" expected key to be one of {_VALID_AUTOGRAD_KEYS}" 690*da0073e9SAndroid Build Coastguard Worker ) 691*da0073e9SAndroid Build Coastguard Worker if key not in used_dispatch_keys: 692*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys.add(key) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker ( 695*da0073e9SAndroid Build Coastguard Worker derivatives, 696*da0073e9SAndroid Build Coastguard Worker forward_derivatives, 697*da0073e9SAndroid Build Coastguard Worker args_with_derivatives, 698*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names, 699*da0073e9SAndroid Build Coastguard Worker available_named_gradients, 700*da0073e9SAndroid Build Coastguard Worker ) = set_up_derivatives(canonical) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker used_named_gradients: set[str] = set() 703*da0073e9SAndroid Build Coastguard Worker for d in derivatives: 704*da0073e9SAndroid Build Coastguard Worker used_named_gradients |= d.named_gradients 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker # only assign an op name if we are actually going to calculate a derivative 707*da0073e9SAndroid Build Coastguard Worker op = None 708*da0073e9SAndroid Build Coastguard Worker if args_with_derivatives: 709*da0073e9SAndroid Build Coastguard Worker op_prefix = _create_op_prefix(defn_name) 710*da0073e9SAndroid Build Coastguard Worker if key != "Default": 711*da0073e9SAndroid Build Coastguard Worker op_prefix = op_prefix + key 712*da0073e9SAndroid Build Coastguard Worker op = f"{op_prefix}{op_counter[op_prefix]}" 713*da0073e9SAndroid Build Coastguard Worker op_counter[op_prefix] += 1 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker diffinfo_dict[key] = DifferentiabilityInfo( 716*da0073e9SAndroid Build Coastguard Worker name=defn_name, 717*da0073e9SAndroid Build Coastguard Worker func=canonical, 718*da0073e9SAndroid Build Coastguard Worker op=op, 719*da0073e9SAndroid Build Coastguard Worker derivatives=derivatives, 720*da0073e9SAndroid Build Coastguard Worker forward_derivatives=forward_derivatives, 721*da0073e9SAndroid Build Coastguard Worker all_saved_inputs=dedup_vars( 722*da0073e9SAndroid Build Coastguard Worker [v for d in derivatives for v in d.saved_inputs] 723*da0073e9SAndroid Build Coastguard Worker ), 724*da0073e9SAndroid Build Coastguard Worker all_saved_outputs=dedup_vars( 725*da0073e9SAndroid Build Coastguard Worker [v for d in derivatives for v in d.saved_outputs] 726*da0073e9SAndroid Build Coastguard Worker ), 727*da0073e9SAndroid Build Coastguard Worker available_named_gradients=available_named_gradients, 728*da0073e9SAndroid Build Coastguard Worker used_named_gradients=used_named_gradients, 729*da0073e9SAndroid Build Coastguard Worker args_with_derivatives=args_with_derivatives, 730*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names=non_differentiable_arg_names, 731*da0073e9SAndroid Build Coastguard Worker output_differentiability=output_differentiability, 732*da0073e9SAndroid Build Coastguard Worker output_differentiability_conditions=output_differentiability_conditions, 733*da0073e9SAndroid Build Coastguard Worker ) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker return canonical.func, diffinfo_dict 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard WorkerGRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Workerdef used_gradient_indices(formula: str) -> list[int]: 742*da0073e9SAndroid Build Coastguard Worker """Determine a list of gradient indices (the i in grads[i]) that 743*da0073e9SAndroid Build Coastguard Worker are used by the formula. 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker >>> used_gradient_indices("foo(grads[0], grads[1])") 746*da0073e9SAndroid Build Coastguard Worker [0, 1] 747*da0073e9SAndroid Build Coastguard Worker """ 748*da0073e9SAndroid Build Coastguard Worker return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)] 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Workerdef saved_variables( 752*da0073e9SAndroid Build Coastguard Worker formula: str, 753*da0073e9SAndroid Build Coastguard Worker nctypes: list[NamedCType], 754*da0073e9SAndroid Build Coastguard Worker var_names: tuple[str, ...], 755*da0073e9SAndroid Build Coastguard Worker) -> tuple[str, tuple[SavedAttribute, ...]]: 756*da0073e9SAndroid Build Coastguard Worker def stride_expr(name: str) -> str: 757*da0073e9SAndroid Build Coastguard Worker assert var_names == (name,), ( 758*da0073e9SAndroid Build Coastguard Worker 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 759*da0073e9SAndroid Build Coastguard Worker 'that ".strides()" is being called on.' 760*da0073e9SAndroid Build Coastguard Worker ) 761*da0073e9SAndroid Build Coastguard Worker return f'strides_or_error({name}, "{name}")' 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [ 764*da0073e9SAndroid Build Coastguard Worker # replace self.sym_sizes() with self_sym_sizes 765*da0073e9SAndroid Build Coastguard Worker ( 766*da0073e9SAndroid Build Coastguard Worker r"{}.sym_sizes\(\)", 767*da0073e9SAndroid Build Coastguard Worker { 768*da0073e9SAndroid Build Coastguard Worker "suffix": "_sym_sizes", 769*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), 770*da0073e9SAndroid Build Coastguard Worker }, 771*da0073e9SAndroid Build Coastguard Worker ), 772*da0073e9SAndroid Build Coastguard Worker # replace self->sym_sizes() with self_sym_sizes_opt 773*da0073e9SAndroid Build Coastguard Worker ( 774*da0073e9SAndroid Build Coastguard Worker r"{}->sym_sizes\(\)", 775*da0073e9SAndroid Build Coastguard Worker { 776*da0073e9SAndroid Build Coastguard Worker "suffix": "_sym_sizes_opt", 777*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType( 778*da0073e9SAndroid Build Coastguard Worker name, OptionalCType(BaseCType(symIntArrayRefT)) 779*da0073e9SAndroid Build Coastguard Worker ), 780*da0073e9SAndroid Build Coastguard Worker "expr": lambda name: f"{name}.has_value() ? std::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : std::nullopt", 781*da0073e9SAndroid Build Coastguard Worker }, 782*da0073e9SAndroid Build Coastguard Worker ), 783*da0073e9SAndroid Build Coastguard Worker # replace self.sym_blocksize() with self_sym_blocksize_opt 784*da0073e9SAndroid Build Coastguard Worker ( 785*da0073e9SAndroid Build Coastguard Worker r"{}.sym_blocksize\(\)", 786*da0073e9SAndroid Build Coastguard Worker { 787*da0073e9SAndroid Build Coastguard Worker "suffix": "_self_sym_blocksize_opt", 788*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType( 789*da0073e9SAndroid Build Coastguard Worker name, OptionalCType(BaseCType(symIntArrayRefT)) 790*da0073e9SAndroid Build Coastguard Worker ), 791*da0073e9SAndroid Build Coastguard Worker "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})", 792*da0073e9SAndroid Build Coastguard Worker }, 793*da0073e9SAndroid Build Coastguard Worker ), 794*da0073e9SAndroid Build Coastguard Worker # replace self.options() with self_options 795*da0073e9SAndroid Build Coastguard Worker ( 796*da0073e9SAndroid Build Coastguard Worker r"{}.options\(\)", 797*da0073e9SAndroid Build Coastguard Worker { 798*da0073e9SAndroid Build Coastguard Worker "suffix": "_options", 799*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), 800*da0073e9SAndroid Build Coastguard Worker }, 801*da0073e9SAndroid Build Coastguard Worker ), 802*da0073e9SAndroid Build Coastguard Worker # replace zeros_like(self) with self_info 803*da0073e9SAndroid Build Coastguard Worker ( 804*da0073e9SAndroid Build Coastguard Worker r"zeros_like\({}\)", 805*da0073e9SAndroid Build Coastguard Worker { 806*da0073e9SAndroid Build Coastguard Worker "suffix": "_info", 807*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), 808*da0073e9SAndroid Build Coastguard Worker "expr": lambda name: name, # at save-time 809*da0073e9SAndroid Build Coastguard Worker "res": lambda name: name + "_info.zeros()", # at eval-time 810*da0073e9SAndroid Build Coastguard Worker }, 811*da0073e9SAndroid Build Coastguard Worker ), 812*da0073e9SAndroid Build Coastguard Worker # replace self.sym_size(2) with self_sym_size_2 813*da0073e9SAndroid Build Coastguard Worker ( 814*da0073e9SAndroid Build Coastguard Worker r"{}.sym_size\((-?\w+)\)", 815*da0073e9SAndroid Build Coastguard Worker { 816*da0073e9SAndroid Build Coastguard Worker "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}", 817*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), 818*da0073e9SAndroid Build Coastguard Worker }, 819*da0073e9SAndroid Build Coastguard Worker ), 820*da0073e9SAndroid Build Coastguard Worker # replace self.numel() with self_numel 821*da0073e9SAndroid Build Coastguard Worker ( 822*da0073e9SAndroid Build Coastguard Worker r"{}.numel\(\)", 823*da0073e9SAndroid Build Coastguard Worker { 824*da0073e9SAndroid Build Coastguard Worker "suffix": "_numel", 825*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(longT)), 826*da0073e9SAndroid Build Coastguard Worker }, 827*da0073e9SAndroid Build Coastguard Worker ), 828*da0073e9SAndroid Build Coastguard Worker # replace self.sym_numel() with self_sym_numel 829*da0073e9SAndroid Build Coastguard Worker ( 830*da0073e9SAndroid Build Coastguard Worker r"{}.sym_numel\(\)", 831*da0073e9SAndroid Build Coastguard Worker { 832*da0073e9SAndroid Build Coastguard Worker "suffix": "_sym_numel", 833*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), 834*da0073e9SAndroid Build Coastguard Worker }, 835*da0073e9SAndroid Build Coastguard Worker ), 836*da0073e9SAndroid Build Coastguard Worker # replace to_args_sizes(self) with self_args_sizes 837*da0073e9SAndroid Build Coastguard Worker ( 838*da0073e9SAndroid Build Coastguard Worker r"to_args_sizes\({}\)", 839*da0073e9SAndroid Build Coastguard Worker { 840*da0073e9SAndroid Build Coastguard Worker "suffix": "_args_sizes", 841*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType( 842*da0073e9SAndroid Build Coastguard Worker name, VectorCType(VectorCType(BaseCType(longT))) 843*da0073e9SAndroid Build Coastguard Worker ), 844*da0073e9SAndroid Build Coastguard Worker }, 845*da0073e9SAndroid Build Coastguard Worker ), 846*da0073e9SAndroid Build Coastguard Worker # replace to_args_sizes_symint(self) with self_args_sizes 847*da0073e9SAndroid Build Coastguard Worker ( 848*da0073e9SAndroid Build Coastguard Worker r"to_args_sizes_symint\({}\)", 849*da0073e9SAndroid Build Coastguard Worker { 850*da0073e9SAndroid Build Coastguard Worker "suffix": "_args_sizes_symint", 851*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType( 852*da0073e9SAndroid Build Coastguard Worker name, VectorCType(VectorCType(BaseCType(SymIntT))) 853*da0073e9SAndroid Build Coastguard Worker ), 854*da0073e9SAndroid Build Coastguard Worker }, 855*da0073e9SAndroid Build Coastguard Worker ), 856*da0073e9SAndroid Build Coastguard Worker # replace to_args_scalartypes(self) with self_args_scalartypes 857*da0073e9SAndroid Build Coastguard Worker ( 858*da0073e9SAndroid Build Coastguard Worker r"to_args_scalartypes\({}\)", 859*da0073e9SAndroid Build Coastguard Worker { 860*da0073e9SAndroid Build Coastguard Worker "suffix": "_args_scalartypes", 861*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType( 862*da0073e9SAndroid Build Coastguard Worker name, VectorCType(BaseCType(scalarTypeT)) 863*da0073e9SAndroid Build Coastguard Worker ), 864*da0073e9SAndroid Build Coastguard Worker }, 865*da0073e9SAndroid Build Coastguard Worker ), 866*da0073e9SAndroid Build Coastguard Worker # replace TensorGeometry(self) with self_geometry 867*da0073e9SAndroid Build Coastguard Worker ( 868*da0073e9SAndroid Build Coastguard Worker r"TensorGeometry\({}\)", 869*da0073e9SAndroid Build Coastguard Worker { 870*da0073e9SAndroid Build Coastguard Worker "suffix": "_geometry", 871*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), 872*da0073e9SAndroid Build Coastguard Worker }, 873*da0073e9SAndroid Build Coastguard Worker ), 874*da0073e9SAndroid Build Coastguard Worker ( 875*da0073e9SAndroid Build Coastguard Worker r"{}.scalar_type\(\)", 876*da0073e9SAndroid Build Coastguard Worker { 877*da0073e9SAndroid Build Coastguard Worker "suffix": "_scalar_type", 878*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), 879*da0073e9SAndroid Build Coastguard Worker }, 880*da0073e9SAndroid Build Coastguard Worker ), 881*da0073e9SAndroid Build Coastguard Worker # replace self.dim() with self_dim 882*da0073e9SAndroid Build Coastguard Worker ( 883*da0073e9SAndroid Build Coastguard Worker r"{}.dim\(\)", 884*da0073e9SAndroid Build Coastguard Worker { 885*da0073e9SAndroid Build Coastguard Worker "suffix": "_dim", 886*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(longT)), 887*da0073e9SAndroid Build Coastguard Worker }, 888*da0073e9SAndroid Build Coastguard Worker ), 889*da0073e9SAndroid Build Coastguard Worker # replace self.sym_strides() with self_sym_strides 890*da0073e9SAndroid Build Coastguard Worker ( 891*da0073e9SAndroid Build Coastguard Worker r"{}.sym_strides\(\)", 892*da0073e9SAndroid Build Coastguard Worker { 893*da0073e9SAndroid Build Coastguard Worker "suffix": "_sym_strides", 894*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), 895*da0073e9SAndroid Build Coastguard Worker "expr": stride_expr, 896*da0073e9SAndroid Build Coastguard Worker }, 897*da0073e9SAndroid Build Coastguard Worker ), 898*da0073e9SAndroid Build Coastguard Worker # replace self.layout() with self_layout 899*da0073e9SAndroid Build Coastguard Worker ( 900*da0073e9SAndroid Build Coastguard Worker r"{}.layout\(\)", 901*da0073e9SAndroid Build Coastguard Worker { 902*da0073e9SAndroid Build Coastguard Worker "suffix": "_layout", 903*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), 904*da0073e9SAndroid Build Coastguard Worker }, 905*da0073e9SAndroid Build Coastguard Worker ), 906*da0073e9SAndroid Build Coastguard Worker # replace self.is_conj() with self_conjugate 907*da0073e9SAndroid Build Coastguard Worker ( 908*da0073e9SAndroid Build Coastguard Worker r"{}.is_conj\(\)", 909*da0073e9SAndroid Build Coastguard Worker { 910*da0073e9SAndroid Build Coastguard Worker "suffix": "_conjugate", 911*da0073e9SAndroid Build Coastguard Worker "nctype": lambda name: NamedCType(name, BaseCType(boolT)), 912*da0073e9SAndroid Build Coastguard Worker }, 913*da0073e9SAndroid Build Coastguard Worker ), 914*da0073e9SAndroid Build Coastguard Worker ] 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker # find which arguments need to be saved 917*da0073e9SAndroid Build Coastguard Worker saved: list[SavedAttribute] = [] 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker if ".sizes()" in formula or "->sizes()" in formula: 920*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 921*da0073e9SAndroid Build Coastguard Worker ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version," 922*da0073e9SAndroid Build Coastguard Worker + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}" 923*da0073e9SAndroid Build Coastguard Worker ) 924*da0073e9SAndroid Build Coastguard Worker if re.search(r"\.size\([-]?\d+\)", formula) or re.search( 925*da0073e9SAndroid Build Coastguard Worker r"->size\([-]?\d+\)", formula 926*da0073e9SAndroid Build Coastguard Worker ): 927*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 928*da0073e9SAndroid Build Coastguard Worker ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version," 929*da0073e9SAndroid Build Coastguard Worker + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}" 930*da0073e9SAndroid Build Coastguard Worker ) 931*da0073e9SAndroid Build Coastguard Worker if ".strides()" in formula or "->strides()" in formula: 932*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 933*da0073e9SAndroid Build Coastguard Worker ".strides() is not supported in derivative formulas. Instead, please use the SymInt version," 934*da0073e9SAndroid Build Coastguard Worker + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}" 935*da0073e9SAndroid Build Coastguard Worker ) 936*da0073e9SAndroid Build Coastguard Worker for nctype in nctypes: 937*da0073e9SAndroid Build Coastguard Worker name = ( 938*da0073e9SAndroid Build Coastguard Worker nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name 939*da0073e9SAndroid Build Coastguard Worker ) 940*da0073e9SAndroid Build Coastguard Worker # First search the formula for expressions which can be evaluated 941*da0073e9SAndroid Build Coastguard Worker # when the autograd Function is created to avoid saving variables 942*da0073e9SAndroid Build Coastguard Worker for regex, info in REPLACEMENTS: 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker def repl(m: re.Match[str]) -> str: 945*da0073e9SAndroid Build Coastguard Worker suffix: str = ( 946*da0073e9SAndroid Build Coastguard Worker info["suffix"](m) if callable(info["suffix"]) else info["suffix"] 947*da0073e9SAndroid Build Coastguard Worker ) 948*da0073e9SAndroid Build Coastguard Worker expr: str = info["expr"](name) if "expr" in info else m.group(0) 949*da0073e9SAndroid Build Coastguard Worker saved.append( 950*da0073e9SAndroid Build Coastguard Worker SavedAttribute( 951*da0073e9SAndroid Build Coastguard Worker nctype=info["nctype"](name + suffix), 952*da0073e9SAndroid Build Coastguard Worker expr=expr, 953*da0073e9SAndroid Build Coastguard Worker ) 954*da0073e9SAndroid Build Coastguard Worker ) 955*da0073e9SAndroid Build Coastguard Worker if "res" in info: 956*da0073e9SAndroid Build Coastguard Worker replacement: str = info["res"](name) 957*da0073e9SAndroid Build Coastguard Worker return replacement 958*da0073e9SAndroid Build Coastguard Worker return name + suffix 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker formula = re.sub(regex.format(name), repl, formula) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker # std::optional<std::string> types stored in Backward nodes must be 963*da0073e9SAndroid Build Coastguard Worker # converted to std::optional<std::string_view> before being passed into 964*da0073e9SAndroid Build Coastguard Worker # the backward function 965*da0073e9SAndroid Build Coastguard Worker if nctype.type == OptionalCType(BaseCType(stringT)): 966*da0073e9SAndroid Build Coastguard Worker formula = re.sub( 967*da0073e9SAndroid Build Coastguard Worker rf"\b{name}\b", 968*da0073e9SAndroid Build Coastguard Worker f"{name}.has_value() ? std::optional<c10::string_view>({name}.value()) : std::nullopt", 969*da0073e9SAndroid Build Coastguard Worker formula, 970*da0073e9SAndroid Build Coastguard Worker ) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker # Find any variables which remain in the formula and save them 973*da0073e9SAndroid Build Coastguard Worker if re.search(IDENT_REGEX.format(name), formula): 974*da0073e9SAndroid Build Coastguard Worker saved.append( 975*da0073e9SAndroid Build Coastguard Worker SavedAttribute( 976*da0073e9SAndroid Build Coastguard Worker nctype=nctype, 977*da0073e9SAndroid Build Coastguard Worker expr=name, 978*da0073e9SAndroid Build Coastguard Worker ) 979*da0073e9SAndroid Build Coastguard Worker ) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker return formula, tuple(saved) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Workerdef _create_op_prefix(name: str) -> str: 985*da0073e9SAndroid Build Coastguard Worker """Takes a native function name converts to a op prefix name. 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker Note that the "name" parameter must be the native function name 988*da0073e9SAndroid Build Coastguard Worker without the optional variant suffix, so "add" instead of 989*da0073e9SAndroid Build Coastguard Worker "add.out". 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker OP names correspond to classes, hence the change to title case. 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker Example:: 994*da0073e9SAndroid Build Coastguard Worker >>> _create_op_prefix('add') 995*da0073e9SAndroid Build Coastguard Worker 'AddBackward' 996*da0073e9SAndroid Build Coastguard Worker """ 997*da0073e9SAndroid Build Coastguard Worker camel_case = "".join([p.title() for p in name.split("_")]) 998*da0073e9SAndroid Build Coastguard Worker return (camel_case + "Backward").replace("ForwardBackward", "Backward") 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Workerdef dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: 1002*da0073e9SAndroid Build Coastguard Worker seen: set[str] = set() 1003*da0073e9SAndroid Build Coastguard Worker saved: list[SavedAttribute] = [] 1004*da0073e9SAndroid Build Coastguard Worker for var in vars: 1005*da0073e9SAndroid Build Coastguard Worker name = ( 1006*da0073e9SAndroid Build Coastguard Worker var.nctype.name.name 1007*da0073e9SAndroid Build Coastguard Worker if isinstance(var.nctype.name, SpecialArgName) 1008*da0073e9SAndroid Build Coastguard Worker else var.nctype.name 1009*da0073e9SAndroid Build Coastguard Worker ) 1010*da0073e9SAndroid Build Coastguard Worker if name in seen: 1011*da0073e9SAndroid Build Coastguard Worker continue 1012*da0073e9SAndroid Build Coastguard Worker seen.add(name) 1013*da0073e9SAndroid Build Coastguard Worker saved.append(var) 1014*da0073e9SAndroid Build Coastguard Worker return saved 1015