1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import Binding, CppSignature, CppSignatureGroup 8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import pythonify_default 9*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 10*da0073e9SAndroid Build Coastguard Worker Argument, 11*da0073e9SAndroid Build Coastguard Worker BaseTy, 12*da0073e9SAndroid Build Coastguard Worker BaseType, 13*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 14*da0073e9SAndroid Build Coastguard Worker ListType, 15*da0073e9SAndroid Build Coastguard Worker NativeFunction, 16*da0073e9SAndroid Build Coastguard Worker OptionalType, 17*da0073e9SAndroid Build Coastguard Worker Return, 18*da0073e9SAndroid Build Coastguard Worker Type, 19*da0073e9SAndroid Build Coastguard Worker Variant, 20*da0073e9SAndroid Build Coastguard Worker) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 24*da0073e9SAndroid Build Coastguard Worker# 25*da0073e9SAndroid Build Coastguard Worker# Data Models 26*da0073e9SAndroid Build Coastguard Worker# 27*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 28*da0073e9SAndroid Build Coastguard Worker# 29*da0073e9SAndroid Build Coastguard Worker# [Notes] python binding codegen 30*da0073e9SAndroid Build Coastguard Worker# 31*da0073e9SAndroid Build Coastguard Worker# The Python binding codegen produces code that takes the input list of 32*da0073e9SAndroid Build Coastguard Worker# PyObjects, finds the matching ATen C++ function using PythonArgParser, 33*da0073e9SAndroid Build Coastguard Worker# converts the PyObjects into C++ types and calls the ATen C++ function: 34*da0073e9SAndroid Build Coastguard Worker# 35*da0073e9SAndroid Build Coastguard Worker# +--------+ parsing +------------------------+ binding +-----------------------+ 36*da0073e9SAndroid Build Coastguard Worker# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch | 37*da0073e9SAndroid Build Coastguard Worker# +--------+ +------------------------+ +-----------------------+ 38*da0073e9SAndroid Build Coastguard Worker# 39*da0073e9SAndroid Build Coastguard Worker# The following examples demonstrate the data models the Python binding 40*da0073e9SAndroid Build Coastguard Worker# codegen needs to deal with and the tasks it needs to accomplish. It 41*da0073e9SAndroid Build Coastguard Worker# helps understand the purpose of the new data types we introduced below. 42*da0073e9SAndroid Build Coastguard Worker# 43*da0073e9SAndroid Build Coastguard Worker# - Function Schema (source of truth) 44*da0073e9SAndroid Build Coastguard Worker# 45*da0073e9SAndroid Build Coastguard Worker# aten::empty.names(int[] size, *, Dimname[]? names, 46*da0073e9SAndroid Build Coastguard Worker# ScalarType? dtype=None, Layout? layout=None, 47*da0073e9SAndroid Build Coastguard Worker# Device? device=None, bool? pin_memory=None, 48*da0073e9SAndroid Build Coastguard Worker# MemoryFormat? memory_format=None) -> Tensor 49*da0073e9SAndroid Build Coastguard Worker# 50*da0073e9SAndroid Build Coastguard Worker# - Python Signature 51*da0073e9SAndroid Build Coastguard Worker# 52*da0073e9SAndroid Build Coastguard Worker# It's used to generate input schema string for PythonArgParser. 53*da0073e9SAndroid Build Coastguard Worker# Note: TensorOptions fields are reordered and the additional 54*da0073e9SAndroid Build Coastguard Worker# 'requires_grad' field is added: 55*da0073e9SAndroid Build Coastguard Worker# 56*da0073e9SAndroid Build Coastguard Worker# empty(IntArrayRef size, *, DimnameList? names, 57*da0073e9SAndroid Build Coastguard Worker# MemoryFormat? memory_format=None, ScalarType dtype=None, 58*da0073e9SAndroid Build Coastguard Worker# Layout layout=torch.strided, Device device=None, 59*da0073e9SAndroid Build Coastguard Worker# bool pin_memory=False, bool requires_grad=False) 60*da0073e9SAndroid Build Coastguard Worker# 61*da0073e9SAndroid Build Coastguard Worker# - C++ Signature 62*da0073e9SAndroid Build Coastguard Worker# 63*da0073e9SAndroid Build Coastguard Worker# It's used to generate C++ lambda formals & dispatch call. 64*da0073e9SAndroid Build Coastguard Worker# Note: the scattered TensorOptions fields are packed into 'options'. 65*da0073e9SAndroid Build Coastguard Worker# 66*da0073e9SAndroid Build Coastguard Worker# auto dispatch_empty = 67*da0073e9SAndroid Build Coastguard Worker# [](IntArrayRef size, std::optional<DimnameList> names, 68*da0073e9SAndroid Build Coastguard Worker# const TensorOptions & options, 69*da0073e9SAndroid Build Coastguard Worker# std::optional<MemoryFormat> memory_format) -> Tensor { 70*da0073e9SAndroid Build Coastguard Worker# pybind11::gil_scoped_release no_gil; 71*da0073e9SAndroid Build Coastguard Worker# return torch::empty(size, names, options, memory_format); 72*da0073e9SAndroid Build Coastguard Worker# }; 73*da0073e9SAndroid Build Coastguard Worker# 74*da0073e9SAndroid Build Coastguard Worker# - Binding between Python Arguments and C++ Arguments 75*da0073e9SAndroid Build Coastguard Worker# 76*da0073e9SAndroid Build Coastguard Worker# Given a set of Python Arguments in scope, we need produce the 77*da0073e9SAndroid Build Coastguard Worker# binding expressions that translate the Python API into C++ API: 78*da0073e9SAndroid Build Coastguard Worker# 79*da0073e9SAndroid Build Coastguard Worker# Python Args Cpp Args Binding Exprs 80*da0073e9SAndroid Build Coastguard Worker# ----------------------------------------------------------------- 81*da0073e9SAndroid Build Coastguard Worker# 0: size size '_r.intlist(0)' 82*da0073e9SAndroid Build Coastguard Worker# 1: names names 'names' [special init] 83*da0073e9SAndroid Build Coastguard Worker# 2: memory_format -------+ 84*da0073e9SAndroid Build Coastguard Worker# 3: dtype -----+-|--> options 'options' [special packing] 85*da0073e9SAndroid Build Coastguard Worker# 4: layout / | 86*da0073e9SAndroid Build Coastguard Worker# 5: device / +--> memory_format '_r.memoryformatOptional(2)' 87*da0073e9SAndroid Build Coastguard Worker# 6: pin_memory / 88*da0073e9SAndroid Build Coastguard Worker# 7: requires_grad -+ 89*da0073e9SAndroid Build Coastguard Worker# 90*da0073e9SAndroid Build Coastguard Worker# So the full dispatch expression would look like: 91*da0073e9SAndroid Build Coastguard Worker# 92*da0073e9SAndroid Build Coastguard Worker# dispatch_empty(_r.intlist(0), names, options, 93*da0073e9SAndroid Build Coastguard Worker# _r.memoryformatOptional(2)) 94*da0073e9SAndroid Build Coastguard Worker# 95*da0073e9SAndroid Build Coastguard Worker# Where does 'names' come from? It involves special local init: 96*da0073e9SAndroid Build Coastguard Worker# 97*da0073e9SAndroid Build Coastguard Worker# auto __names = _r.toDimnameListOptional(1); 98*da0073e9SAndroid Build Coastguard Worker# std::optional<DimnameList> names = 99*da0073e9SAndroid Build Coastguard Worker# __names ? std::make_optional(DimnameList(__names.value())) 100*da0073e9SAndroid Build Coastguard Worker# : std::nullopt; 101*da0073e9SAndroid Build Coastguard Worker# 102*da0073e9SAndroid Build Coastguard Worker# Where does 'options' come from? It involves special local init 103*da0073e9SAndroid Build Coastguard Worker# for TensorOptions. Note that Python side has the additional 104*da0073e9SAndroid Build Coastguard Worker# 'requires_grad' field: 105*da0073e9SAndroid Build Coastguard Worker# 106*da0073e9SAndroid Build Coastguard Worker# const auto options = TensorOptions() 107*da0073e9SAndroid Build Coastguard Worker# .dtype(_r.scalartype(3)) 108*da0073e9SAndroid Build Coastguard Worker# .device(_r.device(5)) 109*da0073e9SAndroid Build Coastguard Worker# .layout(_r.layoutOptional(4)) 110*da0073e9SAndroid Build Coastguard Worker# .requires_grad(_r.toBool(7)) 111*da0073e9SAndroid Build Coastguard Worker# .pinned_memory(_r.toBool(6)); 112*da0073e9SAndroid Build Coastguard Worker# 113*da0073e9SAndroid Build Coastguard Worker# In some other cases one Python Argument can map to multiple C++ 114*da0073e9SAndroid Build Coastguard Worker# Arguments. For example: 115*da0073e9SAndroid Build Coastguard Worker# 116*da0073e9SAndroid Build Coastguard Worker# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) 117*da0073e9SAndroid Build Coastguard Worker# -> (Tensor values, Tensor indices) 118*da0073e9SAndroid Build Coastguard Worker# 119*da0073e9SAndroid Build Coastguard Worker# Python Args Cpp Args Binding Exprs 120*da0073e9SAndroid Build Coastguard Worker# --------------------------------------------------------------------- 121*da0073e9SAndroid Build Coastguard Worker# +----> max 'out[0]' 122*da0073e9SAndroid Build Coastguard Worker# /-----> max_values 'out[1] 123*da0073e9SAndroid Build Coastguard Worker# 0: input / self '_r.tensor(0)' 124*da0073e9SAndroid Build Coastguard Worker# 1: dim / dim '_r.dimname(1)' 125*da0073e9SAndroid Build Coastguard Worker# 2: keepdim / keepdim '_r.toBool(2)' 126*da0073e9SAndroid Build Coastguard Worker# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)' 127*da0073e9SAndroid Build Coastguard Worker# 128*da0073e9SAndroid Build Coastguard Worker# As demonstrated above, the binding can involve reordering, 129*da0073e9SAndroid Build Coastguard Worker# packing, unpacking and special local inits. 130*da0073e9SAndroid Build Coastguard Worker# 131*da0073e9SAndroid Build Coastguard Worker# 132*da0073e9SAndroid Build Coastguard Worker# Let's look at a concrete example: 133*da0073e9SAndroid Build Coastguard Worker# 134*da0073e9SAndroid Build Coastguard Worker# static PythonArgParser parser({ 135*da0073e9SAndroid Build Coastguard Worker# "abs(Tensor input, *, Tensor out=None)", 136*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 137*da0073e9SAndroid Build Coastguard Worker# ^ 138*da0073e9SAndroid Build Coastguard Worker# +--- Python Schema, represented by PythonSignature and PythonArgument 139*da0073e9SAndroid Build Coastguard Worker# 140*da0073e9SAndroid Build Coastguard Worker# }, /*traceable=*/true); 141*da0073e9SAndroid Build Coastguard Worker# 142*da0073e9SAndroid Build Coastguard Worker# ParsedArgs<2> parsed_args; 143*da0073e9SAndroid Build Coastguard Worker# auto _r = parser.parse(nullptr, args, kwargs, parsed_args); 144*da0073e9SAndroid Build Coastguard Worker# 145*da0073e9SAndroid Build Coastguard Worker# ... 146*da0073e9SAndroid Build Coastguard Worker# 147*da0073e9SAndroid Build Coastguard Worker# if (_r.isNone(1)) { 148*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out') 149*da0073e9SAndroid Build Coastguard Worker# represented by PythonArgParserOutputExpr 150*da0073e9SAndroid Build Coastguard Worker# 151*da0073e9SAndroid Build Coastguard Worker# // aten::abs(Tensor self) -> Tensor 152*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 153*da0073e9SAndroid Build Coastguard Worker# ^ 154*da0073e9SAndroid Build Coastguard Worker# +--- NativeFunction schema, base version 155*da0073e9SAndroid Build Coastguard Worker# 156*da0073e9SAndroid Build Coastguard Worker# auto dispatch_abs = [](const Tensor & self) -> Tensor { 157*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 158*da0073e9SAndroid Build Coastguard Worker# ^ 159*da0073e9SAndroid Build Coastguard Worker# +--- dispatch_lambda_args / dispatch_lambda_return_str 160*da0073e9SAndroid Build Coastguard Worker# generated from NativeFunction / CppSignature 161*da0073e9SAndroid Build Coastguard Worker# (deprecated PythonSignature is special) 162*da0073e9SAndroid Build Coastguard Worker# arguments are represented by DispatchLambdaArgument 163*da0073e9SAndroid Build Coastguard Worker# 164*da0073e9SAndroid Build Coastguard Worker# pybind11::gil_scoped_release no_gil; 165*da0073e9SAndroid Build Coastguard Worker# return self.abs(); 166*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs 167*da0073e9SAndroid Build Coastguard Worker# generated from NativeFunction / CppSignature 168*da0073e9SAndroid Build Coastguard Worker# }; 169*da0073e9SAndroid Build Coastguard Worker# return wrap(dispatch_abs(_r.tensor(0))); 170*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~ 171*da0073e9SAndroid Build Coastguard Worker# ^ 172*da0073e9SAndroid Build Coastguard Worker# +--- dispatch_lambda_exprs 173*da0073e9SAndroid Build Coastguard Worker# binding PythonArgParserOutputExpr (python args) 174*da0073e9SAndroid Build Coastguard Worker# and DispatchLambdaArgument (c++ args) 175*da0073e9SAndroid Build Coastguard Worker# 176*da0073e9SAndroid Build Coastguard Worker# } else { 177*da0073e9SAndroid Build Coastguard Worker# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 178*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 179*da0073e9SAndroid Build Coastguard Worker# ^ 180*da0073e9SAndroid Build Coastguard Worker# +--- NativeFunction schema, out-variant 181*da0073e9SAndroid Build Coastguard Worker# 182*da0073e9SAndroid Build Coastguard Worker# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor { 183*da0073e9SAndroid Build Coastguard Worker# pybind11::gil_scoped_release no_gil; 184*da0073e9SAndroid Build Coastguard Worker# return at::abs_out(out, self); 185*da0073e9SAndroid Build Coastguard Worker# }; 186*da0073e9SAndroid Build Coastguard Worker# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); 187*da0073e9SAndroid Build Coastguard Worker# } 188*da0073e9SAndroid Build Coastguard Worker# 189*da0073e9SAndroid Build Coastguard Worker# 190*da0073e9SAndroid Build Coastguard Worker# [Notes] python interface codegen 191*da0073e9SAndroid Build Coastguard Worker# The python dataclasses below are used used to generate both python binding code 192*da0073e9SAndroid Build Coastguard Worker# and pyi type hint signatures. 193*da0073e9SAndroid Build Coastguard Worker# In theory these two should look very similar, but there are number of differences 194*da0073e9SAndroid Build Coastguard Worker# in how pyi signatures vs. python_arg_parser signatures are generated. 195*da0073e9SAndroid Build Coastguard Worker# These differences have been encapsulated in signature_str() vs. signature_str_pyi() 196*da0073e9SAndroid Build Coastguard Worker# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. 197*da0073e9SAndroid Build Coastguard Worker# For examples, only pyi signatures include return types. 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 201*da0073e9SAndroid Build Coastguard Workerclass PythonReturns: 202*da0073e9SAndroid Build Coastguard Worker returns: tuple[Return, ...] 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 206*da0073e9SAndroid Build Coastguard Workerclass PythonArgument: 207*da0073e9SAndroid Build Coastguard Worker name: str 208*da0073e9SAndroid Build Coastguard Worker type: Type 209*da0073e9SAndroid Build Coastguard Worker default: str | None 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker # Used to generate the default init expr for some PythonArgParser outputs, e.g.: 212*da0073e9SAndroid Build Coastguard Worker # 213*da0073e9SAndroid Build Coastguard Worker # _r.layoutWithDefault(3, layout_from_backend(self.options().backend()))) 214*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 215*da0073e9SAndroid Build Coastguard Worker # ^ 216*da0073e9SAndroid Build Coastguard Worker # +--- default_init str 217*da0073e9SAndroid Build Coastguard Worker default_init: str | None 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker # Compute argument formal for python argument parsing. 220*da0073e9SAndroid Build Coastguard Worker # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. 221*da0073e9SAndroid Build Coastguard Worker def argument_str(self, *, method: bool = False, symint: bool = True) -> str: 222*da0073e9SAndroid Build Coastguard Worker type_str = ( 223*da0073e9SAndroid Build Coastguard Worker argument_type_str(self.type, symint=symint) 224*da0073e9SAndroid Build Coastguard Worker .replace("const ", "") 225*da0073e9SAndroid Build Coastguard Worker .replace(" &", "") 226*da0073e9SAndroid Build Coastguard Worker ) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker name = self.name 229*da0073e9SAndroid Build Coastguard Worker # s/self/input/ outside method bindings 230*da0073e9SAndroid Build Coastguard Worker # [old codegen] TODO: remove this? doesn't rename in codegen, it's just 231*da0073e9SAndroid Build Coastguard Worker # for the parse string 232*da0073e9SAndroid Build Coastguard Worker if name == "self" and type_str in ["Tensor", "Number"] and not method: 233*da0073e9SAndroid Build Coastguard Worker name = "input" 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker # add default 236*da0073e9SAndroid Build Coastguard Worker if self.default is not None: 237*da0073e9SAndroid Build Coastguard Worker default = { 238*da0073e9SAndroid Build Coastguard Worker "nullptr": "None", 239*da0073e9SAndroid Build Coastguard Worker "::std::nullopt": "None", 240*da0073e9SAndroid Build Coastguard Worker "std::nullopt": "None", 241*da0073e9SAndroid Build Coastguard Worker "{}": "None", 242*da0073e9SAndroid Build Coastguard Worker }.get(self.default, self.default) 243*da0073e9SAndroid Build Coastguard Worker return f"{type_str} {name}={default}" 244*da0073e9SAndroid Build Coastguard Worker else: 245*da0073e9SAndroid Build Coastguard Worker return f"{type_str} {name}" 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker def argument_str_pyi( 248*da0073e9SAndroid Build Coastguard Worker self, *, method: bool = False, deprecated: bool = False 249*da0073e9SAndroid Build Coastguard Worker ) -> str: 250*da0073e9SAndroid Build Coastguard Worker type_str = argument_type_str_pyi(self.type) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker name = self.name 253*da0073e9SAndroid Build Coastguard Worker # s/self/input/ outside method bindings 254*da0073e9SAndroid Build Coastguard Worker # [old codegen] TODO: remove this? doesn't rename in codegen, it's just 255*da0073e9SAndroid Build Coastguard Worker # for the parse string 256*da0073e9SAndroid Build Coastguard Worker if name == "self" and type_str == "Tensor" and not method and not deprecated: 257*da0073e9SAndroid Build Coastguard Worker name = "input" 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker if name == "from": # from is a Python keyword... 260*da0073e9SAndroid Build Coastguard Worker name += "_" 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker # pyi merges the _out and functional variants into the same signature, with an optional out arg 263*da0073e9SAndroid Build Coastguard Worker if name == "out" and type_str == "Tensor" and not deprecated: 264*da0073e9SAndroid Build Coastguard Worker type_str = "Optional[" + type_str + "]" 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker # pyi deprecated signatures don't get defaults for their out arg 267*da0073e9SAndroid Build Coastguard Worker treat_as_no_default = ( 268*da0073e9SAndroid Build Coastguard Worker deprecated 269*da0073e9SAndroid Build Coastguard Worker and isinstance(self, PythonOutArgument) 270*da0073e9SAndroid Build Coastguard Worker and self.default == "None" 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker # add default 274*da0073e9SAndroid Build Coastguard Worker if self.default is not None and not treat_as_no_default: 275*da0073e9SAndroid Build Coastguard Worker if ( 276*da0073e9SAndroid Build Coastguard Worker isinstance(self.type, ListType) 277*da0073e9SAndroid Build Coastguard Worker and self.type.elem == BaseType(BaseTy.int) 278*da0073e9SAndroid Build Coastguard Worker and self.default.startswith("{") 279*da0073e9SAndroid Build Coastguard Worker and self.default.endswith("}") 280*da0073e9SAndroid Build Coastguard Worker ): 281*da0073e9SAndroid Build Coastguard Worker default = ( 282*da0073e9SAndroid Build Coastguard Worker "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")" 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker else: 285*da0073e9SAndroid Build Coastguard Worker default = { 286*da0073e9SAndroid Build Coastguard Worker "nullptr": "None", 287*da0073e9SAndroid Build Coastguard Worker "::std::nullopt": "None", 288*da0073e9SAndroid Build Coastguard Worker "std::nullopt": "None", 289*da0073e9SAndroid Build Coastguard Worker "{}": "None", 290*da0073e9SAndroid Build Coastguard Worker "c10::MemoryFormat::Contiguous": "contiguous_format", 291*da0073e9SAndroid Build Coastguard Worker "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine", 292*da0073e9SAndroid Build Coastguard Worker }.get(self.default, self.default) 293*da0073e9SAndroid Build Coastguard Worker return f"{name}: {type_str} = {default}" 294*da0073e9SAndroid Build Coastguard Worker else: 295*da0073e9SAndroid Build Coastguard Worker return f"{name}: {type_str}" 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 299*da0073e9SAndroid Build Coastguard Workerclass PythonOutArgument(PythonArgument): 300*da0073e9SAndroid Build Coastguard Worker # In Python signature multiple output fields are packed into one 'out' argument. 301*da0073e9SAndroid Build Coastguard Worker # When binding to C++, it's first binded to a local 'out' variable: 302*da0073e9SAndroid Build Coastguard Worker # 'auto out = _r.tensorlist_n<2>(2);', 303*da0073e9SAndroid Build Coastguard Worker # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. 304*da0073e9SAndroid Build Coastguard Worker # TODO: maybe don't need keep scattered out fields for python signature? 305*da0073e9SAndroid Build Coastguard Worker outputs: tuple[PythonArgument, ...] 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker @staticmethod 308*da0073e9SAndroid Build Coastguard Worker def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: 309*da0073e9SAndroid Build Coastguard Worker if not outputs: 310*da0073e9SAndroid Build Coastguard Worker return None 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker size = len(outputs) 313*da0073e9SAndroid Build Coastguard Worker if size == 1: 314*da0073e9SAndroid Build Coastguard Worker return PythonOutArgument( 315*da0073e9SAndroid Build Coastguard Worker name=outputs[0].name, 316*da0073e9SAndroid Build Coastguard Worker type=outputs[0].type, 317*da0073e9SAndroid Build Coastguard Worker default="None", 318*da0073e9SAndroid Build Coastguard Worker default_init=None, 319*da0073e9SAndroid Build Coastguard Worker outputs=outputs, 320*da0073e9SAndroid Build Coastguard Worker ) 321*da0073e9SAndroid Build Coastguard Worker elif size > 1: 322*da0073e9SAndroid Build Coastguard Worker if any(not a.type.is_tensor_like() for a in outputs): 323*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Unsupported output type: {outputs}") 324*da0073e9SAndroid Build Coastguard Worker return PythonOutArgument( 325*da0073e9SAndroid Build Coastguard Worker name="out", 326*da0073e9SAndroid Build Coastguard Worker # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? 327*da0073e9SAndroid Build Coastguard Worker type=ListType(BaseType(BaseTy.Tensor), size), 328*da0073e9SAndroid Build Coastguard Worker default="None", 329*da0073e9SAndroid Build Coastguard Worker default_init=None, 330*da0073e9SAndroid Build Coastguard Worker outputs=outputs, 331*da0073e9SAndroid Build Coastguard Worker ) 332*da0073e9SAndroid Build Coastguard Worker raise AssertionError(r"Unexpected PythonOutArgument size") 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 336*da0073e9SAndroid Build Coastguard Workerclass PythonSignature: 337*da0073e9SAndroid Build Coastguard Worker # Base operator name, without inplace/outplace suffix. 338*da0073e9SAndroid Build Coastguard Worker name: str 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker # Positional arguments. 341*da0073e9SAndroid Build Coastguard Worker # TODO: create a dedicated SelfArgument type for 'self'? 342*da0073e9SAndroid Build Coastguard Worker input_args: tuple[PythonArgument, ...] 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker # Keyword arguments excluding the 'out' argument and scattered kwargs belonging 345*da0073e9SAndroid Build Coastguard Worker # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). 346*da0073e9SAndroid Build Coastguard Worker input_kwargs: tuple[PythonArgument, ...] 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker output_args: PythonOutArgument | None 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker # Return types, which are only used by pyi 351*da0073e9SAndroid Build Coastguard Worker returns: PythonReturns 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker # These are scattered kwargs arguments belonging to TensorOptions. 354*da0073e9SAndroid Build Coastguard Worker # When binding to C++, they are packed into a TensorOptions object 'options'. 355*da0073e9SAndroid Build Coastguard Worker # It's possible that the C++ signature doesn't take TensorOptions object (e.g. 356*da0073e9SAndroid Build Coastguard Worker # for out variant), in which case they will be used as scattered fields without 357*da0073e9SAndroid Build Coastguard Worker # being packed into 'options'. 358*da0073e9SAndroid Build Coastguard Worker # TODO: maybe create a PythonTensorOptionsArgument? 359*da0073e9SAndroid Build Coastguard Worker tensor_options_args: tuple[PythonArgument, ...] 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker # method or function signature? 362*da0073e9SAndroid Build Coastguard Worker method: bool 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker @property 365*da0073e9SAndroid Build Coastguard Worker def deprecated(self) -> bool: 366*da0073e9SAndroid Build Coastguard Worker return False 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def arguments( 369*da0073e9SAndroid Build Coastguard Worker self, *, skip_outputs: bool = False, skip_tensor_options: bool = False 370*da0073e9SAndroid Build Coastguard Worker ) -> tuple[PythonArgument | PythonOutArgument, ...]: 371*da0073e9SAndroid Build Coastguard Worker result: list[PythonArgument | PythonOutArgument] = [] 372*da0073e9SAndroid Build Coastguard Worker result.extend(self.input_args) 373*da0073e9SAndroid Build Coastguard Worker result.extend(self.input_kwargs) 374*da0073e9SAndroid Build Coastguard Worker if self.output_args is not None and not skip_outputs: 375*da0073e9SAndroid Build Coastguard Worker result.append(self.output_args) 376*da0073e9SAndroid Build Coastguard Worker if not skip_tensor_options: 377*da0073e9SAndroid Build Coastguard Worker result.extend(self.tensor_options_args) 378*da0073e9SAndroid Build Coastguard Worker return tuple(result) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker def arguments_count(self) -> int: 381*da0073e9SAndroid Build Coastguard Worker return len(self.arguments()) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker def output_idx(self) -> int: 384*da0073e9SAndroid Build Coastguard Worker return len(self.input_args) + len(self.input_kwargs) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker # [old codegen] Compute the Python function signature for argument parsing, 387*da0073e9SAndroid Build Coastguard Worker # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: 388*da0073e9SAndroid Build Coastguard Worker # this is NOT the same type signature as specified by PEP 484 389*da0073e9SAndroid Build Coastguard Worker # as understood by mypy; our format was independently developed 390*da0073e9SAndroid Build Coastguard Worker # and has some quirks to make it more suitable specifically 391*da0073e9SAndroid Build Coastguard Worker # for error parsing. 392*da0073e9SAndroid Build Coastguard Worker # 393*da0073e9SAndroid Build Coastguard Worker # For a translation to mypy-valid type signatures, see 394*da0073e9SAndroid Build Coastguard Worker # signature_str_pyi(). 395*da0073e9SAndroid Build Coastguard Worker def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: 396*da0073e9SAndroid Build Coastguard Worker args = self.arguments(skip_outputs=skip_outputs) 397*da0073e9SAndroid Build Coastguard Worker schema_formals: list[str] = [ 398*da0073e9SAndroid Build Coastguard Worker a.argument_str(method=self.method, symint=symint) for a in args 399*da0073e9SAndroid Build Coastguard Worker ] 400*da0073e9SAndroid Build Coastguard Worker positional_argc = len(self.input_args) 401*da0073e9SAndroid Build Coastguard Worker if len(schema_formals) > positional_argc: 402*da0073e9SAndroid Build Coastguard Worker schema_formals.insert(positional_argc, "*") 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker return f'{self.name}({", ".join(schema_formals)})' 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: 407*da0073e9SAndroid Build Coastguard Worker args = self.arguments(skip_outputs=skip_outputs) 408*da0073e9SAndroid Build Coastguard Worker schema_formals: list[str] = [ 409*da0073e9SAndroid Build Coastguard Worker a.argument_str_pyi(method=self.method) for a in args 410*da0073e9SAndroid Build Coastguard Worker ] 411*da0073e9SAndroid Build Coastguard Worker positional_argc = len(self.input_args) 412*da0073e9SAndroid Build Coastguard Worker if len(schema_formals) > positional_argc: 413*da0073e9SAndroid Build Coastguard Worker schema_formals.insert(positional_argc, "*") 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker # only pyi signatures include returns 416*da0073e9SAndroid Build Coastguard Worker returns_str = returns_str_pyi(self) 417*da0073e9SAndroid Build Coastguard Worker # pyi also includes self (with no typing/defaults) for methods 418*da0073e9SAndroid Build Coastguard Worker if self.method: 419*da0073e9SAndroid Build Coastguard Worker schema_formals.insert(0, "self") 420*da0073e9SAndroid Build Coastguard Worker return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: 423*da0073e9SAndroid Build Coastguard Worker # only pyi uses vararg signatures 424*da0073e9SAndroid Build Coastguard Worker args = self.arguments(skip_outputs=skip_outputs) 425*da0073e9SAndroid Build Coastguard Worker schema_formals: list[str] = [ 426*da0073e9SAndroid Build Coastguard Worker a.argument_str_pyi(method=self.method) for a in args 427*da0073e9SAndroid Build Coastguard Worker ] 428*da0073e9SAndroid Build Coastguard Worker # vararg only applies to pyi signatures. vararg variants are not generated for all signatures 429*da0073e9SAndroid Build Coastguard Worker num_args = self.arguments_count() 430*da0073e9SAndroid Build Coastguard Worker num_positionalargs = len(self.input_args) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker have_vararg_version = False 433*da0073e9SAndroid Build Coastguard Worker if num_args > 0: 434*da0073e9SAndroid Build Coastguard Worker vararg_type = args[0].type 435*da0073e9SAndroid Build Coastguard Worker if ( 436*da0073e9SAndroid Build Coastguard Worker isinstance(vararg_type, ListType) 437*da0073e9SAndroid Build Coastguard Worker and str(vararg_type.elem) in ["int", "SymInt"] 438*da0073e9SAndroid Build Coastguard Worker and num_positionalargs == 1 439*da0073e9SAndroid Build Coastguard Worker ): 440*da0073e9SAndroid Build Coastguard Worker have_vararg_version = True 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker if not have_vararg_version: 443*da0073e9SAndroid Build Coastguard Worker return None 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker # Below are the major changes in vararg vs. regular pyi signatures 446*da0073e9SAndroid Build Coastguard Worker # vararg signatures also omit the asterix 447*da0073e9SAndroid Build Coastguard Worker assert isinstance(vararg_type, ListType) 448*da0073e9SAndroid Build Coastguard Worker schema_formals[0] = ( 449*da0073e9SAndroid Build Coastguard Worker "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) 450*da0073e9SAndroid Build Coastguard Worker ) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker returns_str = returns_str_pyi(self) 453*da0073e9SAndroid Build Coastguard Worker # pyi also includes self (with no typing/defaults) for methods 454*da0073e9SAndroid Build Coastguard Worker if self.method: 455*da0073e9SAndroid Build Coastguard Worker schema_formals.insert(0, "self") 456*da0073e9SAndroid Build Coastguard Worker return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker# The deprecated python signature involves some special logic, so create a 460*da0073e9SAndroid Build Coastguard Worker# dedicated data model to store these extra properties. 461*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 462*da0073e9SAndroid Build Coastguard Workerclass PythonSignatureDeprecated(PythonSignature): 463*da0073e9SAndroid Build Coastguard Worker # Schema for the deprecated function 464*da0073e9SAndroid Build Coastguard Worker deprecated_schema: FunctionSchema 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker # The deprecated signature might miss some arguments that the corresponding 467*da0073e9SAndroid Build Coastguard Worker # C++ signature expects. We need store the constant default values to pass in. 468*da0073e9SAndroid Build Coastguard Worker # For example: 469*da0073e9SAndroid Build Coastguard Worker # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) 470*da0073e9SAndroid Build Coastguard Worker # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 471*da0073e9SAndroid Build Coastguard Worker # [func call]: self.addmm(mat1, mat2, beta, 1) 472*da0073e9SAndroid Build Coastguard Worker # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. 473*da0073e9SAndroid Build Coastguard Worker deprecated_args_exprs: tuple[str, ...] 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker @property 476*da0073e9SAndroid Build Coastguard Worker def deprecated(self) -> bool: 477*da0073e9SAndroid Build Coastguard Worker return True 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: 480*da0073e9SAndroid Build Coastguard Worker return ( 481*da0073e9SAndroid Build Coastguard Worker PythonSignature.signature_str( 482*da0073e9SAndroid Build Coastguard Worker self, skip_outputs=skip_outputs, symint=symint 483*da0073e9SAndroid Build Coastguard Worker ) 484*da0073e9SAndroid Build Coastguard Worker + "|deprecated" 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: 488*da0073e9SAndroid Build Coastguard Worker args = self.arguments(skip_outputs=skip_outputs) 489*da0073e9SAndroid Build Coastguard Worker schema_formals: list[str] = [ 490*da0073e9SAndroid Build Coastguard Worker a.argument_str_pyi(method=self.method, deprecated=True) for a in args 491*da0073e9SAndroid Build Coastguard Worker ] 492*da0073e9SAndroid Build Coastguard Worker positional_argc = len(self.input_args) 493*da0073e9SAndroid Build Coastguard Worker if len(schema_formals) > positional_argc: 494*da0073e9SAndroid Build Coastguard Worker schema_formals.insert(positional_argc, "*") 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker returns_str = returns_str_pyi(self) 497*da0073e9SAndroid Build Coastguard Worker return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: 500*da0073e9SAndroid Build Coastguard Worker # the codegen doesn't include vararg variants for deprecated signatures 501*da0073e9SAndroid Build Coastguard Worker return None 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker# This struct is used to hold the PythonSignature and its corresponding 505*da0073e9SAndroid Build Coastguard Worker# NativeFunction BEFORE grouping base and out-variant functions. 506*da0073e9SAndroid Build Coastguard Worker# Why not store NativeFunction in PythonSignature or construct PythonSignature 507*da0073e9SAndroid Build Coastguard Worker# from NativeFunction? Because they are not 1-1 mapped. 508*da0073e9SAndroid Build Coastguard Worker# One native function could have both deprecated and non-deprecated python 509*da0073e9SAndroid Build Coastguard Worker# signatures - NativeFunction doesn't contain information to construct the 510*da0073e9SAndroid Build Coastguard Worker# deprecated python signature. 511*da0073e9SAndroid Build Coastguard Worker# One python signature is used to handle both the base and the out-variant 512*da0073e9SAndroid Build Coastguard Worker# function - see 'PythonSignatureGroup'. 513*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 514*da0073e9SAndroid Build Coastguard Workerclass PythonSignatureNativeFunctionPair: 515*da0073e9SAndroid Build Coastguard Worker signature: PythonSignature 516*da0073e9SAndroid Build Coastguard Worker function: NativeFunction 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker# We merge pairs of functions with signatures that are equivalent mod 520*da0073e9SAndroid Build Coastguard Worker# output arguments, and use a single entry in the python_arg_parser sig 521*da0073e9SAndroid Build Coastguard Worker# list for both (output arguments become optional). 522*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 523*da0073e9SAndroid Build Coastguard Workerclass PythonSignatureGroup: 524*da0073e9SAndroid Build Coastguard Worker # The signature used for Python argument parsing. The outplace signature 525*da0073e9SAndroid Build Coastguard Worker # is preferred if exists, because it can be used to parse inputs for both 526*da0073e9SAndroid Build Coastguard Worker # the out-place variant and the base version (with output omitted). 527*da0073e9SAndroid Build Coastguard Worker signature: PythonSignature 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker # The regular ATen declaration (e.g. conv2d) 530*da0073e9SAndroid Build Coastguard Worker base: NativeFunction 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker # The out variant (e.g. conv2d_out) 533*da0073e9SAndroid Build Coastguard Worker outplace: NativeFunction | None 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker @classmethod 536*da0073e9SAndroid Build Coastguard Worker def from_pairs( 537*da0073e9SAndroid Build Coastguard Worker cls, 538*da0073e9SAndroid Build Coastguard Worker functional: PythonSignatureNativeFunctionPair, 539*da0073e9SAndroid Build Coastguard Worker out: PythonSignatureNativeFunctionPair | None, 540*da0073e9SAndroid Build Coastguard Worker ) -> PythonSignatureGroup: 541*da0073e9SAndroid Build Coastguard Worker if out is None: 542*da0073e9SAndroid Build Coastguard Worker return PythonSignatureGroup( 543*da0073e9SAndroid Build Coastguard Worker signature=functional.signature, 544*da0073e9SAndroid Build Coastguard Worker base=functional.function, 545*da0073e9SAndroid Build Coastguard Worker outplace=None, 546*da0073e9SAndroid Build Coastguard Worker ) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker # prefer the signature with optional out=... arguments because it's the 549*da0073e9SAndroid Build Coastguard Worker # superset that can be used to parse input for both base and outplace. 550*da0073e9SAndroid Build Coastguard Worker signature_kwargs = out.signature.__dict__.copy() 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker # Out overloads in C++ don't have TensorOptions arguments, 553*da0073e9SAndroid Build Coastguard Worker # so take these from the functional variant 554*da0073e9SAndroid Build Coastguard Worker signature_kwargs[ 555*da0073e9SAndroid Build Coastguard Worker "tensor_options_args" 556*da0073e9SAndroid Build Coastguard Worker ] = functional.signature.tensor_options_args 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker return PythonSignatureGroup( 559*da0073e9SAndroid Build Coastguard Worker signature=type(out.signature)(**signature_kwargs), 560*da0073e9SAndroid Build Coastguard Worker base=functional.function, 561*da0073e9SAndroid Build Coastguard Worker outplace=out.function, 562*da0073e9SAndroid Build Coastguard Worker ) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker# C++ function dispatch is wrapped in a lambda function. The lambda function 566*da0073e9SAndroid Build Coastguard Worker# has almost the same signature as the C++ function, only with some small 567*da0073e9SAndroid Build Coastguard Worker# variants - see details below. 568*da0073e9SAndroid Build Coastguard Worker# This data model is used to represent arguments of the lambda function 569*da0073e9SAndroid Build Coastguard Worker# signature. 570*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 571*da0073e9SAndroid Build Coastguard Workerclass DispatchLambdaArgument: 572*da0073e9SAndroid Build Coastguard Worker name: str 573*da0073e9SAndroid Build Coastguard Worker type_str: str 574*da0073e9SAndroid Build Coastguard Worker is_out_arg: bool 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker# To pass PyObjects arguments to C++ function (via the lambda wrapper), 578*da0073e9SAndroid Build Coastguard Worker# we need first convert PyObjects into simple C++ objects. This work 579*da0073e9SAndroid Build Coastguard Worker# is done by PythonArgParser. 580*da0073e9SAndroid Build Coastguard Worker# This data model is used to represent the output of PythonArgParser. 581*da0073e9SAndroid Build Coastguard Worker# It has 1-1 mapping with PythonArgument in PythonSignature. 582*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 583*da0073e9SAndroid Build Coastguard Workerclass PythonArgParserOutputExpr: 584*da0073e9SAndroid Build Coastguard Worker # argument name 585*da0073e9SAndroid Build Coastguard Worker name: str 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker # RHS expression to reference PythonArgParser output. 588*da0073e9SAndroid Build Coastguard Worker expr: str 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker # In some special cases we need create different expr, e.g.: 591*da0073e9SAndroid Build Coastguard Worker # '_r.isNone(1)' instead of '_r.tensor(1)'. 592*da0073e9SAndroid Build Coastguard Worker index: int 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker # The python argument it maps to. 595*da0073e9SAndroid Build Coastguard Worker argument: PythonArgument 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker @property 598*da0073e9SAndroid Build Coastguard Worker def is_none_expr(self) -> str: 599*da0073e9SAndroid Build Coastguard Worker return f"_r.isNone({self.index})" 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker# To pass PythonArgParser output to the lambda wrapper, we need bind 603*da0073e9SAndroid Build Coastguard Worker# PythonArgParserOutputExpr to DispatchLambdaArgument. 604*da0073e9SAndroid Build Coastguard Worker# They are not always 1-1 mapped, e.g. scattered TensorOptions fields 605*da0073e9SAndroid Build Coastguard Worker# need be packed into a TensorOptions object, which is the argument 606*da0073e9SAndroid Build Coastguard Worker# that the lambda function wrapper takes. 607*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 608*da0073e9SAndroid Build Coastguard Workerclass DispatchLambdaArgumentExprs: 609*da0073e9SAndroid Build Coastguard Worker # The exprs that provide the binding for lambda arguments, e.g.: 610*da0073e9SAndroid Build Coastguard Worker # 611*da0073e9SAndroid Build Coastguard Worker # 'self' -> '_r.tensor(0)' 612*da0073e9SAndroid Build Coastguard Worker # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]' 613*da0073e9SAndroid Build Coastguard Worker # 'options' -> 'options' 614*da0073e9SAndroid Build Coastguard Worker # 615*da0073e9SAndroid Build Coastguard Worker # It has 1-1 mapping with DispatchLambdaArgument. 616*da0073e9SAndroid Build Coastguard Worker exprs: Sequence[str] 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker # Special local inits, which might introduce new variables that 619*da0073e9SAndroid Build Coastguard Worker # the 'exprs' above reference, e.g.: 620*da0073e9SAndroid Build Coastguard Worker # 621*da0073e9SAndroid Build Coastguard Worker # 'auto out = _r.tensorlist_n<2>(2);' 622*da0073e9SAndroid Build Coastguard Worker # 623*da0073e9SAndroid Build Coastguard Worker inits: Sequence[str] 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 627*da0073e9SAndroid Build Coastguard Worker# 628*da0073e9SAndroid Build Coastguard Worker# Helper Functions 629*da0073e9SAndroid Build Coastguard Worker# 630*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Workerdef _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: 634*da0073e9SAndroid Build Coastguard Worker return CppSignatureGroup.from_native_function(f, method=method).signature 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Workerdef has_tensor_options(f: NativeFunction) -> bool: 638*da0073e9SAndroid Build Coastguard Worker return f.func.arguments.tensor_options is not None 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 642*da0073e9SAndroid Build Coastguard Worker# 643*da0073e9SAndroid Build Coastguard Worker# Python Signature 644*da0073e9SAndroid Build Coastguard Worker# 645*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker# 'simple_type' was introduced by the old codegen, which is slightly 649*da0073e9SAndroid Build Coastguard Worker# different from the python schema type, e.g.: doesn't have '?' suffix 650*da0073e9SAndroid Build Coastguard Worker# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. 651*da0073e9SAndroid Build Coastguard Workerdef argument_type_str( 652*da0073e9SAndroid Build Coastguard Worker t: Type, *, simple_type: bool = False, symint: bool = True 653*da0073e9SAndroid Build Coastguard Worker) -> str: 654*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 655*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.Tensor: 656*da0073e9SAndroid Build Coastguard Worker return "Tensor" 657*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.int: 658*da0073e9SAndroid Build Coastguard Worker return "int64_t" 659*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.float: 660*da0073e9SAndroid Build Coastguard Worker return "double" 661*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.str: 662*da0073e9SAndroid Build Coastguard Worker return "c10::string_view" 663*da0073e9SAndroid Build Coastguard Worker elif t.name in [ 664*da0073e9SAndroid Build Coastguard Worker BaseTy.bool, 665*da0073e9SAndroid Build Coastguard Worker BaseTy.QScheme, 666*da0073e9SAndroid Build Coastguard Worker BaseTy.Scalar, 667*da0073e9SAndroid Build Coastguard Worker BaseTy.ScalarType, 668*da0073e9SAndroid Build Coastguard Worker BaseTy.Generator, 669*da0073e9SAndroid Build Coastguard Worker BaseTy.Storage, 670*da0073e9SAndroid Build Coastguard Worker BaseTy.Layout, 671*da0073e9SAndroid Build Coastguard Worker BaseTy.Device, 672*da0073e9SAndroid Build Coastguard Worker BaseTy.DeviceIndex, 673*da0073e9SAndroid Build Coastguard Worker BaseTy.MemoryFormat, 674*da0073e9SAndroid Build Coastguard Worker BaseTy.Dimname, 675*da0073e9SAndroid Build Coastguard Worker BaseTy.Stream, 676*da0073e9SAndroid Build Coastguard Worker BaseTy.ConstQuantizerPtr, 677*da0073e9SAndroid Build Coastguard Worker BaseTy.SymInt, 678*da0073e9SAndroid Build Coastguard Worker ]: 679*da0073e9SAndroid Build Coastguard Worker # These python schema type names line up with their function schema names 680*da0073e9SAndroid Build Coastguard Worker return t.name.name 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 683*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "Tensor": 684*da0073e9SAndroid Build Coastguard Worker # Is it desired to keep '?' for simple_type with new style dispatcher? 685*da0073e9SAndroid Build Coastguard Worker return "Tensor?" 686*da0073e9SAndroid Build Coastguard Worker elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) 687*da0073e9SAndroid Build Coastguard Worker return f"{elem}?" 688*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 689*da0073e9SAndroid Build Coastguard Worker size = t.size if not simple_type else None 690*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "bool": 691*da0073e9SAndroid Build Coastguard Worker assert t.size is not None 692*da0073e9SAndroid Build Coastguard Worker return f"::std::array<bool,{t.size}>" 693*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "int": 694*da0073e9SAndroid Build Coastguard Worker return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" 695*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "SymInt": 696*da0073e9SAndroid Build Coastguard Worker if symint: 697*da0073e9SAndroid Build Coastguard Worker return ( 698*da0073e9SAndroid Build Coastguard Worker f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" 699*da0073e9SAndroid Build Coastguard Worker ) 700*da0073e9SAndroid Build Coastguard Worker else: 701*da0073e9SAndroid Build Coastguard Worker return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" 702*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Tensor": 703*da0073e9SAndroid Build Coastguard Worker return f"TensorList[{size}]" if size is not None else "TensorList" 704*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Scalar": 705*da0073e9SAndroid Build Coastguard Worker return f"ScalarList[{size}]" if size is not None else "ScalarList" 706*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Tensor?": 707*da0073e9SAndroid Build Coastguard Worker if simple_type: 708*da0073e9SAndroid Build Coastguard Worker return "c10::List<::std::optional<Tensor>>" 709*da0073e9SAndroid Build Coastguard Worker else: 710*da0073e9SAndroid Build Coastguard Worker return "const c10::List<::std::optional<Tensor>> &" 711*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Dimname": 712*da0073e9SAndroid Build Coastguard Worker return f"DimnameList[{size}]" if size is not None else "DimnameList" 713*da0073e9SAndroid Build Coastguard Worker elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) 714*da0073e9SAndroid Build Coastguard Worker return f"ArrayRef<{elem}>" 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"unrecognized type {repr(t)}") 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Workerdef argument_type_size(t: Type) -> int | None: 720*da0073e9SAndroid Build Coastguard Worker l = t.is_list_like() 721*da0073e9SAndroid Build Coastguard Worker if l is not None and str(l.elem) != "bool": 722*da0073e9SAndroid Build Coastguard Worker return l.size 723*da0073e9SAndroid Build Coastguard Worker else: 724*da0073e9SAndroid Build Coastguard Worker return None 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Workerdef argument(a: Argument) -> PythonArgument: 728*da0073e9SAndroid Build Coastguard Worker return PythonArgument( 729*da0073e9SAndroid Build Coastguard Worker name=a.name, 730*da0073e9SAndroid Build Coastguard Worker type=a.type, 731*da0073e9SAndroid Build Coastguard Worker # TODO: directly translate a.default to python default 732*da0073e9SAndroid Build Coastguard Worker default=( 733*da0073e9SAndroid Build Coastguard Worker str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False))) 734*da0073e9SAndroid Build Coastguard Worker if a.default is not None 735*da0073e9SAndroid Build Coastguard Worker else None 736*da0073e9SAndroid Build Coastguard Worker ), 737*da0073e9SAndroid Build Coastguard Worker default_init=None, 738*da0073e9SAndroid Build Coastguard Worker ) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen 742*da0073e9SAndroid Build Coastguard Workerdef signature( 743*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, *, method: bool = False, pyi: bool = False 744*da0073e9SAndroid Build Coastguard Worker) -> PythonSignature: 745*da0073e9SAndroid Build Coastguard Worker return signature_from_schema( 746*da0073e9SAndroid Build Coastguard Worker f.func, category_override=f.category_override, method=method, pyi=pyi 747*da0073e9SAndroid Build Coastguard Worker ) 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Workerdef signature_from_schema( 751*da0073e9SAndroid Build Coastguard Worker func: FunctionSchema, 752*da0073e9SAndroid Build Coastguard Worker *, 753*da0073e9SAndroid Build Coastguard Worker category_override: str | None, 754*da0073e9SAndroid Build Coastguard Worker method: bool = False, 755*da0073e9SAndroid Build Coastguard Worker pyi: bool = False, 756*da0073e9SAndroid Build Coastguard Worker) -> PythonSignature: 757*da0073e9SAndroid Build Coastguard Worker args: list[Argument] = [] 758*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.pre_self_positional) 759*da0073e9SAndroid Build Coastguard Worker # Skip SelfArgument if this is method. 760*da0073e9SAndroid Build Coastguard Worker if not method and func.arguments.self_arg is not None: 761*da0073e9SAndroid Build Coastguard Worker args.append(func.arguments.self_arg.argument) 762*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.post_self_positional) 763*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.pre_tensor_options_kwarg_only) 764*da0073e9SAndroid Build Coastguard Worker # Skip TensorOptionsArguments. Python side TensorOptions 765*da0073e9SAndroid Build Coastguard Worker # arguments are created based on different rules - see below. 766*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.post_tensor_options_kwarg_only) 767*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.out) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker input_arg_set = {a.name for a in func.arguments.flat_positional} 770*da0073e9SAndroid Build Coastguard Worker kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} 771*da0073e9SAndroid Build Coastguard Worker out_arg_set = {a.name for a in func.arguments.out} 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) 774*da0073e9SAndroid Build Coastguard Worker input_kwargs = tuple( 775*da0073e9SAndroid Build Coastguard Worker map(argument, filter(lambda a: a.name in kwarg_only_set, args)) 776*da0073e9SAndroid Build Coastguard Worker ) 777*da0073e9SAndroid Build Coastguard Worker outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker # Reintroduce the scattered fields of TensorOptions for Python. 780*da0073e9SAndroid Build Coastguard Worker # Compared to the cpp counterpart, the python arguments have new property 781*da0073e9SAndroid Build Coastguard Worker # (default_init) and a new argument 'requires_grad', which require some 782*da0073e9SAndroid Build Coastguard Worker # special handlings. 783*da0073e9SAndroid Build Coastguard Worker # [old codegen] TODO: because these aren't guaranteed to be 100% faithful 784*da0073e9SAndroid Build Coastguard Worker # to the original versions in the yaml, this recreation is a potential 785*da0073e9SAndroid Build Coastguard Worker # source of drift between eager and JIT. Pull this logic out to a shared place. 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker has_tensor_input_arg = any( 788*da0073e9SAndroid Build Coastguard Worker a.type.is_tensor_like() for a in func.arguments.flat_non_out 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker if any(a.name == "requires_grad" for a in func.schema_order_arguments()): 791*da0073e9SAndroid Build Coastguard Worker raise ValueError( 792*da0073e9SAndroid Build Coastguard Worker "argument named requires_grad is reserved, should not explicitly add it in the schema" 793*da0073e9SAndroid Build Coastguard Worker ) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker # [old codegen] this probably won't work if one of the returns is not a tensor, 796*da0073e9SAndroid Build Coastguard Worker # but it will produce a compile-time error that is obvious. 797*da0073e9SAndroid Build Coastguard Worker has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker name: str = cpp.name(func) 800*da0073e9SAndroid Build Coastguard Worker is_factory_function = category_override == "factory" or ( 801*da0073e9SAndroid Build Coastguard Worker has_tensor_return and not has_tensor_input_arg 802*da0073e9SAndroid Build Coastguard Worker ) 803*da0073e9SAndroid Build Coastguard Worker is_like_or_new_function = ( 804*da0073e9SAndroid Build Coastguard Worker category_override in ("new", "like") 805*da0073e9SAndroid Build Coastguard Worker or name.startswith("new_") 806*da0073e9SAndroid Build Coastguard Worker or name.endswith("_like") 807*da0073e9SAndroid Build Coastguard Worker ) 808*da0073e9SAndroid Build Coastguard Worker is_dummy_function = category_override == "dummy" 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker tensor_options_args: list[PythonArgument] = [] 811*da0073e9SAndroid Build Coastguard Worker if (is_factory_function or is_like_or_new_function) and not is_dummy_function: 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def topt_default_init(name: str) -> str | None: 814*da0073e9SAndroid Build Coastguard Worker topt_args = func.arguments.tensor_options 815*da0073e9SAndroid Build Coastguard Worker if topt_args is None: 816*da0073e9SAndroid Build Coastguard Worker return None 817*da0073e9SAndroid Build Coastguard Worker a = getattr(topt_args, name) 818*da0073e9SAndroid Build Coastguard Worker if a.default is None or a.default == "None": 819*da0073e9SAndroid Build Coastguard Worker return None 820*da0073e9SAndroid Build Coastguard Worker return cpp.default_expr(a.default, a.type, symint=False) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker tensor_options_args.append( 823*da0073e9SAndroid Build Coastguard Worker PythonArgument( 824*da0073e9SAndroid Build Coastguard Worker name="dtype", 825*da0073e9SAndroid Build Coastguard Worker type=OptionalType(BaseType(BaseTy.ScalarType)), 826*da0073e9SAndroid Build Coastguard Worker default="None", 827*da0073e9SAndroid Build Coastguard Worker default_init=( 828*da0073e9SAndroid Build Coastguard Worker None if is_like_or_new_function else topt_default_init("dtype") 829*da0073e9SAndroid Build Coastguard Worker ), 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker ) 832*da0073e9SAndroid Build Coastguard Worker tensor_options_args.append( 833*da0073e9SAndroid Build Coastguard Worker PythonArgument( 834*da0073e9SAndroid Build Coastguard Worker name="layout", 835*da0073e9SAndroid Build Coastguard Worker type=OptionalType(BaseType(BaseTy.Layout)), 836*da0073e9SAndroid Build Coastguard Worker default="None", 837*da0073e9SAndroid Build Coastguard Worker default_init=( 838*da0073e9SAndroid Build Coastguard Worker None if is_like_or_new_function else topt_default_init("layout") 839*da0073e9SAndroid Build Coastguard Worker ), 840*da0073e9SAndroid Build Coastguard Worker ) 841*da0073e9SAndroid Build Coastguard Worker ) 842*da0073e9SAndroid Build Coastguard Worker tensor_options_args.append( 843*da0073e9SAndroid Build Coastguard Worker PythonArgument( 844*da0073e9SAndroid Build Coastguard Worker name="device", 845*da0073e9SAndroid Build Coastguard Worker type=OptionalType(BaseType(BaseTy.Device)), 846*da0073e9SAndroid Build Coastguard Worker default="None", 847*da0073e9SAndroid Build Coastguard Worker default_init=( 848*da0073e9SAndroid Build Coastguard Worker None 849*da0073e9SAndroid Build Coastguard Worker if is_like_or_new_function 850*da0073e9SAndroid Build Coastguard Worker else ( 851*da0073e9SAndroid Build Coastguard Worker topt_default_init("device") 852*da0073e9SAndroid Build Coastguard Worker or "torch::tensors::get_default_device()" 853*da0073e9SAndroid Build Coastguard Worker ) 854*da0073e9SAndroid Build Coastguard Worker ), 855*da0073e9SAndroid Build Coastguard Worker ) 856*da0073e9SAndroid Build Coastguard Worker ) 857*da0073e9SAndroid Build Coastguard Worker tensor_options_args.append( 858*da0073e9SAndroid Build Coastguard Worker PythonArgument( 859*da0073e9SAndroid Build Coastguard Worker name="pin_memory", 860*da0073e9SAndroid Build Coastguard Worker type=OptionalType(BaseType(BaseTy.bool)), 861*da0073e9SAndroid Build Coastguard Worker default="False", 862*da0073e9SAndroid Build Coastguard Worker default_init=None, 863*da0073e9SAndroid Build Coastguard Worker ) 864*da0073e9SAndroid Build Coastguard Worker ) 865*da0073e9SAndroid Build Coastguard Worker tensor_options_args.append( 866*da0073e9SAndroid Build Coastguard Worker PythonArgument( 867*da0073e9SAndroid Build Coastguard Worker name="requires_grad", 868*da0073e9SAndroid Build Coastguard Worker type=OptionalType(BaseType(BaseTy.bool)), 869*da0073e9SAndroid Build Coastguard Worker default="False", 870*da0073e9SAndroid Build Coastguard Worker default_init=None, 871*da0073e9SAndroid Build Coastguard Worker ) 872*da0073e9SAndroid Build Coastguard Worker ) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker returns = PythonReturns(returns=func.returns) 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker return PythonSignature( 877*da0073e9SAndroid Build Coastguard Worker name=str(func.name.name), 878*da0073e9SAndroid Build Coastguard Worker input_args=input_args, 879*da0073e9SAndroid Build Coastguard Worker input_kwargs=input_kwargs, 880*da0073e9SAndroid Build Coastguard Worker output_args=PythonOutArgument.from_outputs(outputs), 881*da0073e9SAndroid Build Coastguard Worker tensor_options_args=tuple(tensor_options_args), 882*da0073e9SAndroid Build Coastguard Worker returns=returns, 883*da0073e9SAndroid Build Coastguard Worker method=method, 884*da0073e9SAndroid Build Coastguard Worker ) 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 888*da0073e9SAndroid Build Coastguard Worker# 889*da0073e9SAndroid Build Coastguard Worker# Python Interface 890*da0073e9SAndroid Build Coastguard Worker# 891*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Workerdef structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: 895*da0073e9SAndroid Build Coastguard Worker if len(returns) <= 1 or all(r.name is None for r in returns): 896*da0073e9SAndroid Build Coastguard Worker return [] 897*da0073e9SAndroid Build Coastguard Worker else: 898*da0073e9SAndroid Build Coastguard Worker if any(r.name is None for r in returns): 899*da0073e9SAndroid Build Coastguard Worker # When building on Windows, `PyStructSequence_UnnamedField` could not be 900*da0073e9SAndroid Build Coastguard Worker # resolved by the linker for some reason, which cause error in building: 901*da0073e9SAndroid Build Coastguard Worker # 902*da0073e9SAndroid Build Coastguard Worker # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol 903*da0073e9SAndroid Build Coastguard Worker # PyStructSequence_UnnamedField 904*da0073e9SAndroid Build Coastguard Worker # 905*da0073e9SAndroid Build Coastguard Worker # Thus, at this point in time, we do not support unnamed 906*da0073e9SAndroid Build Coastguard Worker # fields in structseq; you must either name all fields, 907*da0073e9SAndroid Build Coastguard Worker # or none of them. 908*da0073e9SAndroid Build Coastguard Worker raise ValueError("Unnamed field is not supported by codegen") 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker return [str(r.name) for r in returns] 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Workerdef argument_type_str_pyi(t: Type) -> str: 914*da0073e9SAndroid Build Coastguard Worker add_optional = False 915*da0073e9SAndroid Build Coastguard Worker if isinstance(t, OptionalType): 916*da0073e9SAndroid Build Coastguard Worker t = t.elem 917*da0073e9SAndroid Build Coastguard Worker add_optional = True 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 920*da0073e9SAndroid Build Coastguard Worker if t.name in [BaseTy.int, BaseTy.DeviceIndex]: 921*da0073e9SAndroid Build Coastguard Worker ret = "_int" 922*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.SymInt: 923*da0073e9SAndroid Build Coastguard Worker ret = "Union[_int, SymInt]" 924*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.float: 925*da0073e9SAndroid Build Coastguard Worker ret = "_float" 926*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.str: 927*da0073e9SAndroid Build Coastguard Worker ret = "str" 928*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Scalar: 929*da0073e9SAndroid Build Coastguard Worker ret = "Union[Number, _complex]" 930*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.ScalarType: 931*da0073e9SAndroid Build Coastguard Worker ret = "_dtype" 932*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.bool: 933*da0073e9SAndroid Build Coastguard Worker ret = "_bool" 934*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.QScheme: 935*da0073e9SAndroid Build Coastguard Worker ret = "_qscheme" 936*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Layout: 937*da0073e9SAndroid Build Coastguard Worker ret = "_layout" 938*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Device: 939*da0073e9SAndroid Build Coastguard Worker ret = "Optional[DeviceLikeType]" 940*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.MemoryFormat: 941*da0073e9SAndroid Build Coastguard Worker ret = "memory_format" 942*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Dimname: 943*da0073e9SAndroid Build Coastguard Worker ret = "Union[str, ellipsis, None]" 944*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Storage: 945*da0073e9SAndroid Build Coastguard Worker ret = "Union[Storage, UntypedStorage]" 946*da0073e9SAndroid Build Coastguard Worker elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]: 947*da0073e9SAndroid Build Coastguard Worker # These python schema type names line up with their function schema names 948*da0073e9SAndroid Build Coastguard Worker ret = t.name.name 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 951*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "int": 952*da0073e9SAndroid Build Coastguard Worker ret = "Union[_int, _size]" if t.size is not None else "_size" 953*da0073e9SAndroid Build Coastguard Worker elif t.is_tensor_like(): 954*da0073e9SAndroid Build Coastguard Worker # TODO: this doesn't seem right... 955*da0073e9SAndroid Build Coastguard Worker # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] 956*da0073e9SAndroid Build Coastguard Worker # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] 957*da0073e9SAndroid Build Coastguard Worker if isinstance(t.elem, OptionalType): 958*da0073e9SAndroid Build Coastguard Worker add_optional = True 959*da0073e9SAndroid Build Coastguard Worker ret = ( 960*da0073e9SAndroid Build Coastguard Worker "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]" 961*da0073e9SAndroid Build Coastguard Worker if t.size is not None 962*da0073e9SAndroid Build Coastguard Worker else "Union[Tuple[Tensor, ...], List[Tensor]]" 963*da0073e9SAndroid Build Coastguard Worker ) 964*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "float": 965*da0073e9SAndroid Build Coastguard Worker ret = "Sequence[_float]" 966*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "SymInt" and t.size is not None: 967*da0073e9SAndroid Build Coastguard Worker elem = argument_type_str_pyi(t.elem) 968*da0073e9SAndroid Build Coastguard Worker ret = f"Union[{elem}, Sequence[{elem}]]" 969*da0073e9SAndroid Build Coastguard Worker else: 970*da0073e9SAndroid Build Coastguard Worker elem = argument_type_str_pyi(t.elem) 971*da0073e9SAndroid Build Coastguard Worker ret = f"Sequence[{elem}]" 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker else: 974*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"unrecognized type {repr(t)}") 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker if add_optional: 977*da0073e9SAndroid Build Coastguard Worker ret = "Optional[" + ret + "]" 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker return ret 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Workerdef return_type_str_pyi(t: Type) -> str: 983*da0073e9SAndroid Build Coastguard Worker # Where arguments are open to accepting Union, return types should return 984*da0073e9SAndroid Build Coastguard Worker # concrete types 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker if isinstance(t, OptionalType): 987*da0073e9SAndroid Build Coastguard Worker inner = return_type_str_pyi(t.elem) 988*da0073e9SAndroid Build Coastguard Worker return f"Optional[{inner}]" 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 991*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.Device: 992*da0073e9SAndroid Build Coastguard Worker return "_device" 993*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Dimname: 994*da0073e9SAndroid Build Coastguard Worker ret = "Optional[str]" 995*da0073e9SAndroid Build Coastguard Worker else: 996*da0073e9SAndroid Build Coastguard Worker return argument_type_str_pyi(t) 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker if isinstance(t, ListType): 999*da0073e9SAndroid Build Coastguard Worker inner = return_type_str_pyi(t.elem) 1000*da0073e9SAndroid Build Coastguard Worker return f"Tuple[{inner}, ...]" 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker return argument_type_str_pyi(t) 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Workerdef returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: 1006*da0073e9SAndroid Build Coastguard Worker python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] 1007*da0073e9SAndroid Build Coastguard Worker structseq_name = signature.name 1008*da0073e9SAndroid Build Coastguard Worker field_names = structseq_fieldnames(signature.returns.returns) 1009*da0073e9SAndroid Build Coastguard Worker if field_names: 1010*da0073e9SAndroid Build Coastguard Worker # These types are structseq objects which act like named NamedTuples, but 1011*da0073e9SAndroid Build Coastguard Worker # the constructor acts like the constructor of tuple. Using typing.NamedTuple 1012*da0073e9SAndroid Build Coastguard Worker # does not allow us to override __init__. 1013*da0073e9SAndroid Build Coastguard Worker seq_type = f"Tuple[{', '.join(python_returns)}]" 1014*da0073e9SAndroid Build Coastguard Worker structseq_def_lines = [ 1015*da0073e9SAndroid Build Coastguard Worker f"class {structseq_name}({seq_type}):", 1016*da0073e9SAndroid Build Coastguard Worker ] 1017*da0073e9SAndroid Build Coastguard Worker for name, typ in zip(field_names, python_returns): 1018*da0073e9SAndroid Build Coastguard Worker structseq_def_lines.extend( 1019*da0073e9SAndroid Build Coastguard Worker [ 1020*da0073e9SAndroid Build Coastguard Worker " @property", 1021*da0073e9SAndroid Build Coastguard Worker f" def {name}(self) -> {typ}: ...", 1022*da0073e9SAndroid Build Coastguard Worker ] 1023*da0073e9SAndroid Build Coastguard Worker ) 1024*da0073e9SAndroid Build Coastguard Worker structseq_def_lines.extend( 1025*da0073e9SAndroid Build Coastguard Worker [ 1026*da0073e9SAndroid Build Coastguard Worker f" def __new__(cls, sequence: {seq_type}): ...", 1027*da0073e9SAndroid Build Coastguard Worker f" n_fields: _int = {len(field_names)}", 1028*da0073e9SAndroid Build Coastguard Worker f" n_sequeunce_fields: _int = {len(field_names)}", 1029*da0073e9SAndroid Build Coastguard Worker " n_unnamed_fields: _int = 0", 1030*da0073e9SAndroid Build Coastguard Worker " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", 1031*da0073e9SAndroid Build Coastguard Worker "", # add an extra newline 1032*da0073e9SAndroid Build Coastguard Worker ] 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker structseq_def = "\n".join(structseq_def_lines) 1035*da0073e9SAndroid Build Coastguard Worker # Example: 1036*da0073e9SAndroid Build Coastguard Worker # structseq_def = ( 1037*da0073e9SAndroid Build Coastguard Worker # "class max(Tuple[Tensor, Tensor]):\n" 1038*da0073e9SAndroid Build Coastguard Worker # " @property\n" 1039*da0073e9SAndroid Build Coastguard Worker # " def values(self) -> Tensor: ...\n" 1040*da0073e9SAndroid Build Coastguard Worker # " @property\n" 1041*da0073e9SAndroid Build Coastguard Worker # " def indices(self) -> Tensor: ...\n" 1042*da0073e9SAndroid Build Coastguard Worker # " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n" 1043*da0073e9SAndroid Build Coastguard Worker # " n_fields: _int = 2", 1044*da0073e9SAndroid Build Coastguard Worker # " n_sequeunce_fields: _int = 2", 1045*da0073e9SAndroid Build Coastguard Worker # " n_unnamed_fields: _int = 0", 1046*da0073e9SAndroid Build Coastguard Worker # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", 1047*da0073e9SAndroid Build Coastguard Worker # ) 1048*da0073e9SAndroid Build Coastguard Worker return structseq_name, structseq_def 1049*da0073e9SAndroid Build Coastguard Worker return None 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Workerdef returns_str_pyi(signature: PythonSignature) -> str: 1053*da0073e9SAndroid Build Coastguard Worker field_names = structseq_fieldnames(signature.returns.returns) 1054*da0073e9SAndroid Build Coastguard Worker if field_names: 1055*da0073e9SAndroid Build Coastguard Worker return f"torch.return_types.{signature.name}" 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] 1058*da0073e9SAndroid Build Coastguard Worker if len(python_returns) > 1: 1059*da0073e9SAndroid Build Coastguard Worker return "Tuple[" + ", ".join(python_returns) + "]" 1060*da0073e9SAndroid Build Coastguard Worker if len(python_returns) == 1: 1061*da0073e9SAndroid Build Coastguard Worker return python_returns[0] 1062*da0073e9SAndroid Build Coastguard Worker return "None" 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1066*da0073e9SAndroid Build Coastguard Worker# 1067*da0073e9SAndroid Build Coastguard Worker# C++ Function Dispatch 1068*da0073e9SAndroid Build Coastguard Worker# 1069*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1070*da0073e9SAndroid Build Coastguard Worker# This section provides APIs to generate the code that does C++ function 1071*da0073e9SAndroid Build Coastguard Worker# dispatch. The C++ function call is wrapped by a lambda function. 1072*da0073e9SAndroid Build Coastguard Worker# For example: 1073*da0073e9SAndroid Build Coastguard Worker# 1074*da0073e9SAndroid Build Coastguard Worker# // aten::selu_(Tensor(a!) self) -> Tensor(a!) 1075*da0073e9SAndroid Build Coastguard Worker# auto dispatch_selu_ = [](Tensor self) -> Tensor { 1076*da0073e9SAndroid Build Coastguard Worker# pybind11::gil_scoped_release no_gil; 1077*da0073e9SAndroid Build Coastguard Worker# return at::selu_(self); 1078*da0073e9SAndroid Build Coastguard Worker# }; 1079*da0073e9SAndroid Build Coastguard Worker# 1080*da0073e9SAndroid Build Coastguard Worker# The lambda function's signature follows the C++ signature in common 1081*da0073e9SAndroid Build Coastguard Worker# cases, e.g.: 1082*da0073e9SAndroid Build Coastguard Worker# 1083*da0073e9SAndroid Build Coastguard Worker# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 1084*da0073e9SAndroid Build Coastguard Worker# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor 1085*da0073e9SAndroid Build Coastguard Worker# 1086*da0073e9SAndroid Build Coastguard Worker# For out variant the 'out' argument's type is changed from 'Tensor &' 1087*da0073e9SAndroid Build Coastguard Worker# to 'Tensor'. It's because when calling the lambda it passes in the 1088*da0073e9SAndroid Build Coastguard Worker# PythonArgParser output '_r.tensor(3)', which is stack allocated object 1089*da0073e9SAndroid Build Coastguard Worker# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'. 1090*da0073e9SAndroid Build Coastguard Worker# 1091*da0073e9SAndroid Build Coastguard Worker# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 1092*da0073e9SAndroid Build Coastguard Worker# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor 1093*da0073e9SAndroid Build Coastguard Worker# 1094*da0073e9SAndroid Build Coastguard Worker# For multi-output case it can keep using reference type because the 1095*da0073e9SAndroid Build Coastguard Worker# PythonArgParser output has been unpacked to local variables, e.g.: 1096*da0073e9SAndroid Build Coastguard Worker# 1097*da0073e9SAndroid Build Coastguard Worker# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, 1098*da0073e9SAndroid Build Coastguard Worker# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) 1099*da0073e9SAndroid Build Coastguard Worker# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor> 1100*da0073e9SAndroid Build Coastguard Worker# 1101*da0073e9SAndroid Build Coastguard Worker# For deprecated python signature, it should follow deprecated python arg order. 1102*da0073e9SAndroid Build Coastguard Worker# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary? 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Workerdef dispatch_lambda_args( 1106*da0073e9SAndroid Build Coastguard Worker ps: PythonSignature, f: NativeFunction, symint: bool = True 1107*da0073e9SAndroid Build Coastguard Worker) -> tuple[DispatchLambdaArgument, ...]: 1108*da0073e9SAndroid Build Coastguard Worker if isinstance(ps, PythonSignatureDeprecated): 1109*da0073e9SAndroid Build Coastguard Worker schema = ps.deprecated_schema 1110*da0073e9SAndroid Build Coastguard Worker else: 1111*da0073e9SAndroid Build Coastguard Worker schema = f.func 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker # Start with cpp arguments - dispatch lambda signature always include 'self' 1114*da0073e9SAndroid Build Coastguard Worker cpp_args = cpp.arguments( 1115*da0073e9SAndroid Build Coastguard Worker arguments=schema.arguments, 1116*da0073e9SAndroid Build Coastguard Worker faithful=False, 1117*da0073e9SAndroid Build Coastguard Worker symint=symint, 1118*da0073e9SAndroid Build Coastguard Worker method=False, 1119*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args=f.cpp_no_default_args, 1120*da0073e9SAndroid Build Coastguard Worker ) 1121*da0073e9SAndroid Build Coastguard Worker out_args: set[str] = {a.name for a in schema.arguments.out} 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker # Convert from cpp argument to lambda argument 1124*da0073e9SAndroid Build Coastguard Worker def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: 1125*da0073e9SAndroid Build Coastguard Worker type_str = cpp_arg.type 1126*da0073e9SAndroid Build Coastguard Worker is_out_arg = cpp_arg.name in out_args 1127*da0073e9SAndroid Build Coastguard Worker if ps.method and cpp_arg.name == "self": 1128*da0073e9SAndroid Build Coastguard Worker # For method's 'self', we can use 'const Tensor &' and simply ignore mutability! 1129*da0073e9SAndroid Build Coastguard Worker type_str = "const at::Tensor &" 1130*da0073e9SAndroid Build Coastguard Worker else: 1131*da0073e9SAndroid Build Coastguard Worker # For other cases we need prevent dangling refs to temps (unless it's 1132*da0073e9SAndroid Build Coastguard Worker # unpacked scattered output) 1133*da0073e9SAndroid Build Coastguard Worker # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'. 1134*da0073e9SAndroid Build Coastguard Worker # TODO: avoid this special handling? 1135*da0073e9SAndroid Build Coastguard Worker ensure_temp_safe = len(out_args) <= 1 or not is_out_arg 1136*da0073e9SAndroid Build Coastguard Worker if ensure_temp_safe: 1137*da0073e9SAndroid Build Coastguard Worker type_str = { 1138*da0073e9SAndroid Build Coastguard Worker "at::Tensor &": "at::Tensor", 1139*da0073e9SAndroid Build Coastguard Worker }.get(type_str, type_str) 1140*da0073e9SAndroid Build Coastguard Worker return DispatchLambdaArgument( 1141*da0073e9SAndroid Build Coastguard Worker name=cpp_arg.name, 1142*da0073e9SAndroid Build Coastguard Worker type_str=type_str, 1143*da0073e9SAndroid Build Coastguard Worker is_out_arg=is_out_arg, 1144*da0073e9SAndroid Build Coastguard Worker ) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker return tuple(map(dispatch_lambda_arg, cpp_args)) 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean 1150*da0073e9SAndroid Build Coastguard Worker# it's enough to just extend the list here. Before you do this, make sure 1151*da0073e9SAndroid Build Coastguard Worker# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. 1152*da0073e9SAndroid Build Coastguard WorkerSUPPORTED_RETURN_TYPES = { 1153*da0073e9SAndroid Build Coastguard Worker "at::Tensor", 1154*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor>", 1155*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor>", 1156*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>", 1157*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>", 1158*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>", 1159*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>", 1160*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,double,int64_t>", 1161*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>", 1162*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>", 1163*da0073e9SAndroid Build Coastguard Worker "::std::tuple<double,int64_t>", 1164*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,::std::vector<at::Tensor>>", 1165*da0073e9SAndroid Build Coastguard Worker "::std::vector<at::Tensor>", 1166*da0073e9SAndroid Build Coastguard Worker # Needed for flash attention forw/backward 1167*da0073e9SAndroid Build Coastguard Worker "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>", 1168*da0073e9SAndroid Build Coastguard Worker "at::Scalar", 1169*da0073e9SAndroid Build Coastguard Worker "bool", 1170*da0073e9SAndroid Build Coastguard Worker "int64_t", 1171*da0073e9SAndroid Build Coastguard Worker "void*", 1172*da0073e9SAndroid Build Coastguard Worker "void", 1173*da0073e9SAndroid Build Coastguard Worker "at::QScheme", 1174*da0073e9SAndroid Build Coastguard Worker "double", 1175*da0073e9SAndroid Build Coastguard Worker "at::IntArrayRef", 1176*da0073e9SAndroid Build Coastguard Worker "at::ScalarType", 1177*da0073e9SAndroid Build Coastguard Worker "at::Stream", 1178*da0073e9SAndroid Build Coastguard Worker} 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Workerdef dispatch_lambda_return_str(f: NativeFunction) -> str: 1182*da0073e9SAndroid Build Coastguard Worker # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &') 1183*da0073e9SAndroid Build Coastguard Worker # because the dispatch lambdas take mutable arguments *by value*, not 1184*da0073e9SAndroid Build Coastguard Worker # by reference. If you then return a reference to such an argument, you 1185*da0073e9SAndroid Build Coastguard Worker # will now have a pointer to a dangling stack entry. Not good. 1186*da0073e9SAndroid Build Coastguard Worker # 1187*da0073e9SAndroid Build Coastguard Worker # You want: 1188*da0073e9SAndroid Build Coastguard Worker # 1189*da0073e9SAndroid Build Coastguard Worker # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; 1190*da0073e9SAndroid Build Coastguard Worker # ^^^^^^ 1191*da0073e9SAndroid Build Coastguard Worker # 1192*da0073e9SAndroid Build Coastguard Worker # *not* 1193*da0073e9SAndroid Build Coastguard Worker # 1194*da0073e9SAndroid Build Coastguard Worker # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; 1195*da0073e9SAndroid Build Coastguard Worker # ^^^^^^^ 1196*da0073e9SAndroid Build Coastguard Worker # 1197*da0073e9SAndroid Build Coastguard Worker # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing 1198*da0073e9SAndroid Build Coastguard Worker # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a 1199*da0073e9SAndroid Build Coastguard Worker # mutable reference to temporary. Maybe we could assign it to a 1200*da0073e9SAndroid Build Coastguard Worker # variable itself.) 1201*da0073e9SAndroid Build Coastguard Worker returns_without_annotation = tuple( 1202*da0073e9SAndroid Build Coastguard Worker Return(r.name, r.type, None) for r in f.func.returns 1203*da0073e9SAndroid Build Coastguard Worker ) 1204*da0073e9SAndroid Build Coastguard Worker return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() 1205*da0073e9SAndroid Build Coastguard Worker if return_str not in SUPPORTED_RETURN_TYPES: 1206*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") 1207*da0073e9SAndroid Build Coastguard Worker return return_str 1208*da0073e9SAndroid Build Coastguard Worker 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Workerdef cpp_dispatch_target(f: NativeFunction) -> str: 1211*da0073e9SAndroid Build Coastguard Worker symint = f.func.has_symint() 1212*da0073e9SAndroid Build Coastguard Worker name = cpp.name(f.func, symint_overload=symint) 1213*da0073e9SAndroid Build Coastguard Worker if Variant.method in f.variants: 1214*da0073e9SAndroid Build Coastguard Worker return f"self.{name}" 1215*da0073e9SAndroid Build Coastguard Worker if Variant.function in f.variants: 1216*da0073e9SAndroid Build Coastguard Worker if has_tensor_options(f) or f.func.name.name.base.endswith("_like"): 1217*da0073e9SAndroid Build Coastguard Worker namespace = "torch" 1218*da0073e9SAndroid Build Coastguard Worker else: 1219*da0073e9SAndroid Build Coastguard Worker namespace = "at" 1220*da0073e9SAndroid Build Coastguard Worker return f"{namespace}::{name}" 1221*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}") 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Workerdef cpp_dispatch_exprs( 1225*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 1226*da0073e9SAndroid Build Coastguard Worker *, 1227*da0073e9SAndroid Build Coastguard Worker python_signature: PythonSignature | None = None, 1228*da0073e9SAndroid Build Coastguard Worker) -> tuple[str, ...]: 1229*da0073e9SAndroid Build Coastguard Worker cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker exprs: tuple[str, ...] = () 1232*da0073e9SAndroid Build Coastguard Worker if not isinstance(python_signature, PythonSignatureDeprecated): 1233*da0073e9SAndroid Build Coastguard Worker # By default the exprs are consistent with the C++ signature. 1234*da0073e9SAndroid Build Coastguard Worker exprs = tuple(a.name for a in cpp_args) 1235*da0073e9SAndroid Build Coastguard Worker else: 1236*da0073e9SAndroid Build Coastguard Worker # For deprecated python signature we may need fill in some constants. 1237*da0073e9SAndroid Build Coastguard Worker exprs = tuple( 1238*da0073e9SAndroid Build Coastguard Worker filter( 1239*da0073e9SAndroid Build Coastguard Worker lambda n: n != "out" or f.func.is_out_fn(), 1240*da0073e9SAndroid Build Coastguard Worker python_signature.deprecated_args_exprs, 1241*da0073e9SAndroid Build Coastguard Worker ) 1242*da0073e9SAndroid Build Coastguard Worker ) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker if Variant.method in f.variants: 1245*da0073e9SAndroid Build Coastguard Worker exprs = tuple(filter("self".__ne__, exprs)) 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker return exprs 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1251*da0073e9SAndroid Build Coastguard Worker# 1252*da0073e9SAndroid Build Coastguard Worker# Python / C++ Args Binding 1253*da0073e9SAndroid Build Coastguard Worker# 1254*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Worker# We explicitly enumerate the PythonArgParser unpacking methods for all 1258*da0073e9SAndroid Build Coastguard Worker# supported types. This might be more verbose than necessary, partially 1259*da0073e9SAndroid Build Coastguard Worker# because of the irregularity of unpacking method naming, partially 1260*da0073e9SAndroid Build Coastguard Worker# because we want to mimic the old codegen behavior - to reject 1261*da0073e9SAndroid Build Coastguard Worker# unexpected and/or unsupported cases which the old codegen rejects. 1262*da0073e9SAndroid Build Coastguard Worker# For certain cases it is intentionally more restrictive than necessary, 1263*da0073e9SAndroid Build Coastguard Worker# e.g.: it doesn't accepts doublelist with definite size. 1264*da0073e9SAndroid Build Coastguard Workerdef arg_parser_unpack_method( 1265*da0073e9SAndroid Build Coastguard Worker t: Type, default: str | None, default_init: str | None, *, symint: bool = True 1266*da0073e9SAndroid Build Coastguard Worker) -> str: 1267*da0073e9SAndroid Build Coastguard Worker has_default_init = default_init is not None 1268*da0073e9SAndroid Build Coastguard Worker if has_default_init and str(t) not in ( 1269*da0073e9SAndroid Build Coastguard Worker "ScalarType?", 1270*da0073e9SAndroid Build Coastguard Worker "ScalarType", 1271*da0073e9SAndroid Build Coastguard Worker "Device", 1272*da0073e9SAndroid Build Coastguard Worker "Device?", 1273*da0073e9SAndroid Build Coastguard Worker "Layout", 1274*da0073e9SAndroid Build Coastguard Worker "Layout?", 1275*da0073e9SAndroid Build Coastguard Worker "bool", 1276*da0073e9SAndroid Build Coastguard Worker "bool?", 1277*da0073e9SAndroid Build Coastguard Worker ): 1278*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"type '{t}' does not supported unpacking with default") 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 1281*da0073e9SAndroid Build Coastguard Worker if t.name in [ 1282*da0073e9SAndroid Build Coastguard Worker BaseTy.Tensor, 1283*da0073e9SAndroid Build Coastguard Worker BaseTy.Stream, 1284*da0073e9SAndroid Build Coastguard Worker BaseTy.Storage, 1285*da0073e9SAndroid Build Coastguard Worker BaseTy.Scalar, 1286*da0073e9SAndroid Build Coastguard Worker BaseTy.Dimname, 1287*da0073e9SAndroid Build Coastguard Worker ]: 1288*da0073e9SAndroid Build Coastguard Worker # These unpack methods line up with their schema names 1289*da0073e9SAndroid Build Coastguard Worker return t.name.name.lower() 1290*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.ScalarType: 1291*da0073e9SAndroid Build Coastguard Worker return "scalartypeWithDefault" if has_default_init else "scalartype" 1292*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Device: 1293*da0073e9SAndroid Build Coastguard Worker return "deviceWithDefault" if has_default_init else "device" 1294*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.DeviceIndex: 1295*da0073e9SAndroid Build Coastguard Worker return "toInt64" 1296*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.int: 1297*da0073e9SAndroid Build Coastguard Worker return "toInt64" 1298*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.SymInt: 1299*da0073e9SAndroid Build Coastguard Worker return "toSymInt" if symint else "toInt64" 1300*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.bool: 1301*da0073e9SAndroid Build Coastguard Worker return "toBoolWithDefault" if has_default_init else "toBool" 1302*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.float: 1303*da0073e9SAndroid Build Coastguard Worker return "toDouble" 1304*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.str: 1305*da0073e9SAndroid Build Coastguard Worker return "stringView" 1306*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Layout: 1307*da0073e9SAndroid Build Coastguard Worker return "layoutWithDefault" if has_default_init else "layout" 1308*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.MemoryFormat: 1309*da0073e9SAndroid Build Coastguard Worker return "memoryformat" 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 1312*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "Tensor": 1313*da0073e9SAndroid Build Coastguard Worker return "optionalTensor" 1314*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Generator": 1315*da0073e9SAndroid Build Coastguard Worker return "generator" 1316*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Dimname[]": 1317*da0073e9SAndroid Build Coastguard Worker return "toDimnameListOptional" 1318*da0073e9SAndroid Build Coastguard Worker elif not has_default_init and default in ( 1319*da0073e9SAndroid Build Coastguard Worker None, 1320*da0073e9SAndroid Build Coastguard Worker "None", 1321*da0073e9SAndroid Build Coastguard Worker "::std::nullopt", 1322*da0073e9SAndroid Build Coastguard Worker "std::nullopt", 1323*da0073e9SAndroid Build Coastguard Worker ): 1324*da0073e9SAndroid Build Coastguard Worker # If default is None: append 'Optional' to elem's unpacking method 1325*da0073e9SAndroid Build Coastguard Worker return ( 1326*da0073e9SAndroid Build Coastguard Worker arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" 1327*da0073e9SAndroid Build Coastguard Worker ) 1328*da0073e9SAndroid Build Coastguard Worker else: 1329*da0073e9SAndroid Build Coastguard Worker # Otherwise, load as underlying type with default 1330*da0073e9SAndroid Build Coastguard Worker return arg_parser_unpack_method( 1331*da0073e9SAndroid Build Coastguard Worker t.elem, default, default_init, symint=symint 1332*da0073e9SAndroid Build Coastguard Worker ) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 1335*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "Tensor": 1336*da0073e9SAndroid Build Coastguard Worker # accept and use definite size 1337*da0073e9SAndroid Build Coastguard Worker return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist" 1338*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Tensor?": 1339*da0073e9SAndroid Build Coastguard Worker return "list_of_optional_tensors" 1340*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Dimname": 1341*da0073e9SAndroid Build Coastguard Worker # accept definite size 1342*da0073e9SAndroid Build Coastguard Worker return "dimnamelist" 1343*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "int": 1344*da0073e9SAndroid Build Coastguard Worker # accept definite size 1345*da0073e9SAndroid Build Coastguard Worker return "intlist" 1346*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "float": 1347*da0073e9SAndroid Build Coastguard Worker return "doublelist" 1348*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "SymInt": 1349*da0073e9SAndroid Build Coastguard Worker # accept definite size 1350*da0073e9SAndroid Build Coastguard Worker return "symintlist" if symint else "intlist" 1351*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Scalar": 1352*da0073e9SAndroid Build Coastguard Worker return "scalarlist" 1353*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") 1354*da0073e9SAndroid Build Coastguard Worker 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker# Return RHS expression for python argument using PythonArgParser output. 1357*da0073e9SAndroid Build Coastguard Worker# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' 1358*da0073e9SAndroid Build Coastguard Workerdef arg_parser_output_expr( 1359*da0073e9SAndroid Build Coastguard Worker arg_index: int, a: PythonArgument, *, symint: bool = True 1360*da0073e9SAndroid Build Coastguard Worker) -> PythonArgParserOutputExpr: 1361*da0073e9SAndroid Build Coastguard Worker has_default = a.default_init is not None 1362*da0073e9SAndroid Build Coastguard Worker unpack_method = arg_parser_unpack_method( 1363*da0073e9SAndroid Build Coastguard Worker t=a.type, default=a.default, default_init=a.default_init, symint=symint 1364*da0073e9SAndroid Build Coastguard Worker ) 1365*da0073e9SAndroid Build Coastguard Worker default = f", {a.default_init}" if has_default else "" 1366*da0073e9SAndroid Build Coastguard Worker expr = f"_r.{unpack_method}({arg_index}{default})" 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker return PythonArgParserOutputExpr( 1369*da0073e9SAndroid Build Coastguard Worker name=a.name, 1370*da0073e9SAndroid Build Coastguard Worker expr=expr, 1371*da0073e9SAndroid Build Coastguard Worker index=arg_index, 1372*da0073e9SAndroid Build Coastguard Worker argument=a, 1373*da0073e9SAndroid Build Coastguard Worker ) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker 1376*da0073e9SAndroid Build Coastguard Worker# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. 1377*da0073e9SAndroid Build Coastguard Workerdef arg_parser_output_exprs( 1378*da0073e9SAndroid Build Coastguard Worker ps: PythonSignature, f: NativeFunction, *, symint: bool = True 1379*da0073e9SAndroid Build Coastguard Worker) -> dict[str, PythonArgParserOutputExpr]: 1380*da0073e9SAndroid Build Coastguard Worker return { 1381*da0073e9SAndroid Build Coastguard Worker e.name: e 1382*da0073e9SAndroid Build Coastguard Worker for i, a in enumerate(ps.arguments()) 1383*da0073e9SAndroid Build Coastguard Worker for e in (arg_parser_output_expr(i, a, symint=symint),) 1384*da0073e9SAndroid Build Coastguard Worker } 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker# argument name to type for scattered tensor options fields 1388*da0073e9SAndroid Build Coastguard WorkerTENSOR_OPTIONS_FIELDS = { 1389*da0073e9SAndroid Build Coastguard Worker "dtype": "ScalarType?", 1390*da0073e9SAndroid Build Coastguard Worker "device": "Device?", 1391*da0073e9SAndroid Build Coastguard Worker "layout": "Layout?", 1392*da0073e9SAndroid Build Coastguard Worker "pin_memory": "bool?", 1393*da0073e9SAndroid Build Coastguard Worker "requires_grad": "bool?", 1394*da0073e9SAndroid Build Coastguard Worker} 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). 1398*da0073e9SAndroid Build Coastguard Workerdef dispatch_lambda_exprs( 1399*da0073e9SAndroid Build Coastguard Worker ps: PythonSignature, f: NativeFunction, *, symint: bool = True 1400*da0073e9SAndroid Build Coastguard Worker) -> DispatchLambdaArgumentExprs: 1401*da0073e9SAndroid Build Coastguard Worker # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing 1402*da0073e9SAndroid Build Coastguard Worker # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser 1403*da0073e9SAndroid Build Coastguard Worker # outputs. 1404*da0073e9SAndroid Build Coastguard Worker arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) 1405*da0073e9SAndroid Build Coastguard Worker lambda_args = dispatch_lambda_args(ps, f, symint=symint) 1406*da0073e9SAndroid Build Coastguard Worker inits: list[str] = [] 1407*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs: dict[str, str] = {} 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker has_toptions = has_tensor_options(f) 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker # 1. special inits/unpacking to provide binding exprs for lambda arguments. 1412*da0073e9SAndroid Build Coastguard Worker for a in ps.arguments(skip_tensor_options=True): 1413*da0073e9SAndroid Build Coastguard Worker name = a.name 1414*da0073e9SAndroid Build Coastguard Worker arg_parser_expr = arg_parser_outputs[a.name].expr 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker if has_toptions and name == "self": 1417*da0073e9SAndroid Build Coastguard Worker # TODO: why this needs to be special case? 1418*da0073e9SAndroid Build Coastguard Worker inits.extend( 1419*da0073e9SAndroid Build Coastguard Worker [ 1420*da0073e9SAndroid Build Coastguard Worker f"auto self = {arg_parser_expr};", 1421*da0073e9SAndroid Build Coastguard Worker ] 1422*da0073e9SAndroid Build Coastguard Worker ) 1423*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs[name] = name 1424*da0073e9SAndroid Build Coastguard Worker elif ( 1425*da0073e9SAndroid Build Coastguard Worker isinstance(a, PythonOutArgument) 1426*da0073e9SAndroid Build Coastguard Worker and len(a.outputs) > 1 1427*da0073e9SAndroid Build Coastguard Worker and f.func.is_out_fn() 1428*da0073e9SAndroid Build Coastguard Worker ): 1429*da0073e9SAndroid Build Coastguard Worker inits.extend( 1430*da0073e9SAndroid Build Coastguard Worker [ 1431*da0073e9SAndroid Build Coastguard Worker f"auto out = {arg_parser_expr};", 1432*da0073e9SAndroid Build Coastguard Worker ] 1433*da0073e9SAndroid Build Coastguard Worker ) 1434*da0073e9SAndroid Build Coastguard Worker for i, out_arg in enumerate(a.outputs): 1435*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs[out_arg.name] = f"out[{i}]" 1436*da0073e9SAndroid Build Coastguard Worker elif str(a.type) == "Dimname[]?": 1437*da0073e9SAndroid Build Coastguard Worker # [old codegen] 1438*da0073e9SAndroid Build Coastguard Worker # TODO: make this part of something more general, or get rid of it. 1439*da0073e9SAndroid Build Coastguard Worker # optional<ArrayRef<T>> are special. The PythonArgParser returns an 1440*da0073e9SAndroid Build Coastguard Worker # optional<vector<T>>, which cannot be implicitly converted to 1441*da0073e9SAndroid Build Coastguard Worker # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap. 1442*da0073e9SAndroid Build Coastguard Worker inits.extend( 1443*da0073e9SAndroid Build Coastguard Worker [ 1444*da0073e9SAndroid Build Coastguard Worker f"auto __{name} = {arg_parser_expr};", 1445*da0073e9SAndroid Build Coastguard Worker f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950 1446*da0073e9SAndroid Build Coastguard Worker ] 1447*da0073e9SAndroid Build Coastguard Worker ) 1448*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs[name] = name 1449*da0073e9SAndroid Build Coastguard Worker else: 1450*da0073e9SAndroid Build Coastguard Worker # default case - directly using PythonArgParser output expr 1451*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs[name] = arg_parser_expr 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker # method's self is passed directly to python binding, rather than parsed 1454*da0073e9SAndroid Build Coastguard Worker if ps.method: 1455*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs["self"] = "self" 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker # 2. special packing/checking for TensorOptions. 1458*da0073e9SAndroid Build Coastguard Worker tensor_options_args_names = [a.name for a in ps.tensor_options_args] 1459*da0073e9SAndroid Build Coastguard Worker if has_toptions: 1460*da0073e9SAndroid Build Coastguard Worker if f.func.is_out_fn(): 1461*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"{f.func}: tensor options with output arg") 1462*da0073e9SAndroid Build Coastguard Worker for a in ps.tensor_options_args: 1463*da0073e9SAndroid Build Coastguard Worker if a.name not in TENSOR_OPTIONS_FIELDS: 1464*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1465*da0073e9SAndroid Build Coastguard Worker f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments" 1466*da0073e9SAndroid Build Coastguard Worker ) 1467*da0073e9SAndroid Build Coastguard Worker if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): 1468*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1469*da0073e9SAndroid Build Coastguard Worker f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" 1470*da0073e9SAndroid Build Coastguard Worker ) 1471*da0073e9SAndroid Build Coastguard Worker if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS): 1472*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1473*da0073e9SAndroid Build Coastguard Worker f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" 1474*da0073e9SAndroid Build Coastguard Worker ) 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker inits.append( 1477*da0073e9SAndroid Build Coastguard Worker f"""\ 1478*da0073e9SAndroid Build Coastguard Workerconst auto options = TensorOptions() 1479*da0073e9SAndroid Build Coastguard Worker .dtype({arg_parser_outputs['dtype'].expr}) 1480*da0073e9SAndroid Build Coastguard Worker .device({arg_parser_outputs['device'].expr}) 1481*da0073e9SAndroid Build Coastguard Worker .layout({arg_parser_outputs['layout'].expr}) 1482*da0073e9SAndroid Build Coastguard Worker .requires_grad({arg_parser_outputs['requires_grad'].expr}) 1483*da0073e9SAndroid Build Coastguard Worker .pinned_memory({arg_parser_outputs['pin_memory'].expr}); 1484*da0073e9SAndroid Build Coastguard Workertorch::utils::maybe_initialize_device(options); 1485*da0073e9SAndroid Build Coastguard Worker""" 1486*da0073e9SAndroid Build Coastguard Worker ) 1487*da0073e9SAndroid Build Coastguard Worker lambda_args_exprs["options"] = "options" 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker # 3. special case - access scattered TensorOptions fields without packing 1490*da0073e9SAndroid Build Coastguard Worker # TODO: maybe move to the generator side as it's not related to binding. 1491*da0073e9SAndroid Build Coastguard Worker if not has_toptions and tensor_options_args_names: 1492*da0073e9SAndroid Build Coastguard Worker if "dtype" in tensor_options_args_names: 1493*da0073e9SAndroid Build Coastguard Worker # we're an output-arg variant, check these args against output tensor 1494*da0073e9SAndroid Build Coastguard Worker if not f.func.is_out_fn(): 1495*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1496*da0073e9SAndroid Build Coastguard Worker f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}" 1497*da0073e9SAndroid Build Coastguard Worker ) 1498*da0073e9SAndroid Build Coastguard Worker if not all(a in tensor_options_args_names for a in ("layout", "device")): 1499*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1500*da0073e9SAndroid Build Coastguard Worker f"{f.func}: incomplete tensor options for output check" 1501*da0073e9SAndroid Build Coastguard Worker ) 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker inits.append( 1504*da0073e9SAndroid Build Coastguard Worker f"""\ 1505*da0073e9SAndroid Build Coastguard Workercheck_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr}, 1506*da0073e9SAndroid Build Coastguard Worker {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr}, 1507*da0073e9SAndroid Build Coastguard Worker {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr}); 1508*da0073e9SAndroid Build Coastguard Worker""" 1509*da0073e9SAndroid Build Coastguard Worker ) 1510*da0073e9SAndroid Build Coastguard Worker # we'll set requires_grad on outgoing tensor 1511*da0073e9SAndroid Build Coastguard Worker if "requires_grad" not in tensor_options_args_names: 1512*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1513*da0073e9SAndroid Build Coastguard Worker f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]' 1514*da0073e9SAndroid Build Coastguard Worker ) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker return DispatchLambdaArgumentExprs( 1517*da0073e9SAndroid Build Coastguard Worker exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args), 1518*da0073e9SAndroid Build Coastguard Worker inits=inits, 1519*da0073e9SAndroid Build Coastguard Worker ) 1520