1from __future__ import annotations 2 3from torchgen.api import dispatcher 4from torchgen.api.types import ( 5 BaseCppType, 6 BaseCType, 7 Binding, 8 boolT, 9 ConstRefCType, 10 CType, 11 longT, 12 NamedCType, 13 tensorT, 14) 15from torchgen.model import ( 16 Argument, 17 BaseTy, 18 BaseType, 19 FunctionSchema, 20 NativeFunction, 21 NativeFunctionsViewGroup, 22) 23 24 25# This file describes the translation of JIT schema to API's used 26# when creating view lambdas that are used by the functionalization pass. 27# There are two types of lambdas: forward lambdas and reverse lambdas. 28# These API's mostly follow the dispatcher API, with a few quirks: 29# - The lambda capture has to convert reference types to value types 30# - While the forward lambda just directly calls into the at::_ops API 31# (following the dispatcher convention), the logic here for the reverse lambda 32# is responsible for generating both the call-site, and the declarations 33# (which are implemented manually in the at::functionalization::impl namespace). 34 35# The lambdas generated for each view op in the functionalization pass are of the form 36# [capture_arguments](outer_arguments) -> returns_type { 37# return name(inner_arguments); 38# } 39 40# Define some specific lambda input arguments. 41base_binding = Binding( 42 name="base", 43 nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), 44 argument=Argument( 45 name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None 46 ), 47 default=None, 48) 49mutated_view_binding = Binding( 50 name="mutated_view", 51 nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), 52 argument=Argument( 53 name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None 54 ), 55 default=None, 56) 57mutated_view_idx_binding = Binding( 58 name="mutated_view_idx", 59 nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), 60 argument=Argument( 61 name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None 62 ), 63 default=None, 64) 65reapply_views_binding = Binding( 66 name="reapply_views", 67 nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), 68 argument=Argument( 69 name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None 70 ), 71 default=None, 72) 73 74InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") 75inverse_return_mode_binding = Binding( 76 name="inverse_return_mode", 77 nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), 78 argument=Argument( 79 name="inverse_return_mode", 80 # NB: not actually a bool but it doesn't matter because this isn't used 81 type=BaseType(BaseTy.bool), 82 default=None, 83 annotation=None, 84 ), 85 default=None, 86) 87 88 89# The lambda capture itself doesn't have a name. 90# The name returned here corresponds to the name of the inner function called by the lambda. 91def name( 92 g: NativeFunctionsViewGroup, 93 *, 94 is_reverse: bool, 95 include_namespace: bool, 96 reapply_views: bool | None = None, 97) -> str: 98 if reapply_views is None: 99 # reapply_views is only important for the fwd lambda, 100 # since we always plumb the runtime "reapply_views" argument into the reverse function. 101 assert is_reverse 102 if is_reverse: 103 return reverse_name(g.view, include_namespace) 104 # in the forward case, we just directly call into the at::_ops API (so we always need the namespace) 105 assert include_namespace 106 assert g.view_copy is not None 107 api_name = ( 108 g.view.func.name.unambiguous_name() 109 if reapply_views 110 else g.view_copy.func.name.unambiguous_name() 111 ) 112 return f"at::_ops::{api_name}::call" 113 114 115def reverse_name(f: NativeFunction, include_namespace: bool) -> str: 116 # for the reverse: we plumb the "reapply_views" flag into that function and support 117 # both copy and non-copy variants. (We could avoid doing that, but that would require 118 # writing out twice as many view inverse functions). 119 api_name = f.func.name.unambiguous_name() 120 # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) 121 if include_namespace: 122 return f"at::functionalization::FunctionalInverses::{api_name}_inverse" 123 else: 124 return f"{api_name}_inverse" 125 126 127def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: 128 # capture arguments include all arguments except `self`. 129 # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), 130 # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>) 131 args = func.arguments.flat_all 132 assert args[0].type == BaseType(BaseTy.Tensor) 133 non_self_args = args[1:] 134 non_self_value_bindings = [ 135 dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args 136 ] 137 138 all_bindings = [ 139 inverse_return_mode_binding if is_reverse else reapply_views_binding 140 ] 141 all_bindings.extend(non_self_value_bindings) 142 return all_bindings 143 144 145def returns_type(func: FunctionSchema) -> CType: 146 # Assertion: all view ops return tensor-like outputs 147 assert len(func.returns) >= 1 148 for ret in func.returns: 149 assert ret.type.is_tensor_like() 150 # However, the return type of the lambda is always an individual tensor. 151 # For multi-tensor outputs, each tensor needs to be tracked individually. 152 return BaseCType(tensorT) 153 154 155def outer_arguments(*, is_reverse: bool) -> list[Binding]: 156 if is_reverse: 157 return [base_binding, mutated_view_binding, mutated_view_idx_binding] 158 else: 159 return [base_binding, mutated_view_idx_binding] 160 161 162def inner_call_index(func: FunctionSchema) -> Binding | None: 163 # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. 164 # When we replay a view op that returns multiple tensors, we need to index into the output appropriately 165 if len(func.returns) > 1 or ( 166 len(func.returns) == 1 and func.returns[0].type.is_list_like() 167 ): 168 return mutated_view_idx_binding 169 return None 170 171 172def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: 173 args = func.arguments.flat_all 174 assert args[0].type == BaseType(BaseTy.Tensor) 175 non_self_args = args[1:] 176 # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. 177 # Both of these follow the dispatcher API. 178 non_self_bindings = [dispatcher.argument(a) for a in non_self_args] 179 if not is_reverse: 180 # the forward lambda swaps out the original tensor argument with the lambd arg "base" 181 return [base_binding] + non_self_bindings 182 else: 183 # the reverse lambda does the same, but with an additional "mutated_view" arg 184 # additionally, we have a calling convention: for view ops that return multiple tensor outputs 185 # their corresponding view_inverse function takes in an additional index argument. 186 index_binding = inner_call_index(func) 187 if index_binding is not None: 188 return [ 189 base_binding, 190 mutated_view_binding, 191 inverse_return_mode_binding, 192 index_binding, 193 ] + non_self_bindings 194 else: 195 return [ 196 base_binding, 197 mutated_view_binding, 198 inverse_return_mode_binding, 199 ] + non_self_bindings 200