1from __future__ import annotations 2 3from dataclasses import dataclass 4from typing import Sequence 5 6from torchgen.api import cpp 7from torchgen.api.types import Binding, CppSignature, CppSignatureGroup 8from torchgen.gen import pythonify_default 9from torchgen.model import ( 10 Argument, 11 BaseTy, 12 BaseType, 13 FunctionSchema, 14 ListType, 15 NativeFunction, 16 OptionalType, 17 Return, 18 Type, 19 Variant, 20) 21 22 23# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 24# 25# Data Models 26# 27# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 28# 29# [Notes] python binding codegen 30# 31# The Python binding codegen produces code that takes the input list of 32# PyObjects, finds the matching ATen C++ function using PythonArgParser, 33# converts the PyObjects into C++ types and calls the ATen C++ function: 34# 35# +--------+ parsing +------------------------+ binding +-----------------------+ 36# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch | 37# +--------+ +------------------------+ +-----------------------+ 38# 39# The following examples demonstrate the data models the Python binding 40# codegen needs to deal with and the tasks it needs to accomplish. It 41# helps understand the purpose of the new data types we introduced below. 42# 43# - Function Schema (source of truth) 44# 45# aten::empty.names(int[] size, *, Dimname[]? names, 46# ScalarType? dtype=None, Layout? layout=None, 47# Device? device=None, bool? pin_memory=None, 48# MemoryFormat? memory_format=None) -> Tensor 49# 50# - Python Signature 51# 52# It's used to generate input schema string for PythonArgParser. 53# Note: TensorOptions fields are reordered and the additional 54# 'requires_grad' field is added: 55# 56# empty(IntArrayRef size, *, DimnameList? names, 57# MemoryFormat? memory_format=None, ScalarType dtype=None, 58# Layout layout=torch.strided, Device device=None, 59# bool pin_memory=False, bool requires_grad=False) 60# 61# - C++ Signature 62# 63# It's used to generate C++ lambda formals & dispatch call. 64# Note: the scattered TensorOptions fields are packed into 'options'. 65# 66# auto dispatch_empty = 67# [](IntArrayRef size, std::optional<DimnameList> names, 68# const TensorOptions & options, 69# std::optional<MemoryFormat> memory_format) -> Tensor { 70# pybind11::gil_scoped_release no_gil; 71# return torch::empty(size, names, options, memory_format); 72# }; 73# 74# - Binding between Python Arguments and C++ Arguments 75# 76# Given a set of Python Arguments in scope, we need produce the 77# binding expressions that translate the Python API into C++ API: 78# 79# Python Args Cpp Args Binding Exprs 80# ----------------------------------------------------------------- 81# 0: size size '_r.intlist(0)' 82# 1: names names 'names' [special init] 83# 2: memory_format -------+ 84# 3: dtype -----+-|--> options 'options' [special packing] 85# 4: layout / | 86# 5: device / +--> memory_format '_r.memoryformatOptional(2)' 87# 6: pin_memory / 88# 7: requires_grad -+ 89# 90# So the full dispatch expression would look like: 91# 92# dispatch_empty(_r.intlist(0), names, options, 93# _r.memoryformatOptional(2)) 94# 95# Where does 'names' come from? It involves special local init: 96# 97# auto __names = _r.toDimnameListOptional(1); 98# std::optional<DimnameList> names = 99# __names ? std::make_optional(DimnameList(__names.value())) 100# : std::nullopt; 101# 102# Where does 'options' come from? It involves special local init 103# for TensorOptions. Note that Python side has the additional 104# 'requires_grad' field: 105# 106# const auto options = TensorOptions() 107# .dtype(_r.scalartype(3)) 108# .device(_r.device(5)) 109# .layout(_r.layoutOptional(4)) 110# .requires_grad(_r.toBool(7)) 111# .pinned_memory(_r.toBool(6)); 112# 113# In some other cases one Python Argument can map to multiple C++ 114# Arguments. For example: 115# 116# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) 117# -> (Tensor values, Tensor indices) 118# 119# Python Args Cpp Args Binding Exprs 120# --------------------------------------------------------------------- 121# +----> max 'out[0]' 122# /-----> max_values 'out[1] 123# 0: input / self '_r.tensor(0)' 124# 1: dim / dim '_r.dimname(1)' 125# 2: keepdim / keepdim '_r.toBool(2)' 126# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)' 127# 128# As demonstrated above, the binding can involve reordering, 129# packing, unpacking and special local inits. 130# 131# 132# Let's look at a concrete example: 133# 134# static PythonArgParser parser({ 135# "abs(Tensor input, *, Tensor out=None)", 136# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 137# ^ 138# +--- Python Schema, represented by PythonSignature and PythonArgument 139# 140# }, /*traceable=*/true); 141# 142# ParsedArgs<2> parsed_args; 143# auto _r = parser.parse(nullptr, args, kwargs, parsed_args); 144# 145# ... 146# 147# if (_r.isNone(1)) { 148# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out') 149# represented by PythonArgParserOutputExpr 150# 151# // aten::abs(Tensor self) -> Tensor 152# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 153# ^ 154# +--- NativeFunction schema, base version 155# 156# auto dispatch_abs = [](const Tensor & self) -> Tensor { 157# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 158# ^ 159# +--- dispatch_lambda_args / dispatch_lambda_return_str 160# generated from NativeFunction / CppSignature 161# (deprecated PythonSignature is special) 162# arguments are represented by DispatchLambdaArgument 163# 164# pybind11::gil_scoped_release no_gil; 165# return self.abs(); 166# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs 167# generated from NativeFunction / CppSignature 168# }; 169# return wrap(dispatch_abs(_r.tensor(0))); 170# ~~~~~~~~~~~~~ 171# ^ 172# +--- dispatch_lambda_exprs 173# binding PythonArgParserOutputExpr (python args) 174# and DispatchLambdaArgument (c++ args) 175# 176# } else { 177# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 178# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 179# ^ 180# +--- NativeFunction schema, out-variant 181# 182# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor { 183# pybind11::gil_scoped_release no_gil; 184# return at::abs_out(out, self); 185# }; 186# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); 187# } 188# 189# 190# [Notes] python interface codegen 191# The python dataclasses below are used used to generate both python binding code 192# and pyi type hint signatures. 193# In theory these two should look very similar, but there are number of differences 194# in how pyi signatures vs. python_arg_parser signatures are generated. 195# These differences have been encapsulated in signature_str() vs. signature_str_pyi() 196# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. 197# For examples, only pyi signatures include return types. 198 199 200@dataclass(frozen=True) 201class PythonReturns: 202 returns: tuple[Return, ...] 203 204 205@dataclass(frozen=True) 206class PythonArgument: 207 name: str 208 type: Type 209 default: str | None 210 211 # Used to generate the default init expr for some PythonArgParser outputs, e.g.: 212 # 213 # _r.layoutWithDefault(3, layout_from_backend(self.options().backend()))) 214 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 215 # ^ 216 # +--- default_init str 217 default_init: str | None 218 219 # Compute argument formal for python argument parsing. 220 # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. 221 def argument_str(self, *, method: bool = False, symint: bool = True) -> str: 222 type_str = ( 223 argument_type_str(self.type, symint=symint) 224 .replace("const ", "") 225 .replace(" &", "") 226 ) 227 228 name = self.name 229 # s/self/input/ outside method bindings 230 # [old codegen] TODO: remove this? doesn't rename in codegen, it's just 231 # for the parse string 232 if name == "self" and type_str in ["Tensor", "Number"] and not method: 233 name = "input" 234 235 # add default 236 if self.default is not None: 237 default = { 238 "nullptr": "None", 239 "::std::nullopt": "None", 240 "std::nullopt": "None", 241 "{}": "None", 242 }.get(self.default, self.default) 243 return f"{type_str} {name}={default}" 244 else: 245 return f"{type_str} {name}" 246 247 def argument_str_pyi( 248 self, *, method: bool = False, deprecated: bool = False 249 ) -> str: 250 type_str = argument_type_str_pyi(self.type) 251 252 name = self.name 253 # s/self/input/ outside method bindings 254 # [old codegen] TODO: remove this? doesn't rename in codegen, it's just 255 # for the parse string 256 if name == "self" and type_str == "Tensor" and not method and not deprecated: 257 name = "input" 258 259 if name == "from": # from is a Python keyword... 260 name += "_" 261 262 # pyi merges the _out and functional variants into the same signature, with an optional out arg 263 if name == "out" and type_str == "Tensor" and not deprecated: 264 type_str = "Optional[" + type_str + "]" 265 266 # pyi deprecated signatures don't get defaults for their out arg 267 treat_as_no_default = ( 268 deprecated 269 and isinstance(self, PythonOutArgument) 270 and self.default == "None" 271 ) 272 273 # add default 274 if self.default is not None and not treat_as_no_default: 275 if ( 276 isinstance(self.type, ListType) 277 and self.type.elem == BaseType(BaseTy.int) 278 and self.default.startswith("{") 279 and self.default.endswith("}") 280 ): 281 default = ( 282 "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")" 283 ) 284 else: 285 default = { 286 "nullptr": "None", 287 "::std::nullopt": "None", 288 "std::nullopt": "None", 289 "{}": "None", 290 "c10::MemoryFormat::Contiguous": "contiguous_format", 291 "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine", 292 }.get(self.default, self.default) 293 return f"{name}: {type_str} = {default}" 294 else: 295 return f"{name}: {type_str}" 296 297 298@dataclass(frozen=True) 299class PythonOutArgument(PythonArgument): 300 # In Python signature multiple output fields are packed into one 'out' argument. 301 # When binding to C++, it's first binded to a local 'out' variable: 302 # 'auto out = _r.tensorlist_n<2>(2);', 303 # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. 304 # TODO: maybe don't need keep scattered out fields for python signature? 305 outputs: tuple[PythonArgument, ...] 306 307 @staticmethod 308 def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: 309 if not outputs: 310 return None 311 312 size = len(outputs) 313 if size == 1: 314 return PythonOutArgument( 315 name=outputs[0].name, 316 type=outputs[0].type, 317 default="None", 318 default_init=None, 319 outputs=outputs, 320 ) 321 elif size > 1: 322 if any(not a.type.is_tensor_like() for a in outputs): 323 raise RuntimeError(f"Unsupported output type: {outputs}") 324 return PythonOutArgument( 325 name="out", 326 # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? 327 type=ListType(BaseType(BaseTy.Tensor), size), 328 default="None", 329 default_init=None, 330 outputs=outputs, 331 ) 332 raise AssertionError(r"Unexpected PythonOutArgument size") 333 334 335@dataclass(frozen=True) 336class PythonSignature: 337 # Base operator name, without inplace/outplace suffix. 338 name: str 339 340 # Positional arguments. 341 # TODO: create a dedicated SelfArgument type for 'self'? 342 input_args: tuple[PythonArgument, ...] 343 344 # Keyword arguments excluding the 'out' argument and scattered kwargs belonging 345 # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). 346 input_kwargs: tuple[PythonArgument, ...] 347 348 output_args: PythonOutArgument | None 349 350 # Return types, which are only used by pyi 351 returns: PythonReturns 352 353 # These are scattered kwargs arguments belonging to TensorOptions. 354 # When binding to C++, they are packed into a TensorOptions object 'options'. 355 # It's possible that the C++ signature doesn't take TensorOptions object (e.g. 356 # for out variant), in which case they will be used as scattered fields without 357 # being packed into 'options'. 358 # TODO: maybe create a PythonTensorOptionsArgument? 359 tensor_options_args: tuple[PythonArgument, ...] 360 361 # method or function signature? 362 method: bool 363 364 @property 365 def deprecated(self) -> bool: 366 return False 367 368 def arguments( 369 self, *, skip_outputs: bool = False, skip_tensor_options: bool = False 370 ) -> tuple[PythonArgument | PythonOutArgument, ...]: 371 result: list[PythonArgument | PythonOutArgument] = [] 372 result.extend(self.input_args) 373 result.extend(self.input_kwargs) 374 if self.output_args is not None and not skip_outputs: 375 result.append(self.output_args) 376 if not skip_tensor_options: 377 result.extend(self.tensor_options_args) 378 return tuple(result) 379 380 def arguments_count(self) -> int: 381 return len(self.arguments()) 382 383 def output_idx(self) -> int: 384 return len(self.input_args) + len(self.input_kwargs) 385 386 # [old codegen] Compute the Python function signature for argument parsing, 387 # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: 388 # this is NOT the same type signature as specified by PEP 484 389 # as understood by mypy; our format was independently developed 390 # and has some quirks to make it more suitable specifically 391 # for error parsing. 392 # 393 # For a translation to mypy-valid type signatures, see 394 # signature_str_pyi(). 395 def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: 396 args = self.arguments(skip_outputs=skip_outputs) 397 schema_formals: list[str] = [ 398 a.argument_str(method=self.method, symint=symint) for a in args 399 ] 400 positional_argc = len(self.input_args) 401 if len(schema_formals) > positional_argc: 402 schema_formals.insert(positional_argc, "*") 403 404 return f'{self.name}({", ".join(schema_formals)})' 405 406 def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: 407 args = self.arguments(skip_outputs=skip_outputs) 408 schema_formals: list[str] = [ 409 a.argument_str_pyi(method=self.method) for a in args 410 ] 411 positional_argc = len(self.input_args) 412 if len(schema_formals) > positional_argc: 413 schema_formals.insert(positional_argc, "*") 414 415 # only pyi signatures include returns 416 returns_str = returns_str_pyi(self) 417 # pyi also includes self (with no typing/defaults) for methods 418 if self.method: 419 schema_formals.insert(0, "self") 420 return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' 421 422 def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: 423 # only pyi uses vararg signatures 424 args = self.arguments(skip_outputs=skip_outputs) 425 schema_formals: list[str] = [ 426 a.argument_str_pyi(method=self.method) for a in args 427 ] 428 # vararg only applies to pyi signatures. vararg variants are not generated for all signatures 429 num_args = self.arguments_count() 430 num_positionalargs = len(self.input_args) 431 432 have_vararg_version = False 433 if num_args > 0: 434 vararg_type = args[0].type 435 if ( 436 isinstance(vararg_type, ListType) 437 and str(vararg_type.elem) in ["int", "SymInt"] 438 and num_positionalargs == 1 439 ): 440 have_vararg_version = True 441 442 if not have_vararg_version: 443 return None 444 445 # Below are the major changes in vararg vs. regular pyi signatures 446 # vararg signatures also omit the asterix 447 assert isinstance(vararg_type, ListType) 448 schema_formals[0] = ( 449 "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) 450 ) 451 452 returns_str = returns_str_pyi(self) 453 # pyi also includes self (with no typing/defaults) for methods 454 if self.method: 455 schema_formals.insert(0, "self") 456 return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' 457 458 459# The deprecated python signature involves some special logic, so create a 460# dedicated data model to store these extra properties. 461@dataclass(frozen=True) 462class PythonSignatureDeprecated(PythonSignature): 463 # Schema for the deprecated function 464 deprecated_schema: FunctionSchema 465 466 # The deprecated signature might miss some arguments that the corresponding 467 # C++ signature expects. We need store the constant default values to pass in. 468 # For example: 469 # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) 470 # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 471 # [func call]: self.addmm(mat1, mat2, beta, 1) 472 # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. 473 deprecated_args_exprs: tuple[str, ...] 474 475 @property 476 def deprecated(self) -> bool: 477 return True 478 479 def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: 480 return ( 481 PythonSignature.signature_str( 482 self, skip_outputs=skip_outputs, symint=symint 483 ) 484 + "|deprecated" 485 ) 486 487 def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: 488 args = self.arguments(skip_outputs=skip_outputs) 489 schema_formals: list[str] = [ 490 a.argument_str_pyi(method=self.method, deprecated=True) for a in args 491 ] 492 positional_argc = len(self.input_args) 493 if len(schema_formals) > positional_argc: 494 schema_formals.insert(positional_argc, "*") 495 496 returns_str = returns_str_pyi(self) 497 return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' 498 499 def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: 500 # the codegen doesn't include vararg variants for deprecated signatures 501 return None 502 503 504# This struct is used to hold the PythonSignature and its corresponding 505# NativeFunction BEFORE grouping base and out-variant functions. 506# Why not store NativeFunction in PythonSignature or construct PythonSignature 507# from NativeFunction? Because they are not 1-1 mapped. 508# One native function could have both deprecated and non-deprecated python 509# signatures - NativeFunction doesn't contain information to construct the 510# deprecated python signature. 511# One python signature is used to handle both the base and the out-variant 512# function - see 'PythonSignatureGroup'. 513@dataclass(frozen=True) 514class PythonSignatureNativeFunctionPair: 515 signature: PythonSignature 516 function: NativeFunction 517 518 519# We merge pairs of functions with signatures that are equivalent mod 520# output arguments, and use a single entry in the python_arg_parser sig 521# list for both (output arguments become optional). 522@dataclass(frozen=True) 523class PythonSignatureGroup: 524 # The signature used for Python argument parsing. The outplace signature 525 # is preferred if exists, because it can be used to parse inputs for both 526 # the out-place variant and the base version (with output omitted). 527 signature: PythonSignature 528 529 # The regular ATen declaration (e.g. conv2d) 530 base: NativeFunction 531 532 # The out variant (e.g. conv2d_out) 533 outplace: NativeFunction | None 534 535 @classmethod 536 def from_pairs( 537 cls, 538 functional: PythonSignatureNativeFunctionPair, 539 out: PythonSignatureNativeFunctionPair | None, 540 ) -> PythonSignatureGroup: 541 if out is None: 542 return PythonSignatureGroup( 543 signature=functional.signature, 544 base=functional.function, 545 outplace=None, 546 ) 547 548 # prefer the signature with optional out=... arguments because it's the 549 # superset that can be used to parse input for both base and outplace. 550 signature_kwargs = out.signature.__dict__.copy() 551 552 # Out overloads in C++ don't have TensorOptions arguments, 553 # so take these from the functional variant 554 signature_kwargs[ 555 "tensor_options_args" 556 ] = functional.signature.tensor_options_args 557 558 return PythonSignatureGroup( 559 signature=type(out.signature)(**signature_kwargs), 560 base=functional.function, 561 outplace=out.function, 562 ) 563 564 565# C++ function dispatch is wrapped in a lambda function. The lambda function 566# has almost the same signature as the C++ function, only with some small 567# variants - see details below. 568# This data model is used to represent arguments of the lambda function 569# signature. 570@dataclass(frozen=True) 571class DispatchLambdaArgument: 572 name: str 573 type_str: str 574 is_out_arg: bool 575 576 577# To pass PyObjects arguments to C++ function (via the lambda wrapper), 578# we need first convert PyObjects into simple C++ objects. This work 579# is done by PythonArgParser. 580# This data model is used to represent the output of PythonArgParser. 581# It has 1-1 mapping with PythonArgument in PythonSignature. 582@dataclass(frozen=True) 583class PythonArgParserOutputExpr: 584 # argument name 585 name: str 586 587 # RHS expression to reference PythonArgParser output. 588 expr: str 589 590 # In some special cases we need create different expr, e.g.: 591 # '_r.isNone(1)' instead of '_r.tensor(1)'. 592 index: int 593 594 # The python argument it maps to. 595 argument: PythonArgument 596 597 @property 598 def is_none_expr(self) -> str: 599 return f"_r.isNone({self.index})" 600 601 602# To pass PythonArgParser output to the lambda wrapper, we need bind 603# PythonArgParserOutputExpr to DispatchLambdaArgument. 604# They are not always 1-1 mapped, e.g. scattered TensorOptions fields 605# need be packed into a TensorOptions object, which is the argument 606# that the lambda function wrapper takes. 607@dataclass(frozen=True) 608class DispatchLambdaArgumentExprs: 609 # The exprs that provide the binding for lambda arguments, e.g.: 610 # 611 # 'self' -> '_r.tensor(0)' 612 # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]' 613 # 'options' -> 'options' 614 # 615 # It has 1-1 mapping with DispatchLambdaArgument. 616 exprs: Sequence[str] 617 618 # Special local inits, which might introduce new variables that 619 # the 'exprs' above reference, e.g.: 620 # 621 # 'auto out = _r.tensorlist_n<2>(2);' 622 # 623 inits: Sequence[str] 624 625 626# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 627# 628# Helper Functions 629# 630# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 631 632 633def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: 634 return CppSignatureGroup.from_native_function(f, method=method).signature 635 636 637def has_tensor_options(f: NativeFunction) -> bool: 638 return f.func.arguments.tensor_options is not None 639 640 641# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 642# 643# Python Signature 644# 645# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 646 647 648# 'simple_type' was introduced by the old codegen, which is slightly 649# different from the python schema type, e.g.: doesn't have '?' suffix 650# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. 651def argument_type_str( 652 t: Type, *, simple_type: bool = False, symint: bool = True 653) -> str: 654 if isinstance(t, BaseType): 655 if t.name == BaseTy.Tensor: 656 return "Tensor" 657 elif t.name == BaseTy.int: 658 return "int64_t" 659 elif t.name == BaseTy.float: 660 return "double" 661 elif t.name == BaseTy.str: 662 return "c10::string_view" 663 elif t.name in [ 664 BaseTy.bool, 665 BaseTy.QScheme, 666 BaseTy.Scalar, 667 BaseTy.ScalarType, 668 BaseTy.Generator, 669 BaseTy.Storage, 670 BaseTy.Layout, 671 BaseTy.Device, 672 BaseTy.DeviceIndex, 673 BaseTy.MemoryFormat, 674 BaseTy.Dimname, 675 BaseTy.Stream, 676 BaseTy.ConstQuantizerPtr, 677 BaseTy.SymInt, 678 ]: 679 # These python schema type names line up with their function schema names 680 return t.name.name 681 682 elif isinstance(t, OptionalType): 683 if str(t.elem) == "Tensor": 684 # Is it desired to keep '?' for simple_type with new style dispatcher? 685 return "Tensor?" 686 elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) 687 return f"{elem}?" 688 elif isinstance(t, ListType): 689 size = t.size if not simple_type else None 690 if str(t.elem) == "bool": 691 assert t.size is not None 692 return f"::std::array<bool,{t.size}>" 693 elif str(t.elem) == "int": 694 return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" 695 elif str(t.elem) == "SymInt": 696 if symint: 697 return ( 698 f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" 699 ) 700 else: 701 return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" 702 elif str(t.elem) == "Tensor": 703 return f"TensorList[{size}]" if size is not None else "TensorList" 704 elif str(t.elem) == "Scalar": 705 return f"ScalarList[{size}]" if size is not None else "ScalarList" 706 elif str(t.elem) == "Tensor?": 707 if simple_type: 708 return "c10::List<::std::optional<Tensor>>" 709 else: 710 return "const c10::List<::std::optional<Tensor>> &" 711 elif str(t.elem) == "Dimname": 712 return f"DimnameList[{size}]" if size is not None else "DimnameList" 713 elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) 714 return f"ArrayRef<{elem}>" 715 716 raise RuntimeError(f"unrecognized type {repr(t)}") 717 718 719def argument_type_size(t: Type) -> int | None: 720 l = t.is_list_like() 721 if l is not None and str(l.elem) != "bool": 722 return l.size 723 else: 724 return None 725 726 727def argument(a: Argument) -> PythonArgument: 728 return PythonArgument( 729 name=a.name, 730 type=a.type, 731 # TODO: directly translate a.default to python default 732 default=( 733 str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False))) 734 if a.default is not None 735 else None 736 ), 737 default_init=None, 738 ) 739 740 741# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen 742def signature( 743 f: NativeFunction, *, method: bool = False, pyi: bool = False 744) -> PythonSignature: 745 return signature_from_schema( 746 f.func, category_override=f.category_override, method=method, pyi=pyi 747 ) 748 749 750def signature_from_schema( 751 func: FunctionSchema, 752 *, 753 category_override: str | None, 754 method: bool = False, 755 pyi: bool = False, 756) -> PythonSignature: 757 args: list[Argument] = [] 758 args.extend(func.arguments.pre_self_positional) 759 # Skip SelfArgument if this is method. 760 if not method and func.arguments.self_arg is not None: 761 args.append(func.arguments.self_arg.argument) 762 args.extend(func.arguments.post_self_positional) 763 args.extend(func.arguments.pre_tensor_options_kwarg_only) 764 # Skip TensorOptionsArguments. Python side TensorOptions 765 # arguments are created based on different rules - see below. 766 args.extend(func.arguments.post_tensor_options_kwarg_only) 767 args.extend(func.arguments.out) 768 769 input_arg_set = {a.name for a in func.arguments.flat_positional} 770 kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} 771 out_arg_set = {a.name for a in func.arguments.out} 772 773 input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) 774 input_kwargs = tuple( 775 map(argument, filter(lambda a: a.name in kwarg_only_set, args)) 776 ) 777 outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) 778 779 # Reintroduce the scattered fields of TensorOptions for Python. 780 # Compared to the cpp counterpart, the python arguments have new property 781 # (default_init) and a new argument 'requires_grad', which require some 782 # special handlings. 783 # [old codegen] TODO: because these aren't guaranteed to be 100% faithful 784 # to the original versions in the yaml, this recreation is a potential 785 # source of drift between eager and JIT. Pull this logic out to a shared place. 786 787 has_tensor_input_arg = any( 788 a.type.is_tensor_like() for a in func.arguments.flat_non_out 789 ) 790 if any(a.name == "requires_grad" for a in func.schema_order_arguments()): 791 raise ValueError( 792 "argument named requires_grad is reserved, should not explicitly add it in the schema" 793 ) 794 795 # [old codegen] this probably won't work if one of the returns is not a tensor, 796 # but it will produce a compile-time error that is obvious. 797 has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) 798 799 name: str = cpp.name(func) 800 is_factory_function = category_override == "factory" or ( 801 has_tensor_return and not has_tensor_input_arg 802 ) 803 is_like_or_new_function = ( 804 category_override in ("new", "like") 805 or name.startswith("new_") 806 or name.endswith("_like") 807 ) 808 is_dummy_function = category_override == "dummy" 809 810 tensor_options_args: list[PythonArgument] = [] 811 if (is_factory_function or is_like_or_new_function) and not is_dummy_function: 812 813 def topt_default_init(name: str) -> str | None: 814 topt_args = func.arguments.tensor_options 815 if topt_args is None: 816 return None 817 a = getattr(topt_args, name) 818 if a.default is None or a.default == "None": 819 return None 820 return cpp.default_expr(a.default, a.type, symint=False) 821 822 tensor_options_args.append( 823 PythonArgument( 824 name="dtype", 825 type=OptionalType(BaseType(BaseTy.ScalarType)), 826 default="None", 827 default_init=( 828 None if is_like_or_new_function else topt_default_init("dtype") 829 ), 830 ) 831 ) 832 tensor_options_args.append( 833 PythonArgument( 834 name="layout", 835 type=OptionalType(BaseType(BaseTy.Layout)), 836 default="None", 837 default_init=( 838 None if is_like_or_new_function else topt_default_init("layout") 839 ), 840 ) 841 ) 842 tensor_options_args.append( 843 PythonArgument( 844 name="device", 845 type=OptionalType(BaseType(BaseTy.Device)), 846 default="None", 847 default_init=( 848 None 849 if is_like_or_new_function 850 else ( 851 topt_default_init("device") 852 or "torch::tensors::get_default_device()" 853 ) 854 ), 855 ) 856 ) 857 tensor_options_args.append( 858 PythonArgument( 859 name="pin_memory", 860 type=OptionalType(BaseType(BaseTy.bool)), 861 default="False", 862 default_init=None, 863 ) 864 ) 865 tensor_options_args.append( 866 PythonArgument( 867 name="requires_grad", 868 type=OptionalType(BaseType(BaseTy.bool)), 869 default="False", 870 default_init=None, 871 ) 872 ) 873 874 returns = PythonReturns(returns=func.returns) 875 876 return PythonSignature( 877 name=str(func.name.name), 878 input_args=input_args, 879 input_kwargs=input_kwargs, 880 output_args=PythonOutArgument.from_outputs(outputs), 881 tensor_options_args=tuple(tensor_options_args), 882 returns=returns, 883 method=method, 884 ) 885 886 887# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 888# 889# Python Interface 890# 891# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 892 893 894def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: 895 if len(returns) <= 1 or all(r.name is None for r in returns): 896 return [] 897 else: 898 if any(r.name is None for r in returns): 899 # When building on Windows, `PyStructSequence_UnnamedField` could not be 900 # resolved by the linker for some reason, which cause error in building: 901 # 902 # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol 903 # PyStructSequence_UnnamedField 904 # 905 # Thus, at this point in time, we do not support unnamed 906 # fields in structseq; you must either name all fields, 907 # or none of them. 908 raise ValueError("Unnamed field is not supported by codegen") 909 910 return [str(r.name) for r in returns] 911 912 913def argument_type_str_pyi(t: Type) -> str: 914 add_optional = False 915 if isinstance(t, OptionalType): 916 t = t.elem 917 add_optional = True 918 919 if isinstance(t, BaseType): 920 if t.name in [BaseTy.int, BaseTy.DeviceIndex]: 921 ret = "_int" 922 if t.name == BaseTy.SymInt: 923 ret = "Union[_int, SymInt]" 924 elif t.name == BaseTy.float: 925 ret = "_float" 926 elif t.name == BaseTy.str: 927 ret = "str" 928 elif t.name == BaseTy.Scalar: 929 ret = "Union[Number, _complex]" 930 elif t.name == BaseTy.ScalarType: 931 ret = "_dtype" 932 elif t.name == BaseTy.bool: 933 ret = "_bool" 934 elif t.name == BaseTy.QScheme: 935 ret = "_qscheme" 936 elif t.name == BaseTy.Layout: 937 ret = "_layout" 938 elif t.name == BaseTy.Device: 939 ret = "Optional[DeviceLikeType]" 940 elif t.name == BaseTy.MemoryFormat: 941 ret = "memory_format" 942 elif t.name == BaseTy.Dimname: 943 ret = "Union[str, ellipsis, None]" 944 elif t.name == BaseTy.Storage: 945 ret = "Union[Storage, UntypedStorage]" 946 elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]: 947 # These python schema type names line up with their function schema names 948 ret = t.name.name 949 950 elif isinstance(t, ListType): 951 if str(t.elem) == "int": 952 ret = "Union[_int, _size]" if t.size is not None else "_size" 953 elif t.is_tensor_like(): 954 # TODO: this doesn't seem right... 955 # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] 956 # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] 957 if isinstance(t.elem, OptionalType): 958 add_optional = True 959 ret = ( 960 "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]" 961 if t.size is not None 962 else "Union[Tuple[Tensor, ...], List[Tensor]]" 963 ) 964 elif str(t.elem) == "float": 965 ret = "Sequence[_float]" 966 elif str(t.elem) == "SymInt" and t.size is not None: 967 elem = argument_type_str_pyi(t.elem) 968 ret = f"Union[{elem}, Sequence[{elem}]]" 969 else: 970 elem = argument_type_str_pyi(t.elem) 971 ret = f"Sequence[{elem}]" 972 973 else: 974 raise RuntimeError(f"unrecognized type {repr(t)}") 975 976 if add_optional: 977 ret = "Optional[" + ret + "]" 978 979 return ret 980 981 982def return_type_str_pyi(t: Type) -> str: 983 # Where arguments are open to accepting Union, return types should return 984 # concrete types 985 986 if isinstance(t, OptionalType): 987 inner = return_type_str_pyi(t.elem) 988 return f"Optional[{inner}]" 989 990 if isinstance(t, BaseType): 991 if t.name == BaseTy.Device: 992 return "_device" 993 elif t.name == BaseTy.Dimname: 994 ret = "Optional[str]" 995 else: 996 return argument_type_str_pyi(t) 997 998 if isinstance(t, ListType): 999 inner = return_type_str_pyi(t.elem) 1000 return f"Tuple[{inner}, ...]" 1001 1002 return argument_type_str_pyi(t) 1003 1004 1005def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: 1006 python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] 1007 structseq_name = signature.name 1008 field_names = structseq_fieldnames(signature.returns.returns) 1009 if field_names: 1010 # These types are structseq objects which act like named NamedTuples, but 1011 # the constructor acts like the constructor of tuple. Using typing.NamedTuple 1012 # does not allow us to override __init__. 1013 seq_type = f"Tuple[{', '.join(python_returns)}]" 1014 structseq_def_lines = [ 1015 f"class {structseq_name}({seq_type}):", 1016 ] 1017 for name, typ in zip(field_names, python_returns): 1018 structseq_def_lines.extend( 1019 [ 1020 " @property", 1021 f" def {name}(self) -> {typ}: ...", 1022 ] 1023 ) 1024 structseq_def_lines.extend( 1025 [ 1026 f" def __new__(cls, sequence: {seq_type}): ...", 1027 f" n_fields: _int = {len(field_names)}", 1028 f" n_sequeunce_fields: _int = {len(field_names)}", 1029 " n_unnamed_fields: _int = 0", 1030 " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", 1031 "", # add an extra newline 1032 ] 1033 ) 1034 structseq_def = "\n".join(structseq_def_lines) 1035 # Example: 1036 # structseq_def = ( 1037 # "class max(Tuple[Tensor, Tensor]):\n" 1038 # " @property\n" 1039 # " def values(self) -> Tensor: ...\n" 1040 # " @property\n" 1041 # " def indices(self) -> Tensor: ...\n" 1042 # " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n" 1043 # " n_fields: _int = 2", 1044 # " n_sequeunce_fields: _int = 2", 1045 # " n_unnamed_fields: _int = 0", 1046 # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", 1047 # ) 1048 return structseq_name, structseq_def 1049 return None 1050 1051 1052def returns_str_pyi(signature: PythonSignature) -> str: 1053 field_names = structseq_fieldnames(signature.returns.returns) 1054 if field_names: 1055 return f"torch.return_types.{signature.name}" 1056 1057 python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] 1058 if len(python_returns) > 1: 1059 return "Tuple[" + ", ".join(python_returns) + "]" 1060 if len(python_returns) == 1: 1061 return python_returns[0] 1062 return "None" 1063 1064 1065# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1066# 1067# C++ Function Dispatch 1068# 1069# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1070# This section provides APIs to generate the code that does C++ function 1071# dispatch. The C++ function call is wrapped by a lambda function. 1072# For example: 1073# 1074# // aten::selu_(Tensor(a!) self) -> Tensor(a!) 1075# auto dispatch_selu_ = [](Tensor self) -> Tensor { 1076# pybind11::gil_scoped_release no_gil; 1077# return at::selu_(self); 1078# }; 1079# 1080# The lambda function's signature follows the C++ signature in common 1081# cases, e.g.: 1082# 1083# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 1084# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor 1085# 1086# For out variant the 'out' argument's type is changed from 'Tensor &' 1087# to 'Tensor'. It's because when calling the lambda it passes in the 1088# PythonArgParser output '_r.tensor(3)', which is stack allocated object 1089# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'. 1090# 1091# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 1092# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor 1093# 1094# For multi-output case it can keep using reference type because the 1095# PythonArgParser output has been unpacked to local variables, e.g.: 1096# 1097# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, 1098# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) 1099# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor> 1100# 1101# For deprecated python signature, it should follow deprecated python arg order. 1102# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary? 1103 1104 1105def dispatch_lambda_args( 1106 ps: PythonSignature, f: NativeFunction, symint: bool = True 1107) -> tuple[DispatchLambdaArgument, ...]: 1108 if isinstance(ps, PythonSignatureDeprecated): 1109 schema = ps.deprecated_schema 1110 else: 1111 schema = f.func 1112 1113 # Start with cpp arguments - dispatch lambda signature always include 'self' 1114 cpp_args = cpp.arguments( 1115 arguments=schema.arguments, 1116 faithful=False, 1117 symint=symint, 1118 method=False, 1119 cpp_no_default_args=f.cpp_no_default_args, 1120 ) 1121 out_args: set[str] = {a.name for a in schema.arguments.out} 1122 1123 # Convert from cpp argument to lambda argument 1124 def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: 1125 type_str = cpp_arg.type 1126 is_out_arg = cpp_arg.name in out_args 1127 if ps.method and cpp_arg.name == "self": 1128 # For method's 'self', we can use 'const Tensor &' and simply ignore mutability! 1129 type_str = "const at::Tensor &" 1130 else: 1131 # For other cases we need prevent dangling refs to temps (unless it's 1132 # unpacked scattered output) 1133 # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'. 1134 # TODO: avoid this special handling? 1135 ensure_temp_safe = len(out_args) <= 1 or not is_out_arg 1136 if ensure_temp_safe: 1137 type_str = { 1138 "at::Tensor &": "at::Tensor", 1139 }.get(type_str, type_str) 1140 return DispatchLambdaArgument( 1141 name=cpp_arg.name, 1142 type_str=type_str, 1143 is_out_arg=is_out_arg, 1144 ) 1145 1146 return tuple(map(dispatch_lambda_arg, cpp_args)) 1147 1148 1149# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean 1150# it's enough to just extend the list here. Before you do this, make sure 1151# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. 1152SUPPORTED_RETURN_TYPES = { 1153 "at::Tensor", 1154 "::std::tuple<at::Tensor,at::Tensor>", 1155 "::std::tuple<at::Tensor,at::Tensor,at::Tensor>", 1156 "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>", 1157 "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>", 1158 "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>", 1159 "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>", 1160 "::std::tuple<at::Tensor,at::Tensor,double,int64_t>", 1161 "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>", 1162 "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>", 1163 "::std::tuple<double,int64_t>", 1164 "::std::tuple<at::Tensor,::std::vector<at::Tensor>>", 1165 "::std::vector<at::Tensor>", 1166 # Needed for flash attention forw/backward 1167 "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>", 1168 "at::Scalar", 1169 "bool", 1170 "int64_t", 1171 "void*", 1172 "void", 1173 "at::QScheme", 1174 "double", 1175 "at::IntArrayRef", 1176 "at::ScalarType", 1177 "at::Stream", 1178} 1179 1180 1181def dispatch_lambda_return_str(f: NativeFunction) -> str: 1182 # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &') 1183 # because the dispatch lambdas take mutable arguments *by value*, not 1184 # by reference. If you then return a reference to such an argument, you 1185 # will now have a pointer to a dangling stack entry. Not good. 1186 # 1187 # You want: 1188 # 1189 # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; 1190 # ^^^^^^ 1191 # 1192 # *not* 1193 # 1194 # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; 1195 # ^^^^^^^ 1196 # 1197 # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing 1198 # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a 1199 # mutable reference to temporary. Maybe we could assign it to a 1200 # variable itself.) 1201 returns_without_annotation = tuple( 1202 Return(r.name, r.type, None) for r in f.func.returns 1203 ) 1204 return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() 1205 if return_str not in SUPPORTED_RETURN_TYPES: 1206 raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") 1207 return return_str 1208 1209 1210def cpp_dispatch_target(f: NativeFunction) -> str: 1211 symint = f.func.has_symint() 1212 name = cpp.name(f.func, symint_overload=symint) 1213 if Variant.method in f.variants: 1214 return f"self.{name}" 1215 if Variant.function in f.variants: 1216 if has_tensor_options(f) or f.func.name.name.base.endswith("_like"): 1217 namespace = "torch" 1218 else: 1219 namespace = "at" 1220 return f"{namespace}::{name}" 1221 raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}") 1222 1223 1224def cpp_dispatch_exprs( 1225 f: NativeFunction, 1226 *, 1227 python_signature: PythonSignature | None = None, 1228) -> tuple[str, ...]: 1229 cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() 1230 1231 exprs: tuple[str, ...] = () 1232 if not isinstance(python_signature, PythonSignatureDeprecated): 1233 # By default the exprs are consistent with the C++ signature. 1234 exprs = tuple(a.name for a in cpp_args) 1235 else: 1236 # For deprecated python signature we may need fill in some constants. 1237 exprs = tuple( 1238 filter( 1239 lambda n: n != "out" or f.func.is_out_fn(), 1240 python_signature.deprecated_args_exprs, 1241 ) 1242 ) 1243 1244 if Variant.method in f.variants: 1245 exprs = tuple(filter("self".__ne__, exprs)) 1246 1247 return exprs 1248 1249 1250# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1251# 1252# Python / C++ Args Binding 1253# 1254# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1255 1256 1257# We explicitly enumerate the PythonArgParser unpacking methods for all 1258# supported types. This might be more verbose than necessary, partially 1259# because of the irregularity of unpacking method naming, partially 1260# because we want to mimic the old codegen behavior - to reject 1261# unexpected and/or unsupported cases which the old codegen rejects. 1262# For certain cases it is intentionally more restrictive than necessary, 1263# e.g.: it doesn't accepts doublelist with definite size. 1264def arg_parser_unpack_method( 1265 t: Type, default: str | None, default_init: str | None, *, symint: bool = True 1266) -> str: 1267 has_default_init = default_init is not None 1268 if has_default_init and str(t) not in ( 1269 "ScalarType?", 1270 "ScalarType", 1271 "Device", 1272 "Device?", 1273 "Layout", 1274 "Layout?", 1275 "bool", 1276 "bool?", 1277 ): 1278 raise RuntimeError(f"type '{t}' does not supported unpacking with default") 1279 1280 if isinstance(t, BaseType): 1281 if t.name in [ 1282 BaseTy.Tensor, 1283 BaseTy.Stream, 1284 BaseTy.Storage, 1285 BaseTy.Scalar, 1286 BaseTy.Dimname, 1287 ]: 1288 # These unpack methods line up with their schema names 1289 return t.name.name.lower() 1290 elif t.name == BaseTy.ScalarType: 1291 return "scalartypeWithDefault" if has_default_init else "scalartype" 1292 elif t.name == BaseTy.Device: 1293 return "deviceWithDefault" if has_default_init else "device" 1294 elif t.name == BaseTy.DeviceIndex: 1295 return "toInt64" 1296 elif t.name == BaseTy.int: 1297 return "toInt64" 1298 elif t.name == BaseTy.SymInt: 1299 return "toSymInt" if symint else "toInt64" 1300 elif t.name == BaseTy.bool: 1301 return "toBoolWithDefault" if has_default_init else "toBool" 1302 elif t.name == BaseTy.float: 1303 return "toDouble" 1304 elif t.name == BaseTy.str: 1305 return "stringView" 1306 elif t.name == BaseTy.Layout: 1307 return "layoutWithDefault" if has_default_init else "layout" 1308 elif t.name == BaseTy.MemoryFormat: 1309 return "memoryformat" 1310 1311 elif isinstance(t, OptionalType): 1312 if str(t.elem) == "Tensor": 1313 return "optionalTensor" 1314 elif str(t.elem) == "Generator": 1315 return "generator" 1316 elif str(t.elem) == "Dimname[]": 1317 return "toDimnameListOptional" 1318 elif not has_default_init and default in ( 1319 None, 1320 "None", 1321 "::std::nullopt", 1322 "std::nullopt", 1323 ): 1324 # If default is None: append 'Optional' to elem's unpacking method 1325 return ( 1326 arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" 1327 ) 1328 else: 1329 # Otherwise, load as underlying type with default 1330 return arg_parser_unpack_method( 1331 t.elem, default, default_init, symint=symint 1332 ) 1333 1334 elif isinstance(t, ListType): 1335 if str(t.elem) == "Tensor": 1336 # accept and use definite size 1337 return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist" 1338 elif str(t.elem) == "Tensor?": 1339 return "list_of_optional_tensors" 1340 elif str(t.elem) == "Dimname": 1341 # accept definite size 1342 return "dimnamelist" 1343 elif str(t.elem) == "int": 1344 # accept definite size 1345 return "intlist" 1346 elif str(t.elem) == "float": 1347 return "doublelist" 1348 elif str(t.elem) == "SymInt": 1349 # accept definite size 1350 return "symintlist" if symint else "intlist" 1351 elif str(t.elem) == "Scalar": 1352 return "scalarlist" 1353 raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") 1354 1355 1356# Return RHS expression for python argument using PythonArgParser output. 1357# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' 1358def arg_parser_output_expr( 1359 arg_index: int, a: PythonArgument, *, symint: bool = True 1360) -> PythonArgParserOutputExpr: 1361 has_default = a.default_init is not None 1362 unpack_method = arg_parser_unpack_method( 1363 t=a.type, default=a.default, default_init=a.default_init, symint=symint 1364 ) 1365 default = f", {a.default_init}" if has_default else "" 1366 expr = f"_r.{unpack_method}({arg_index}{default})" 1367 1368 return PythonArgParserOutputExpr( 1369 name=a.name, 1370 expr=expr, 1371 index=arg_index, 1372 argument=a, 1373 ) 1374 1375 1376# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. 1377def arg_parser_output_exprs( 1378 ps: PythonSignature, f: NativeFunction, *, symint: bool = True 1379) -> dict[str, PythonArgParserOutputExpr]: 1380 return { 1381 e.name: e 1382 for i, a in enumerate(ps.arguments()) 1383 for e in (arg_parser_output_expr(i, a, symint=symint),) 1384 } 1385 1386 1387# argument name to type for scattered tensor options fields 1388TENSOR_OPTIONS_FIELDS = { 1389 "dtype": "ScalarType?", 1390 "device": "Device?", 1391 "layout": "Layout?", 1392 "pin_memory": "bool?", 1393 "requires_grad": "bool?", 1394} 1395 1396 1397# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). 1398def dispatch_lambda_exprs( 1399 ps: PythonSignature, f: NativeFunction, *, symint: bool = True 1400) -> DispatchLambdaArgumentExprs: 1401 # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing 1402 # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser 1403 # outputs. 1404 arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) 1405 lambda_args = dispatch_lambda_args(ps, f, symint=symint) 1406 inits: list[str] = [] 1407 lambda_args_exprs: dict[str, str] = {} 1408 1409 has_toptions = has_tensor_options(f) 1410 1411 # 1. special inits/unpacking to provide binding exprs for lambda arguments. 1412 for a in ps.arguments(skip_tensor_options=True): 1413 name = a.name 1414 arg_parser_expr = arg_parser_outputs[a.name].expr 1415 1416 if has_toptions and name == "self": 1417 # TODO: why this needs to be special case? 1418 inits.extend( 1419 [ 1420 f"auto self = {arg_parser_expr};", 1421 ] 1422 ) 1423 lambda_args_exprs[name] = name 1424 elif ( 1425 isinstance(a, PythonOutArgument) 1426 and len(a.outputs) > 1 1427 and f.func.is_out_fn() 1428 ): 1429 inits.extend( 1430 [ 1431 f"auto out = {arg_parser_expr};", 1432 ] 1433 ) 1434 for i, out_arg in enumerate(a.outputs): 1435 lambda_args_exprs[out_arg.name] = f"out[{i}]" 1436 elif str(a.type) == "Dimname[]?": 1437 # [old codegen] 1438 # TODO: make this part of something more general, or get rid of it. 1439 # optional<ArrayRef<T>> are special. The PythonArgParser returns an 1440 # optional<vector<T>>, which cannot be implicitly converted to 1441 # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap. 1442 inits.extend( 1443 [ 1444 f"auto __{name} = {arg_parser_expr};", 1445 f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950 1446 ] 1447 ) 1448 lambda_args_exprs[name] = name 1449 else: 1450 # default case - directly using PythonArgParser output expr 1451 lambda_args_exprs[name] = arg_parser_expr 1452 1453 # method's self is passed directly to python binding, rather than parsed 1454 if ps.method: 1455 lambda_args_exprs["self"] = "self" 1456 1457 # 2. special packing/checking for TensorOptions. 1458 tensor_options_args_names = [a.name for a in ps.tensor_options_args] 1459 if has_toptions: 1460 if f.func.is_out_fn(): 1461 raise RuntimeError(f"{f.func}: tensor options with output arg") 1462 for a in ps.tensor_options_args: 1463 if a.name not in TENSOR_OPTIONS_FIELDS: 1464 raise RuntimeError( 1465 f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments" 1466 ) 1467 if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): 1468 raise RuntimeError( 1469 f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" 1470 ) 1471 if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS): 1472 raise RuntimeError( 1473 f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" 1474 ) 1475 1476 inits.append( 1477 f"""\ 1478const auto options = TensorOptions() 1479 .dtype({arg_parser_outputs['dtype'].expr}) 1480 .device({arg_parser_outputs['device'].expr}) 1481 .layout({arg_parser_outputs['layout'].expr}) 1482 .requires_grad({arg_parser_outputs['requires_grad'].expr}) 1483 .pinned_memory({arg_parser_outputs['pin_memory'].expr}); 1484torch::utils::maybe_initialize_device(options); 1485""" 1486 ) 1487 lambda_args_exprs["options"] = "options" 1488 1489 # 3. special case - access scattered TensorOptions fields without packing 1490 # TODO: maybe move to the generator side as it's not related to binding. 1491 if not has_toptions and tensor_options_args_names: 1492 if "dtype" in tensor_options_args_names: 1493 # we're an output-arg variant, check these args against output tensor 1494 if not f.func.is_out_fn(): 1495 raise RuntimeError( 1496 f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}" 1497 ) 1498 if not all(a in tensor_options_args_names for a in ("layout", "device")): 1499 raise RuntimeError( 1500 f"{f.func}: incomplete tensor options for output check" 1501 ) 1502 1503 inits.append( 1504 f"""\ 1505check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr}, 1506 {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr}, 1507 {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr}); 1508""" 1509 ) 1510 # we'll set requires_grad on outgoing tensor 1511 if "requires_grad" not in tensor_options_args_names: 1512 raise RuntimeError( 1513 f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]' 1514 ) 1515 1516 return DispatchLambdaArgumentExprs( 1517 exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args), 1518 inits=inits, 1519 ) 1520