1from __future__ import annotations 2 3from collections import defaultdict 4from typing import Sequence 5 6import torchgen.api.dispatcher as dispatcher 7from torchgen.api.translate import translate 8from torchgen.api.types import Binding, DispatcherSignature, Expr 9from torchgen.context import with_native_function 10from torchgen.model import ( 11 Annotation, 12 Argument, 13 BackendIndex, 14 BackendMetadata, 15 BaseOperatorName, 16 BaseTy, 17 BaseType, 18 DEFAULT_KERNEL_NAMESPACE, 19 DeviceCheckType, 20 DispatchKey, 21 FunctionSchema, 22 NativeFunction, 23 NativeFunctionsGroup, 24 OperatorName, 25 Return, 26 SchemaKind, 27 Variant, 28) 29from torchgen.utils import concatMap 30 31 32# See Note: [Out ops with functional variants that don't get grouped properly] 33OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ 34 # This has a functional variant, but it's currently marked private. 35 # This function should be marked private as well (*_backward ops aren't exposed to python anyway). 36 "adaptive_avg_pool3d_backward.grad_input", 37 # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly. 38 # Maybe we can kill this operator in favor of convolution_backward? 39 "_slow_conv2d_backward.grad_input", 40] 41 42 43# See Note: [Mutable ops that cannot get an out variant] 44MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ 45 # should be out=? 46 "_cummax_helper", 47 # should be out=? 48 "_cummin_helper", 49] 50 51# All of these operators don't have any tensor like returns 52FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ 53 "_assert_async", # no return 54 "_assert_async.msg", # no return 55 "_cslt_sparse_mm_search", # returns an int 56 "_assert_scalar", # no return 57 "_dimI", # returns an int 58 "_dimV", # returns an int 59 "_has_same_storage_numel", # returns a boolean 60 "_linalg_check_errors", # no return 61 "_local_scalar_dense", # returns a Scalar 62 "_nested_tensor_from_mask_left_aligned", # returns a boolean 63 "_nnz", # returns an int 64 "_use_cudnn_ctc_loss", # returns a boolean 65 "_use_cudnn_ctc_loss.Tensor", # returns a boolean 66 "_validate_compressed_sparse_indices", # no return 67 "allclose", # returns a boolean 68 "dense_dim", # returns an int 69 "equal", # returns a boolean 70 "is_coalesced", # returns an boolean 71 "is_pinned", # returns a boolean 72 "is_same_size", # returns a boolean 73 "is_set_to", # returns a boolean 74 "q_per_channel_axis", # returns an int 75 "q_scale", # returns a float 76 "q_zero_point", # returns an int 77 "qscheme", # returns a QScheme 78 "record_stream", # no return 79 "sparse_dim", # returns an int 80 "sym_constrain_range", # no return 81 "sym_constrain_range_for_size", # no return 82 "_nested_tensor_storage_offsets", # returns a vector of ints 83 "_chunk_grad_outputs_efficient_attention", # returns a bool 84 "_fused_sdp_choice", # returns an int 85 "_print", # no return 86 "_sink_tokens", # no return 87 "_nested_get_ragged_idx", # returns an int 88] 89 90INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ 91 # polygamma and polygamma.out both exist, but have a 92 # pre-self arg (while polygamma_ does not) 93 # We should either fix this schema so it can be grouped properly, 94 # or allow the codegen to generate new functional/out= NativeFunctions for this op 95 # (which would require changing its overload name to prevent overload ambiguity). 96 "polygamma_" 97] 98 99 100# Groups "similar" NativeFunctions together 101# example add.Tensor, add_.Tensor, add.out 102# "similar" NativeFunctions are all expected to have an identical `signature()`, 103# But have differing SchemaKinds. 104def pre_group_native_functions( 105 native_functions: Sequence[NativeFunction], 106) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]: 107 pre_grouped_native_functions: dict[ 108 FunctionSchema, dict[SchemaKind, NativeFunction] 109 ] = defaultdict(dict) 110 for f in native_functions: 111 d = pre_grouped_native_functions[f.func.signature()] 112 assert f.func.kind() not in d 113 d[f.func.kind()] = f 114 return pre_grouped_native_functions 115 116 117# Returns the out variant overload name given a base function overload name 118def get_expected_out_variant_overload_name(overload_name: str | None) -> str: 119 return "out" if not overload_name else f"{overload_name}_out" 120 121 122# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant 123# Example before: 124# _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) 125# Example after: 126# _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) 127def self_to_out_signature(func: FunctionSchema) -> FunctionSchema: 128 # Generating an out= schema from an inplace schema. 129 assert func.kind() == SchemaKind.inplace 130 assert func.arguments.self_arg is not None 131 # The new out= schema has: 132 # - a new out argument with the same type as "func" (but with a mutable annotation) 133 # - The returns (if any) now alias the out= argument instead of "func" 134 # - an "out" overload name 135 return FunctionSchema( 136 name=func.name.remove_inplace().with_overload( 137 get_expected_out_variant_overload_name(func.name.overload_name) 138 ), 139 arguments=func.arguments.remove_self_annotation().with_out_args( 140 [ 141 Argument( 142 name="out", 143 type=func.arguments.self_arg.argument.type, 144 default=None, 145 annotation=func.arguments.self_arg.argument.annotation, 146 ) 147 ] 148 ), 149 returns=func.returns, 150 ) 151 152 153# Helper function: given a functional FunctionSchema, generate its corresponding out= variant 154# Example before: 155# _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 156# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 157# Example after: 158# _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, 159# Tensor(a!) out) -> Tensor(a!) 160def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema: 161 # Generating an out= schema from a functional schema. 162 assert func.kind() == SchemaKind.functional 163 164 new_returns, new_out_args = generate_out_args_from_schema(func) 165 # The new out= schema has: 166 # - one or more new out argument(s) with the same type as returns (but with a mutable annotation) 167 # - The returns now alias the out= arguments 168 # - an "_out" overload name 169 return FunctionSchema( 170 name=func.name.with_overload( 171 get_expected_out_variant_overload_name(func.name.overload_name) 172 ), 173 arguments=func.arguments.signature().with_out_args( 174 new_out_args, 175 ), 176 returns=tuple(new_returns), 177 ) 178 179 180# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations. 181def generate_out_args_from_schema( 182 func: FunctionSchema, 183) -> tuple[list[Return], list[Argument]]: 184 # More of a sanity check - our existing restrictions on schemas should enforce that 185 # mutable schema kinds never return their mutable arguments. 186 assert not any( 187 r.annotation is not None and r.annotation.is_write for r in func.returns 188 ) 189 190 tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()] 191 assert len(tensorlike_rets) > 0 192 193 used_annotations = concatMap( 194 lambda a: [] if a.annotation is None else a.annotation.alias_set, 195 func.arguments.flat_all, 196 ) 197 valid_annotations = [ 198 x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations 199 ] 200 201 all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) 202 203 new_out_args: list[Argument] = [] 204 # The end result of new_returns is that: 205 # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. 206 # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). 207 new_returns: list[Return] = [] 208 for i, r in enumerate(func.returns): 209 if r.type.is_tensor_like(): 210 new_out = Argument( 211 name="out" if len(func.returns) == 1 else f"out{i}", 212 type=r.type, 213 default=None, 214 annotation=Annotation.parse(f"{valid_annotations[i]}!"), 215 ) 216 new_out_args.append(new_out) 217 if all_rets_are_tensors: 218 # The convention for out= schemas is that they only return their out arguments 219 # if the return is a plain Tensor (or if it's a tuple of plain Tensors) 220 new_ret = Return( 221 name=None, type=new_out.type, annotation=new_out.annotation 222 ) 223 new_returns.append(new_ret) 224 else: 225 new_returns.append(r) 226 return new_returns, new_out_args 227 228 229# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant 230# Example before: 231# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 232# Example after: 233# _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950 234def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: 235 # Generating an out= schema from a mutable schema. 236 assert func.kind() == SchemaKind.mutable 237 # The new out= schema has: 238 # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments 239 # (if the argument is a tensor then we also return it for method chaining, 240 # otherwise we return nothing) 241 # - an "out" overload name 242 # 243 # Note that: 244 # (1) This also means that we can *only* generate an out= variant from a mutable schema 245 # if the mutable schema has at least one tensor-like non-aliasing return. 246 # (2) The generated out= variant still has mutable positional arguments, 247 # but if necessary we could probably add another out= variant that also 248 # functionalizes the mutable arguments (a functional_out variant) 249 250 new_returns, new_out_args = generate_out_args_from_schema(func) 251 252 return FunctionSchema( 253 name=func.name.remove_inplace().with_overload( 254 get_expected_out_variant_overload_name(func.name.overload_name) 255 ), 256 arguments=func.arguments.with_out_args(new_out_args), 257 returns=tuple(new_returns), 258 ) 259 260 261# This function, given function of one SchemaKind, as well as a target SchemaKind, 262# generates a new NativeFunction with the same properties, but using the target SchemaKind. 263# We only actually generate functions for either functional or out= SchemaKinds. 264# This function returns a tuple, with: 265# - The generated NativeFunction 266# - a dictionary of `BackendIndex` objects, describing which dispatch keys 267# we will generate kernels for, for the new NativeFunction. 268# Details are in the function, but we only generate composite kernels (in some cases) today. 269def generate_function( 270 f: NativeFunction, k: SchemaKind 271) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: 272 from torchgen.api import cpp 273 274 if k == SchemaKind.functional: 275 assert f.func.kind() != SchemaKind.functional 276 # The new "functional" NativeFunction has: 277 # - any mutable arguments have been converted into (immutable) returns. 278 # (if a mutable argument was not also a return, it gets converted to one) 279 # - "_functional" appended to the base name, ONLY IF this op has a mutable variant. 280 # See Note [Overload Ambiguity With Functional Variants] 281 # The default grouping logic in signature() actually already does this, 282 # so we can piggy-back off it (but we still want return names) 283 func = f.func.signature(keep_return_names=True).with_name( 284 OperatorName( 285 name=BaseOperatorName( 286 base=f.func.name.name.base, 287 inplace=False, 288 dunder_method=f.func.name.name.dunder_method, 289 # See Note [Overload Ambiguity With Functional Variants] 290 functional_overload=f.func.kind() == SchemaKind.mutable, 291 ), 292 overload_name=f.func.name.overload_name, 293 ) 294 ) 295 elif k == SchemaKind.out: 296 # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily, 297 # but at least today, there is no good reason to actually use them. 298 # we'll generate a dispatcher entry for them, but won't actually register any kernels for them. 299 if f.func.kind() == SchemaKind.inplace: 300 func = self_to_out_signature(f.func) 301 elif f.func.kind() == SchemaKind.mutable: 302 func = mutable_to_out_signature(f.func) 303 elif f.func.kind() == SchemaKind.functional: 304 func = functional_to_out_signature(f.func) 305 else: 306 raise AssertionError( 307 "We only bother generating out= functions from either inplace or mutable or functional variants" 308 ) 309 else: 310 raise AssertionError( 311 "We currently only generate either functional or out= NativeFunctions" 312 ) 313 314 # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to 315 # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and 316 # `randn.generator_with_names_out`. 317 kernel_name = ( 318 func.name.unambiguous_name() 319 if func.kind() == SchemaKind.out 320 else cpp.name(func) 321 ) 322 if f.func.has_symint(): 323 kernel_name += "_symint" 324 backend_metadata = { 325 DispatchKey.CompositeExplicitAutograd: { 326 func.name: BackendMetadata( 327 kernel=kernel_name, 328 structured=False, 329 cpp_namespace=DEFAULT_KERNEL_NAMESPACE, 330 ) 331 } 332 } 333 tags = {"generated"} | set( 334 f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"} 335 ) 336 337 return ( 338 NativeFunction( 339 func=func, 340 use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, 341 # These generated fn's aren't meant to be user friendly- don't generate methods. 342 variants={Variant.function}, 343 structured=False, 344 structured_delegate=None, 345 structured_inherits=None, 346 precomputed=None, 347 autogen=[], 348 ufunc_inner_loop={}, 349 manual_kernel_registration=False, 350 manual_cpp_binding=False, 351 python_module=None, 352 category_override=None, 353 device_guard=False, 354 device_check=DeviceCheckType.NoCheck, 355 loc=f.loc, 356 cpp_no_default_args=set(), 357 is_abstract=f.is_abstract, 358 has_composite_implicit_autograd_kernel=False, 359 has_composite_implicit_autograd_nested_tensor_kernel=False, 360 has_composite_explicit_autograd_kernel=True, 361 has_composite_explicit_autograd_non_functional_kernel=False, 362 # Every generated NativeFunction gets a "generated" tag, so it's easy to tell 363 # which NativeFunction objects did not come directly from native_functions.yaml. 364 tags=tags, 365 namespace=f.namespace, 366 ), 367 backend_metadata, 368 ) 369 370 371# This function is responsible for adding generated NativeFunctions which don't appear 372# explicitly in the codegen. 373# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running 374# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml") 375# (Maybe we should make a friendly API for this) 376# 377# Note: this function *mutates* its two inputs, 378# adding the new NativeFunctions / BackendMetadata to them 379def add_generated_native_functions( 380 rs: list[NativeFunction], 381 indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], 382) -> None: 383 # The main code for generating new NativeFunctions 384 # First we group of NativeFunctions by schema kind, 385 # then we detect which ones are missing and generate them. 386 pre_grouped_native_functions = pre_group_native_functions(rs) 387 for d in pre_grouped_native_functions.values(): 388 has_functional = SchemaKind.functional in d 389 has_inplace = SchemaKind.inplace in d 390 has_mutable = SchemaKind.mutable in d 391 has_out = SchemaKind.out in d 392 393 # We automatically generate a few native functions that don't exist in the yaml, for a few reasons: 394 # (1) If an operator has an inplace/out= variant but no functional variant, we can generate 395 # a simple functional variant that the functionalization pass can consume. 396 # (2) If an operator has an inplace or functional but no out= variant, we generate an out= 397 # variant, mostly so we can easily pair up functions into NativeFunctionsGroup, 398 # while maintaining the constraint that the out= variant is "required". 399 if has_mutable or has_inplace or has_out or has_functional: 400 # Don't bother generating functions trio's for native functions that bypass the dispatcher. 401 are_manual = all(f.manual_cpp_binding for f in d.values()) 402 # Don't bother generating functional + out= variants for view operators 403 # set_ is technically an inplace_view, but for now it is treated 404 # as a normal inplace op in the codegen 405 has_view_ops = any( 406 f.is_view_op and str(f.func.name.name) != "set_" for f in d.values() 407 ) 408 # Don't generate the other variants for CompositeImplicitAutograd operators. 409 # We could probably do this, but the main benefit of generating the function triplets 410 # is for transforms that need them, and transforms don't need to act directly 411 # on CompositeImplicitAutograd operators (since we let them decompose). 412 are_composite_implicit = all( 413 f.has_composite_implicit_autograd_kernel for f in d.values() 414 ) 415 if are_manual or has_view_ops or are_composite_implicit: 416 continue 417 if has_out and len(d.values()) == 1: 418 # Note: [Out ops with functional variants that don't get grouped properly] 419 # In theory we could validly have an out= operator in native_functions.yaml 420 # that has no other variants. 421 # But today, all of the operators where that's the case actually do have 422 # functional variants, that we are just unable to pair up properly. 423 # I think banning this all together is probably safer 424 # (you can always add a functional variant yourself if you want to add a new out= operator). 425 # 426 # We should probably fix the existing cases; this check is to prevent us from adding more over time. 427 if ( 428 str(d[SchemaKind.out].func.name) 429 not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY 430 ): 431 raise AssertionError( 432 f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}" 433 ) 434 continue 435 436 # Some inplace ops that have problematic schemas (that we should fix), which prevent us 437 # from generating out= and functional variants 438 if ( 439 has_inplace 440 and str(d[SchemaKind.inplace].func.name) 441 in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY 442 ): 443 continue 444 445 base_fn = ( 446 d[SchemaKind.inplace] 447 if has_inplace 448 else d[SchemaKind.mutable] 449 if has_mutable 450 else d[SchemaKind.out] 451 if has_out 452 else d[SchemaKind.functional] 453 ) 454 455 # Note: [Mutable ops that cannot get an out variant] 456 # We can only generate an out= variant if either: 457 # - the original function has tensor-like returns (since we can convert them to out kwargs) 458 # - or it's inplace (since we can convert `self` to an out kwarg) 459 # There are only two functions that don't fit this criteria today though, 460 # and they both look like they should be fixed to be out= variants, 461 # so if feels safer to ban this schema all-together 462 base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any( 463 r.type.is_tensor_like() for r in base_fn.func.returns 464 ) 465 # Note: [Loosen the assertion that all functional should have out variant] 466 # By design all functional operators should have our variants. The needs_out check 467 # is loosening this requirement, changing it to only generate out variant if there's 468 # an `autogen` block in the native function, in the long run it should be removed. 469 # FIXME: Remove this after figuring out CI job failures related to min, max, mean 470 needs_out = any("out" in str(op_name) for op_name in base_fn.autogen) 471 gets_out_variant = not has_out and base_fn_valid and needs_out 472 if not has_out and not base_fn_valid: 473 if ( 474 str(base_fn.func.name) 475 not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT 476 and str(base_fn.func.name) 477 not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT 478 ): 479 raise AssertionError( 480 f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}. 481This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If 482out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list.""" 483 ) 484 485 # Generate an out= variant 486 if gets_out_variant: 487 fn, metadata = generate_function(base_fn, SchemaKind.out) 488 d[SchemaKind.out] = fn 489 BackendIndex.grow_index(indices, metadata) 490 rs.append(fn) 491 492 # Generate a functional variant, but only do it if the operator got an out= variant 493 # (Functional variants are only useful if we can group up the variants, 494 # which we can only do if they have an out= variant) 495 if not has_functional and (has_out or gets_out_variant): 496 fn, metadata = generate_function(base_fn, SchemaKind.functional) 497 d[SchemaKind.functional] = fn 498 BackendIndex.grow_index(indices, metadata) 499 rs.append(fn) 500 501 502def return_str(rets: tuple[Return, ...], names: list[str]) -> str: 503 assert len(rets) == len(names) 504 if len(rets) == 0: 505 return "" 506 elif len(rets) == 1: 507 return f"return {names[0]};" 508 else: 509 return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});" 510 511 512# Given a function, and the name of a variable corresponding to the output of that function, 513# gather up all of the individual returns that are not aliased 514def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]: 515 aliased_rets = func.aliased_return_names() 516 non_aliased_names = [] 517 is_out_var_a_tuple = len(func.returns) > 1 518 for i, r in enumerate(aliased_rets): 519 if r is None: 520 non_aliased_names.append( 521 f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var 522 ) 523 return non_aliased_names 524 525 526# Generates functional kernels in terms of their inplace.mutable counterparts. 527# We only do this for "generated" NativeFunctions 528@with_native_function 529def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None: 530 # We should only be generating these for code-generated NativeFunctions 531 if "generated" not in g.functional.tags: 532 return None 533 # And we always write the kernel for a generated op in terms of a non-generated op. 534 if g.inplace is not None and "generated" not in g.inplace.tags: 535 target_f = g.inplace 536 elif g.mutable is not None and "generated" not in g.mutable.tags: 537 target_f = g.mutable 538 else: 539 # We should be guaranteed to have a valid inplace/mutable variant to call into. 540 # See Note: [Mutable Ops Not Using Functionalization] 541 raise AssertionError(str(g.functional.func)) 542 543 sig = DispatcherSignature(g.functional.func) 544 target_sig = DispatcherSignature(target_f.func) 545 546 context: list[Binding | Expr] = [] 547 clone_mutable_inputs = [] 548 cloned_return_names = [] 549 # We can't just directly pass all of the arguments from the functional op into the mutating op. 550 # We need to check for which inputs to the mutating operator are mutable, 551 # and clone those inputs first. 552 for a_curr, a_tgt in zip( 553 dispatcher.jit_arguments(g.functional.func), 554 dispatcher.jit_arguments(target_f.func), 555 ): 556 if a_tgt.annotation is not None and a_tgt.annotation.is_write: 557 clone_mutable_inputs.append( 558 f"auto {a_curr.name}_clone = clone_arg({a_curr.name});" 559 ) 560 context.append( 561 Expr( 562 expr=f"{a_curr.name}_clone", 563 type=dispatcher.argument_type(a_curr, binds=a_curr.name), 564 ) 565 ) 566 # Invariant: mutable arguments on the inner mutable op are always returns on the functional op. 567 cloned_return_names.append(f"{a_curr.name}_clone") 568 else: 569 context.append(dispatcher.argument(a_curr)) 570 exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())]) 571 572 out_name = "output" 573 maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else "" 574 inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name) 575 ret_str = return_str( 576 g.functional.func.returns, inner_return_names + cloned_return_names 577 ) 578 579 clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) 580 return f""" 581{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{ 582 {clone_mutable_inputs_str} 583 {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs}); 584 {ret_str} 585}} 586""" 587 588 589# Generates out= kernels in terms of their functional counterparts. 590# We only do this for "generated" NativeFunctions 591@with_native_function 592def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None: 593 # We should only be generating these for code-generated NativeFunctions 594 if "generated" not in g.out.tags: 595 return None 596 # And we always write the kernel for the out= op in terms of the functional. 597 # Note that the functional op might have also been generated, but we don't have to 598 # worry about cycles, because the generated functional kernels are always implemented 599 # in terms of non-generated kernels (see gen_composite_functional_kernel). 600 601 sig = DispatcherSignature(g.out.func) 602 target_sig = DispatcherSignature(g.functional.func) 603 604 exprs = ", ".join( 605 [e.expr for e in translate(sig.arguments(), target_sig.arguments())] 606 ) 607 608 copy_outs = [] 609 out_name = "tmp_output" 610 for i, out_arg in enumerate(g.out.func.arguments.out): 611 functional_return_name = ( 612 out_name 613 if len(g.functional.func.returns) == 1 614 else f"std::get<{i}>({out_name})" 615 ) 616 copy_outs.append( 617 f"""\ 618 resize_out_helper({out_arg.name}, {functional_return_name}); 619 copy_arg({out_arg.name}, {functional_return_name});""" 620 ) 621 622 rets = [] 623 # For each return arg in the calling (out=) operator, 624 # If it corresponds to an aliased input, return the input. 625 # Otherwise, return the corresponding output from calling the functional operator. 626 for i, ret_name in enumerate(g.out.func.aliased_return_names()): 627 if ret_name is not None: 628 rets.append(ret_name) 629 else: 630 functional_return_name = ( 631 out_name 632 if len(g.functional.func.returns) == 1 633 else f"std::get<{i}>({out_name})" 634 ) 635 rets.append(functional_return_name) 636 637 copy_outs_str = "\n".join(copy_outs) 638 639 # Kernel name needs to follow the naming convention defined in `generate_function()` 640 return f""" 641{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{ 642 auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs}); 643 {copy_outs_str} 644 {return_str(g.out.func.returns, rets)} 645}} 646""" 647