xref: /aosp_15_r20/external/pytorch/torchgen/api/python.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Sequence
5
6from torchgen.api import cpp
7from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
8from torchgen.gen import pythonify_default
9from torchgen.model import (
10    Argument,
11    BaseTy,
12    BaseType,
13    FunctionSchema,
14    ListType,
15    NativeFunction,
16    OptionalType,
17    Return,
18    Type,
19    Variant,
20)
21
22
23# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
24#
25#                           Data Models
26#
27# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
28#
29# [Notes] python binding codegen
30#
31# The Python binding codegen produces code that takes the input list of
32# PyObjects, finds the matching ATen C++ function using PythonArgParser,
33# converts the PyObjects into C++ types and calls the ATen C++ function:
34#
35# +--------+  parsing   +------------------------+  binding   +-----------------------+
36# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
37# +--------+            +------------------------+            +-----------------------+
38#
39# The following examples demonstrate the data models the Python binding
40# codegen needs to deal with and the tasks it needs to accomplish. It
41# helps understand the purpose of the new data types we introduced below.
42#
43#  - Function Schema (source of truth)
44#
45#      aten::empty.names(int[] size, *, Dimname[]? names,
46#                        ScalarType? dtype=None, Layout? layout=None,
47#                        Device? device=None, bool? pin_memory=None,
48#                        MemoryFormat? memory_format=None) -> Tensor
49#
50#  - Python Signature
51#
52#    It's used to generate input schema string for PythonArgParser.
53#    Note: TensorOptions fields are reordered and the additional
54#    'requires_grad' field is added:
55#
56#      empty(IntArrayRef size, *, DimnameList? names,
57#            MemoryFormat? memory_format=None, ScalarType dtype=None,
58#            Layout layout=torch.strided, Device device=None,
59#            bool pin_memory=False, bool requires_grad=False)
60#
61#  - C++ Signature
62#
63#    It's used to generate C++ lambda formals & dispatch call.
64#    Note: the scattered TensorOptions fields are packed into 'options'.
65#
66#      auto dispatch_empty =
67#          [](IntArrayRef size, std::optional<DimnameList> names,
68#             const TensorOptions & options,
69#             std::optional<MemoryFormat> memory_format) -> Tensor {
70#          pybind11::gil_scoped_release no_gil;
71#          return torch::empty(size, names, options, memory_format);
72#      };
73#
74#  - Binding between Python Arguments and C++ Arguments
75#
76#    Given a set of Python Arguments in scope, we need produce the
77#    binding expressions that translate the Python API into C++ API:
78#
79#            Python Args               Cpp Args       Binding Exprs
80#     -----------------------------------------------------------------
81#         0: size                      size           '_r.intlist(0)'
82#         1: names                     names          'names' [special init]
83#         2: memory_format -------+
84#         3: dtype         -----+-|--> options        'options' [special packing]
85#         4: layout            /  |
86#         5: device           /   +--> memory_format  '_r.memoryformatOptional(2)'
87#         6: pin_memory      /
88#         7: requires_grad -+
89#
90#    So the full dispatch expression would look like:
91#
92#      dispatch_empty(_r.intlist(0), names, options,
93#                     _r.memoryformatOptional(2))
94#
95#    Where does 'names' come from? It involves special local init:
96#
97#      auto __names = _r.toDimnameListOptional(1);
98#      std::optional<DimnameList> names =
99#          __names ? std::make_optional(DimnameList(__names.value()))
100#                  : std::nullopt;
101#
102#    Where does 'options' come from? It involves special local init
103#    for TensorOptions. Note that Python side has the additional
104#    'requires_grad' field:
105#
106#      const auto options = TensorOptions()
107#          .dtype(_r.scalartype(3))
108#          .device(_r.device(5))
109#          .layout(_r.layoutOptional(4))
110#          .requires_grad(_r.toBool(7))
111#          .pinned_memory(_r.toBool(6));
112#
113#    In some other cases one Python Argument can map to multiple C++
114#    Arguments. For example:
115#
116#     aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
117#       -> (Tensor values, Tensor indices)
118#
119#            Python Args               Cpp Args          Binding Exprs
120#     ---------------------------------------------------------------------
121#                               +----> max               'out[0]'
122#                              /-----> max_values        'out[1]
123#         0: input            /        self              '_r.tensor(0)'
124#         1: dim             /         dim               '_r.dimname(1)'
125#         2: keepdim        /          keepdim           '_r.toBool(2)'
126#         3: out      -----+           [local init] out  '_r.tensorlist_n<2>(3)'
127#
128#    As demonstrated above, the binding can involve reordering,
129#    packing, unpacking and special local inits.
130#
131#
132#  Let's look at a concrete example:
133#
134#      static PythonArgParser parser({
135#        "abs(Tensor input, *, Tensor out=None)",
136#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137#         ^
138#         +--- Python Schema, represented by PythonSignature and PythonArgument
139#
140#      }, /*traceable=*/true);
141#
142#      ParsedArgs<2> parsed_args;
143#      auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
144#
145#      ...
146#
147#      if (_r.isNone(1)) {
148#          ~~~~~~~~~~~~  <--- Scattered PythonArgParser output (arg name = 'out')
149#                             represented by PythonArgParserOutputExpr
150#
151#        // aten::abs(Tensor self) -> Tensor
152#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153#         ^
154#         +--- NativeFunction schema, base version
155#
156#        auto dispatch_abs = [](const Tensor & self) -> Tensor {
157#                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
158#                             ^
159#                             +--- dispatch_lambda_args / dispatch_lambda_return_str
160#                                  generated from NativeFunction / CppSignature
161#                                  (deprecated PythonSignature is special)
162#                                  arguments are represented by DispatchLambdaArgument
163#
164#          pybind11::gil_scoped_release no_gil;
165#          return self.abs();
166#                 ~~~~~~~~~~~  <--- cpp_dispatch_target / cpp_dispatch_exprs
167#                                   generated from NativeFunction / CppSignature
168#        };
169#        return wrap(dispatch_abs(_r.tensor(0)));
170#                                 ~~~~~~~~~~~~~
171#                                  ^
172#                                  +--- dispatch_lambda_exprs
173#                                       binding PythonArgParserOutputExpr (python args)
174#                                       and DispatchLambdaArgument (c++ args)
175#
176#      } else {
177#        // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
178#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
179#         ^
180#         +--- NativeFunction schema, out-variant
181#
182#        auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
183#          pybind11::gil_scoped_release no_gil;
184#          return at::abs_out(out, self);
185#        };
186#        return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
187#      }
188#
189#
190# [Notes] python interface codegen
191# The python dataclasses below are used used to generate both python binding code
192# and pyi type hint signatures.
193# In theory these two should look very similar, but there are number of differences
194# in how pyi signatures vs. python_arg_parser signatures are generated.
195# These differences have been encapsulated in signature_str() vs. signature_str_pyi()
196# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
197# For examples, only pyi signatures include return types.
198
199
200@dataclass(frozen=True)
201class PythonReturns:
202    returns: tuple[Return, ...]
203
204
205@dataclass(frozen=True)
206class PythonArgument:
207    name: str
208    type: Type
209    default: str | None
210
211    # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
212    #
213    #   _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
214    #                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215    #                            ^
216    #                            +--- default_init str
217    default_init: str | None
218
219    # Compute argument formal for python argument parsing.
220    # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
221    def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
222        type_str = (
223            argument_type_str(self.type, symint=symint)
224            .replace("const ", "")
225            .replace(" &", "")
226        )
227
228        name = self.name
229        # s/self/input/ outside method bindings
230        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
231        # for the parse string
232        if name == "self" and type_str in ["Tensor", "Number"] and not method:
233            name = "input"
234
235        # add default
236        if self.default is not None:
237            default = {
238                "nullptr": "None",
239                "::std::nullopt": "None",
240                "std::nullopt": "None",
241                "{}": "None",
242            }.get(self.default, self.default)
243            return f"{type_str} {name}={default}"
244        else:
245            return f"{type_str} {name}"
246
247    def argument_str_pyi(
248        self, *, method: bool = False, deprecated: bool = False
249    ) -> str:
250        type_str = argument_type_str_pyi(self.type)
251
252        name = self.name
253        # s/self/input/ outside method bindings
254        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
255        # for the parse string
256        if name == "self" and type_str == "Tensor" and not method and not deprecated:
257            name = "input"
258
259        if name == "from":  # from is a Python keyword...
260            name += "_"
261
262        # pyi merges the _out and functional variants into the same signature, with an optional out arg
263        if name == "out" and type_str == "Tensor" and not deprecated:
264            type_str = "Optional[" + type_str + "]"
265
266        # pyi deprecated signatures don't get defaults for their out arg
267        treat_as_no_default = (
268            deprecated
269            and isinstance(self, PythonOutArgument)
270            and self.default == "None"
271        )
272
273        # add default
274        if self.default is not None and not treat_as_no_default:
275            if (
276                isinstance(self.type, ListType)
277                and self.type.elem == BaseType(BaseTy.int)
278                and self.default.startswith("{")
279                and self.default.endswith("}")
280            ):
281                default = (
282                    "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
283                )
284            else:
285                default = {
286                    "nullptr": "None",
287                    "::std::nullopt": "None",
288                    "std::nullopt": "None",
289                    "{}": "None",
290                    "c10::MemoryFormat::Contiguous": "contiguous_format",
291                    "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
292                }.get(self.default, self.default)
293            return f"{name}: {type_str} = {default}"
294        else:
295            return f"{name}: {type_str}"
296
297
298@dataclass(frozen=True)
299class PythonOutArgument(PythonArgument):
300    # In Python signature multiple output fields are packed into one 'out' argument.
301    # When binding to C++, it's first binded to a local 'out' variable:
302    #   'auto out = _r.tensorlist_n<2>(2);',
303    # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
304    # TODO: maybe don't need keep scattered out fields for python signature?
305    outputs: tuple[PythonArgument, ...]
306
307    @staticmethod
308    def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
309        if not outputs:
310            return None
311
312        size = len(outputs)
313        if size == 1:
314            return PythonOutArgument(
315                name=outputs[0].name,
316                type=outputs[0].type,
317                default="None",
318                default_init=None,
319                outputs=outputs,
320            )
321        elif size > 1:
322            if any(not a.type.is_tensor_like() for a in outputs):
323                raise RuntimeError(f"Unsupported output type: {outputs}")
324            return PythonOutArgument(
325                name="out",
326                # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
327                type=ListType(BaseType(BaseTy.Tensor), size),
328                default="None",
329                default_init=None,
330                outputs=outputs,
331            )
332        raise AssertionError(r"Unexpected PythonOutArgument size")
333
334
335@dataclass(frozen=True)
336class PythonSignature:
337    # Base operator name, without inplace/outplace suffix.
338    name: str
339
340    # Positional arguments.
341    # TODO: create a dedicated SelfArgument type for 'self'?
342    input_args: tuple[PythonArgument, ...]
343
344    # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
345    # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
346    input_kwargs: tuple[PythonArgument, ...]
347
348    output_args: PythonOutArgument | None
349
350    # Return types, which are only used by pyi
351    returns: PythonReturns
352
353    # These are scattered kwargs arguments belonging to TensorOptions.
354    # When binding to C++, they are packed into a TensorOptions object 'options'.
355    # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
356    # for out variant), in which case they will be used as scattered fields without
357    # being packed into 'options'.
358    # TODO: maybe create a PythonTensorOptionsArgument?
359    tensor_options_args: tuple[PythonArgument, ...]
360
361    # method or function signature?
362    method: bool
363
364    @property
365    def deprecated(self) -> bool:
366        return False
367
368    def arguments(
369        self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
370    ) -> tuple[PythonArgument | PythonOutArgument, ...]:
371        result: list[PythonArgument | PythonOutArgument] = []
372        result.extend(self.input_args)
373        result.extend(self.input_kwargs)
374        if self.output_args is not None and not skip_outputs:
375            result.append(self.output_args)
376        if not skip_tensor_options:
377            result.extend(self.tensor_options_args)
378        return tuple(result)
379
380    def arguments_count(self) -> int:
381        return len(self.arguments())
382
383    def output_idx(self) -> int:
384        return len(self.input_args) + len(self.input_kwargs)
385
386    # [old codegen] Compute the Python function signature for argument parsing,
387    # as specified in torch/csrc/utils/python_arg_parser.h.  WARNING:
388    # this is NOT the same type signature as specified by PEP 484
389    # as understood by mypy; our format was independently developed
390    # and has some quirks to make it more suitable specifically
391    # for error parsing.
392    #
393    # For a translation to mypy-valid type signatures, see
394    # signature_str_pyi().
395    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
396        args = self.arguments(skip_outputs=skip_outputs)
397        schema_formals: list[str] = [
398            a.argument_str(method=self.method, symint=symint) for a in args
399        ]
400        positional_argc = len(self.input_args)
401        if len(schema_formals) > positional_argc:
402            schema_formals.insert(positional_argc, "*")
403
404        return f'{self.name}({", ".join(schema_formals)})'
405
406    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
407        args = self.arguments(skip_outputs=skip_outputs)
408        schema_formals: list[str] = [
409            a.argument_str_pyi(method=self.method) for a in args
410        ]
411        positional_argc = len(self.input_args)
412        if len(schema_formals) > positional_argc:
413            schema_formals.insert(positional_argc, "*")
414
415        # only pyi signatures include returns
416        returns_str = returns_str_pyi(self)
417        # pyi also includes self (with no typing/defaults) for methods
418        if self.method:
419            schema_formals.insert(0, "self")
420        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
421
422    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
423        # only pyi uses vararg signatures
424        args = self.arguments(skip_outputs=skip_outputs)
425        schema_formals: list[str] = [
426            a.argument_str_pyi(method=self.method) for a in args
427        ]
428        # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
429        num_args = self.arguments_count()
430        num_positionalargs = len(self.input_args)
431
432        have_vararg_version = False
433        if num_args > 0:
434            vararg_type = args[0].type
435            if (
436                isinstance(vararg_type, ListType)
437                and str(vararg_type.elem) in ["int", "SymInt"]
438                and num_positionalargs == 1
439            ):
440                have_vararg_version = True
441
442        if not have_vararg_version:
443            return None
444
445        # Below are the major changes in vararg vs. regular pyi signatures
446        # vararg signatures also omit the asterix
447        assert isinstance(vararg_type, ListType)
448        schema_formals[0] = (
449            "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
450        )
451
452        returns_str = returns_str_pyi(self)
453        # pyi also includes self (with no typing/defaults) for methods
454        if self.method:
455            schema_formals.insert(0, "self")
456        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
457
458
459# The deprecated python signature involves some special logic, so create a
460# dedicated data model to store these extra properties.
461@dataclass(frozen=True)
462class PythonSignatureDeprecated(PythonSignature):
463    # Schema for the deprecated function
464    deprecated_schema: FunctionSchema
465
466    # The deprecated signature might miss some arguments that the corresponding
467    # C++ signature expects. We need store the constant default values to pass in.
468    # For example:
469    #   [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
470    #   [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
471    #   [func call]: self.addmm(mat1, mat2, beta, 1)
472    # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
473    deprecated_args_exprs: tuple[str, ...]
474
475    @property
476    def deprecated(self) -> bool:
477        return True
478
479    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
480        return (
481            PythonSignature.signature_str(
482                self, skip_outputs=skip_outputs, symint=symint
483            )
484            + "|deprecated"
485        )
486
487    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
488        args = self.arguments(skip_outputs=skip_outputs)
489        schema_formals: list[str] = [
490            a.argument_str_pyi(method=self.method, deprecated=True) for a in args
491        ]
492        positional_argc = len(self.input_args)
493        if len(schema_formals) > positional_argc:
494            schema_formals.insert(positional_argc, "*")
495
496        returns_str = returns_str_pyi(self)
497        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
498
499    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
500        # the codegen doesn't include vararg variants for deprecated signatures
501        return None
502
503
504# This struct is used to hold the PythonSignature and its corresponding
505# NativeFunction BEFORE grouping base and out-variant functions.
506# Why not store NativeFunction in PythonSignature or construct PythonSignature
507# from NativeFunction? Because they are not 1-1 mapped.
508# One native function could have both deprecated and non-deprecated python
509# signatures - NativeFunction doesn't contain information to construct the
510# deprecated python signature.
511# One python signature is used to handle both the base and the out-variant
512# function - see 'PythonSignatureGroup'.
513@dataclass(frozen=True)
514class PythonSignatureNativeFunctionPair:
515    signature: PythonSignature
516    function: NativeFunction
517
518
519# We merge pairs of functions with signatures that are equivalent mod
520# output arguments, and use a single entry in the python_arg_parser sig
521# list for both (output arguments become optional).
522@dataclass(frozen=True)
523class PythonSignatureGroup:
524    # The signature used for Python argument parsing. The outplace signature
525    # is preferred if exists, because it can be used to parse inputs for both
526    # the out-place variant and the base version (with output omitted).
527    signature: PythonSignature
528
529    # The regular ATen declaration (e.g. conv2d)
530    base: NativeFunction
531
532    # The out variant (e.g. conv2d_out)
533    outplace: NativeFunction | None
534
535    @classmethod
536    def from_pairs(
537        cls,
538        functional: PythonSignatureNativeFunctionPair,
539        out: PythonSignatureNativeFunctionPair | None,
540    ) -> PythonSignatureGroup:
541        if out is None:
542            return PythonSignatureGroup(
543                signature=functional.signature,
544                base=functional.function,
545                outplace=None,
546            )
547
548        # prefer the signature with optional out=... arguments because it's the
549        # superset that can be used to parse input for both base and outplace.
550        signature_kwargs = out.signature.__dict__.copy()
551
552        # Out overloads in C++ don't have TensorOptions arguments,
553        # so take these from the functional variant
554        signature_kwargs[
555            "tensor_options_args"
556        ] = functional.signature.tensor_options_args
557
558        return PythonSignatureGroup(
559            signature=type(out.signature)(**signature_kwargs),
560            base=functional.function,
561            outplace=out.function,
562        )
563
564
565# C++ function dispatch is wrapped in a lambda function. The lambda function
566# has almost the same signature as the C++ function, only with some small
567# variants - see details below.
568# This data model is used to represent arguments of the lambda function
569# signature.
570@dataclass(frozen=True)
571class DispatchLambdaArgument:
572    name: str
573    type_str: str
574    is_out_arg: bool
575
576
577# To pass PyObjects arguments to C++ function (via the lambda wrapper),
578# we need first convert PyObjects into simple C++ objects. This work
579# is done by PythonArgParser.
580# This data model is used to represent the output of PythonArgParser.
581# It has 1-1 mapping with PythonArgument in PythonSignature.
582@dataclass(frozen=True)
583class PythonArgParserOutputExpr:
584    # argument name
585    name: str
586
587    # RHS expression to reference PythonArgParser output.
588    expr: str
589
590    # In some special cases we need create different expr, e.g.:
591    # '_r.isNone(1)' instead of '_r.tensor(1)'.
592    index: int
593
594    # The python argument it maps to.
595    argument: PythonArgument
596
597    @property
598    def is_none_expr(self) -> str:
599        return f"_r.isNone({self.index})"
600
601
602# To pass PythonArgParser output to the lambda wrapper, we need bind
603# PythonArgParserOutputExpr to DispatchLambdaArgument.
604# They are not always 1-1 mapped, e.g. scattered TensorOptions fields
605# need be packed into a TensorOptions object, which is the argument
606# that the lambda function wrapper takes.
607@dataclass(frozen=True)
608class DispatchLambdaArgumentExprs:
609    # The exprs that provide the binding for lambda arguments, e.g.:
610    #
611    #   'self' -> '_r.tensor(0)'
612    #   'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
613    #   'options' -> 'options'
614    #
615    # It has 1-1 mapping with DispatchLambdaArgument.
616    exprs: Sequence[str]
617
618    # Special local inits, which might introduce new variables that
619    # the 'exprs' above reference, e.g.:
620    #
621    #   'auto out = _r.tensorlist_n<2>(2);'
622    #
623    inits: Sequence[str]
624
625
626# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
627#
628#                          Helper Functions
629#
630# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
631
632
633def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
634    return CppSignatureGroup.from_native_function(f, method=method).signature
635
636
637def has_tensor_options(f: NativeFunction) -> bool:
638    return f.func.arguments.tensor_options is not None
639
640
641# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
642#
643#                          Python Signature
644#
645# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
646
647
648# 'simple_type' was introduced by the old codegen, which is slightly
649# different from the python schema type, e.g.: doesn't have '?' suffix
650# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
651def argument_type_str(
652    t: Type, *, simple_type: bool = False, symint: bool = True
653) -> str:
654    if isinstance(t, BaseType):
655        if t.name == BaseTy.Tensor:
656            return "Tensor"
657        elif t.name == BaseTy.int:
658            return "int64_t"
659        elif t.name == BaseTy.float:
660            return "double"
661        elif t.name == BaseTy.str:
662            return "c10::string_view"
663        elif t.name in [
664            BaseTy.bool,
665            BaseTy.QScheme,
666            BaseTy.Scalar,
667            BaseTy.ScalarType,
668            BaseTy.Generator,
669            BaseTy.Storage,
670            BaseTy.Layout,
671            BaseTy.Device,
672            BaseTy.DeviceIndex,
673            BaseTy.MemoryFormat,
674            BaseTy.Dimname,
675            BaseTy.Stream,
676            BaseTy.ConstQuantizerPtr,
677            BaseTy.SymInt,
678        ]:
679            # These python schema type names line up with their function schema names
680            return t.name.name
681
682    elif isinstance(t, OptionalType):
683        if str(t.elem) == "Tensor":
684            # Is it desired to keep '?' for simple_type with new style dispatcher?
685            return "Tensor?"
686        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
687        return f"{elem}?"
688    elif isinstance(t, ListType):
689        size = t.size if not simple_type else None
690        if str(t.elem) == "bool":
691            assert t.size is not None
692            return f"::std::array<bool,{t.size}>"
693        elif str(t.elem) == "int":
694            return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
695        elif str(t.elem) == "SymInt":
696            if symint:
697                return (
698                    f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
699                )
700            else:
701                return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
702        elif str(t.elem) == "Tensor":
703            return f"TensorList[{size}]" if size is not None else "TensorList"
704        elif str(t.elem) == "Scalar":
705            return f"ScalarList[{size}]" if size is not None else "ScalarList"
706        elif str(t.elem) == "Tensor?":
707            if simple_type:
708                return "c10::List<::std::optional<Tensor>>"
709            else:
710                return "const c10::List<::std::optional<Tensor>> &"
711        elif str(t.elem) == "Dimname":
712            return f"DimnameList[{size}]" if size is not None else "DimnameList"
713        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
714        return f"ArrayRef<{elem}>"
715
716    raise RuntimeError(f"unrecognized type {repr(t)}")
717
718
719def argument_type_size(t: Type) -> int | None:
720    l = t.is_list_like()
721    if l is not None and str(l.elem) != "bool":
722        return l.size
723    else:
724        return None
725
726
727def argument(a: Argument) -> PythonArgument:
728    return PythonArgument(
729        name=a.name,
730        type=a.type,
731        # TODO: directly translate a.default to python default
732        default=(
733            str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
734            if a.default is not None
735            else None
736        ),
737        default_init=None,
738    )
739
740
741# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
742def signature(
743    f: NativeFunction, *, method: bool = False, pyi: bool = False
744) -> PythonSignature:
745    return signature_from_schema(
746        f.func, category_override=f.category_override, method=method, pyi=pyi
747    )
748
749
750def signature_from_schema(
751    func: FunctionSchema,
752    *,
753    category_override: str | None,
754    method: bool = False,
755    pyi: bool = False,
756) -> PythonSignature:
757    args: list[Argument] = []
758    args.extend(func.arguments.pre_self_positional)
759    # Skip SelfArgument if this is method.
760    if not method and func.arguments.self_arg is not None:
761        args.append(func.arguments.self_arg.argument)
762    args.extend(func.arguments.post_self_positional)
763    args.extend(func.arguments.pre_tensor_options_kwarg_only)
764    # Skip TensorOptionsArguments. Python side TensorOptions
765    # arguments are created based on different rules - see below.
766    args.extend(func.arguments.post_tensor_options_kwarg_only)
767    args.extend(func.arguments.out)
768
769    input_arg_set = {a.name for a in func.arguments.flat_positional}
770    kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
771    out_arg_set = {a.name for a in func.arguments.out}
772
773    input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
774    input_kwargs = tuple(
775        map(argument, filter(lambda a: a.name in kwarg_only_set, args))
776    )
777    outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
778
779    # Reintroduce the scattered fields of TensorOptions for Python.
780    # Compared to the cpp counterpart, the python arguments have new property
781    # (default_init) and a new argument 'requires_grad', which require some
782    # special handlings.
783    # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
784    # to the original versions in the yaml, this recreation is a potential
785    # source of drift between eager and JIT. Pull this logic out to a shared place.
786
787    has_tensor_input_arg = any(
788        a.type.is_tensor_like() for a in func.arguments.flat_non_out
789    )
790    if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
791        raise ValueError(
792            "argument named requires_grad is reserved, should not explicitly add it in the schema"
793        )
794
795    # [old codegen] this probably won't work if one of the returns is not a tensor,
796    # but it will produce a compile-time error that is obvious.
797    has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
798
799    name: str = cpp.name(func)
800    is_factory_function = category_override == "factory" or (
801        has_tensor_return and not has_tensor_input_arg
802    )
803    is_like_or_new_function = (
804        category_override in ("new", "like")
805        or name.startswith("new_")
806        or name.endswith("_like")
807    )
808    is_dummy_function = category_override == "dummy"
809
810    tensor_options_args: list[PythonArgument] = []
811    if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
812
813        def topt_default_init(name: str) -> str | None:
814            topt_args = func.arguments.tensor_options
815            if topt_args is None:
816                return None
817            a = getattr(topt_args, name)
818            if a.default is None or a.default == "None":
819                return None
820            return cpp.default_expr(a.default, a.type, symint=False)
821
822        tensor_options_args.append(
823            PythonArgument(
824                name="dtype",
825                type=OptionalType(BaseType(BaseTy.ScalarType)),
826                default="None",
827                default_init=(
828                    None if is_like_or_new_function else topt_default_init("dtype")
829                ),
830            )
831        )
832        tensor_options_args.append(
833            PythonArgument(
834                name="layout",
835                type=OptionalType(BaseType(BaseTy.Layout)),
836                default="None",
837                default_init=(
838                    None if is_like_or_new_function else topt_default_init("layout")
839                ),
840            )
841        )
842        tensor_options_args.append(
843            PythonArgument(
844                name="device",
845                type=OptionalType(BaseType(BaseTy.Device)),
846                default="None",
847                default_init=(
848                    None
849                    if is_like_or_new_function
850                    else (
851                        topt_default_init("device")
852                        or "torch::tensors::get_default_device()"
853                    )
854                ),
855            )
856        )
857        tensor_options_args.append(
858            PythonArgument(
859                name="pin_memory",
860                type=OptionalType(BaseType(BaseTy.bool)),
861                default="False",
862                default_init=None,
863            )
864        )
865        tensor_options_args.append(
866            PythonArgument(
867                name="requires_grad",
868                type=OptionalType(BaseType(BaseTy.bool)),
869                default="False",
870                default_init=None,
871            )
872        )
873
874    returns = PythonReturns(returns=func.returns)
875
876    return PythonSignature(
877        name=str(func.name.name),
878        input_args=input_args,
879        input_kwargs=input_kwargs,
880        output_args=PythonOutArgument.from_outputs(outputs),
881        tensor_options_args=tuple(tensor_options_args),
882        returns=returns,
883        method=method,
884    )
885
886
887# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
888#
889#                          Python Interface
890#
891# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
892
893
894def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
895    if len(returns) <= 1 or all(r.name is None for r in returns):
896        return []
897    else:
898        if any(r.name is None for r in returns):
899            # When building on Windows, `PyStructSequence_UnnamedField` could not be
900            # resolved by the linker for some reason, which cause error in building:
901            #
902            # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
903            # PyStructSequence_UnnamedField
904            #
905            # Thus, at this point in time, we do not support unnamed
906            # fields in structseq; you must either name all fields,
907            # or none of them.
908            raise ValueError("Unnamed field is not supported by codegen")
909
910        return [str(r.name) for r in returns]
911
912
913def argument_type_str_pyi(t: Type) -> str:
914    add_optional = False
915    if isinstance(t, OptionalType):
916        t = t.elem
917        add_optional = True
918
919    if isinstance(t, BaseType):
920        if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
921            ret = "_int"
922        if t.name == BaseTy.SymInt:
923            ret = "Union[_int, SymInt]"
924        elif t.name == BaseTy.float:
925            ret = "_float"
926        elif t.name == BaseTy.str:
927            ret = "str"
928        elif t.name == BaseTy.Scalar:
929            ret = "Union[Number, _complex]"
930        elif t.name == BaseTy.ScalarType:
931            ret = "_dtype"
932        elif t.name == BaseTy.bool:
933            ret = "_bool"
934        elif t.name == BaseTy.QScheme:
935            ret = "_qscheme"
936        elif t.name == BaseTy.Layout:
937            ret = "_layout"
938        elif t.name == BaseTy.Device:
939            ret = "Optional[DeviceLikeType]"
940        elif t.name == BaseTy.MemoryFormat:
941            ret = "memory_format"
942        elif t.name == BaseTy.Dimname:
943            ret = "Union[str, ellipsis, None]"
944        elif t.name == BaseTy.Storage:
945            ret = "Union[Storage, UntypedStorage]"
946        elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
947            # These python schema type names line up with their function schema names
948            ret = t.name.name
949
950    elif isinstance(t, ListType):
951        if str(t.elem) == "int":
952            ret = "Union[_int, _size]" if t.size is not None else "_size"
953        elif t.is_tensor_like():
954            # TODO: this doesn't seem right...
955            # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
956            # It should probably translate to   Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
957            if isinstance(t.elem, OptionalType):
958                add_optional = True
959            ret = (
960                "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
961                if t.size is not None
962                else "Union[Tuple[Tensor, ...], List[Tensor]]"
963            )
964        elif str(t.elem) == "float":
965            ret = "Sequence[_float]"
966        elif str(t.elem) == "SymInt" and t.size is not None:
967            elem = argument_type_str_pyi(t.elem)
968            ret = f"Union[{elem}, Sequence[{elem}]]"
969        else:
970            elem = argument_type_str_pyi(t.elem)
971            ret = f"Sequence[{elem}]"
972
973    else:
974        raise RuntimeError(f"unrecognized type {repr(t)}")
975
976    if add_optional:
977        ret = "Optional[" + ret + "]"
978
979    return ret
980
981
982def return_type_str_pyi(t: Type) -> str:
983    # Where arguments are open to accepting Union, return types should return
984    # concrete types
985
986    if isinstance(t, OptionalType):
987        inner = return_type_str_pyi(t.elem)
988        return f"Optional[{inner}]"
989
990    if isinstance(t, BaseType):
991        if t.name == BaseTy.Device:
992            return "_device"
993        elif t.name == BaseTy.Dimname:
994            ret = "Optional[str]"
995        else:
996            return argument_type_str_pyi(t)
997
998    if isinstance(t, ListType):
999        inner = return_type_str_pyi(t.elem)
1000        return f"Tuple[{inner}, ...]"
1001
1002    return argument_type_str_pyi(t)
1003
1004
1005def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
1006    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1007    structseq_name = signature.name
1008    field_names = structseq_fieldnames(signature.returns.returns)
1009    if field_names:
1010        # These types are structseq objects which act like named NamedTuples, but
1011        # the constructor acts like the constructor of tuple. Using typing.NamedTuple
1012        # does not allow us to override __init__.
1013        seq_type = f"Tuple[{', '.join(python_returns)}]"
1014        structseq_def_lines = [
1015            f"class {structseq_name}({seq_type}):",
1016        ]
1017        for name, typ in zip(field_names, python_returns):
1018            structseq_def_lines.extend(
1019                [
1020                    "    @property",
1021                    f"    def {name}(self) -> {typ}: ...",
1022                ]
1023            )
1024        structseq_def_lines.extend(
1025            [
1026                f"    def __new__(cls, sequence: {seq_type}): ...",
1027                f"    n_fields: _int = {len(field_names)}",
1028                f"    n_sequeunce_fields: _int = {len(field_names)}",
1029                "    n_unnamed_fields: _int = 0",
1030                "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
1031                "",  # add an extra newline
1032            ]
1033        )
1034        structseq_def = "\n".join(structseq_def_lines)
1035        # Example:
1036        # structseq_def = (
1037        #     "class max(Tuple[Tensor, Tensor]):\n"
1038        #     "    @property\n"
1039        #     "    def values(self) -> Tensor: ...\n"
1040        #     "    @property\n"
1041        #     "    def indices(self) -> Tensor: ...\n"
1042        #     "    def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
1043        #     "    n_fields: _int = 2",
1044        #     "    n_sequeunce_fields: _int = 2",
1045        #     "    n_unnamed_fields: _int = 0",
1046        #     "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
1047        # )
1048        return structseq_name, structseq_def
1049    return None
1050
1051
1052def returns_str_pyi(signature: PythonSignature) -> str:
1053    field_names = structseq_fieldnames(signature.returns.returns)
1054    if field_names:
1055        return f"torch.return_types.{signature.name}"
1056
1057    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1058    if len(python_returns) > 1:
1059        return "Tuple[" + ", ".join(python_returns) + "]"
1060    if len(python_returns) == 1:
1061        return python_returns[0]
1062    return "None"
1063
1064
1065# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1066#
1067#                        C++ Function Dispatch
1068#
1069# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1070# This section provides APIs to generate the code that does C++ function
1071# dispatch. The C++ function call is wrapped by a lambda function.
1072# For example:
1073#
1074#    // aten::selu_(Tensor(a!) self) -> Tensor(a!)
1075#    auto dispatch_selu_ = [](Tensor self) -> Tensor {
1076#      pybind11::gil_scoped_release no_gil;
1077#      return at::selu_(self);
1078#    };
1079#
1080# The lambda function's signature follows the C++ signature in common
1081# cases, e.g.:
1082#
1083#   // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
1084#   [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1085#
1086# For out variant the 'out' argument's type is changed from 'Tensor &'
1087# to 'Tensor'. It's because when calling the lambda it passes in the
1088# PythonArgParser output '_r.tensor(3)', which is stack allocated object
1089# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
1090#
1091#   // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
1092#   [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1093#
1094# For multi-output case it can keep using reference type because the
1095# PythonArgParser output has been unpacked to local variables, e.g.:
1096#
1097#   // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
1098#   //     Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
1099#   [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
1100#
1101# For deprecated python signature, it should follow deprecated python arg order.
1102# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
1103
1104
1105def dispatch_lambda_args(
1106    ps: PythonSignature, f: NativeFunction, symint: bool = True
1107) -> tuple[DispatchLambdaArgument, ...]:
1108    if isinstance(ps, PythonSignatureDeprecated):
1109        schema = ps.deprecated_schema
1110    else:
1111        schema = f.func
1112
1113    # Start with cpp arguments - dispatch lambda signature always include 'self'
1114    cpp_args = cpp.arguments(
1115        arguments=schema.arguments,
1116        faithful=False,
1117        symint=symint,
1118        method=False,
1119        cpp_no_default_args=f.cpp_no_default_args,
1120    )
1121    out_args: set[str] = {a.name for a in schema.arguments.out}
1122
1123    # Convert from cpp argument to lambda argument
1124    def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
1125        type_str = cpp_arg.type
1126        is_out_arg = cpp_arg.name in out_args
1127        if ps.method and cpp_arg.name == "self":
1128            # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
1129            type_str = "const at::Tensor &"
1130        else:
1131            # For other cases we need prevent dangling refs to temps (unless it's
1132            # unpacked scattered output)
1133            # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
1134            # TODO: avoid this special handling?
1135            ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
1136            if ensure_temp_safe:
1137                type_str = {
1138                    "at::Tensor &": "at::Tensor",
1139                }.get(type_str, type_str)
1140        return DispatchLambdaArgument(
1141            name=cpp_arg.name,
1142            type_str=type_str,
1143            is_out_arg=is_out_arg,
1144        )
1145
1146    return tuple(map(dispatch_lambda_arg, cpp_args))
1147
1148
1149# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
1150# it's enough to just extend the list here. Before you do this, make sure
1151# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
1152SUPPORTED_RETURN_TYPES = {
1153    "at::Tensor",
1154    "::std::tuple<at::Tensor,at::Tensor>",
1155    "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
1156    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1157    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1158    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1159    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
1160    "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
1161    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
1162    "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
1163    "::std::tuple<double,int64_t>",
1164    "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
1165    "::std::vector<at::Tensor>",
1166    # Needed for flash attention forw/backward
1167    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
1168    "at::Scalar",
1169    "bool",
1170    "int64_t",
1171    "void*",
1172    "void",
1173    "at::QScheme",
1174    "double",
1175    "at::IntArrayRef",
1176    "at::ScalarType",
1177    "at::Stream",
1178}
1179
1180
1181def dispatch_lambda_return_str(f: NativeFunction) -> str:
1182    # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
1183    # because the dispatch lambdas take mutable arguments *by value*, not
1184    # by reference. If you then return a reference to such an argument, you
1185    # will now have a pointer to a dangling stack entry. Not good.
1186    #
1187    # You want:
1188    #
1189    #   auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
1190    #                                            ^^^^^^
1191    #
1192    # *not*
1193    #
1194    #   auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
1195    #                                            ^^^^^^^
1196    #
1197    # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
1198    # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
1199    # mutable reference to temporary.  Maybe we could assign it to a
1200    # variable itself.)
1201    returns_without_annotation = tuple(
1202        Return(r.name, r.type, None) for r in f.func.returns
1203    )
1204    return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
1205    if return_str not in SUPPORTED_RETURN_TYPES:
1206        raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
1207    return return_str
1208
1209
1210def cpp_dispatch_target(f: NativeFunction) -> str:
1211    symint = f.func.has_symint()
1212    name = cpp.name(f.func, symint_overload=symint)
1213    if Variant.method in f.variants:
1214        return f"self.{name}"
1215    if Variant.function in f.variants:
1216        if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
1217            namespace = "torch"
1218        else:
1219            namespace = "at"
1220        return f"{namespace}::{name}"
1221    raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
1222
1223
1224def cpp_dispatch_exprs(
1225    f: NativeFunction,
1226    *,
1227    python_signature: PythonSignature | None = None,
1228) -> tuple[str, ...]:
1229    cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
1230
1231    exprs: tuple[str, ...] = ()
1232    if not isinstance(python_signature, PythonSignatureDeprecated):
1233        # By default the exprs are consistent with the C++ signature.
1234        exprs = tuple(a.name for a in cpp_args)
1235    else:
1236        # For deprecated python signature we may need fill in some constants.
1237        exprs = tuple(
1238            filter(
1239                lambda n: n != "out" or f.func.is_out_fn(),
1240                python_signature.deprecated_args_exprs,
1241            )
1242        )
1243
1244    if Variant.method in f.variants:
1245        exprs = tuple(filter("self".__ne__, exprs))
1246
1247    return exprs
1248
1249
1250# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1251#
1252#                     Python / C++ Args Binding
1253#
1254# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1255
1256
1257# We explicitly enumerate the PythonArgParser unpacking methods for all
1258# supported types. This might be more verbose than necessary, partially
1259# because of the irregularity of unpacking method naming, partially
1260# because we want to mimic the old codegen behavior - to reject
1261# unexpected and/or unsupported cases which the old codegen rejects.
1262# For certain cases it is intentionally more restrictive than necessary,
1263# e.g.: it doesn't accepts doublelist with definite size.
1264def arg_parser_unpack_method(
1265    t: Type, default: str | None, default_init: str | None, *, symint: bool = True
1266) -> str:
1267    has_default_init = default_init is not None
1268    if has_default_init and str(t) not in (
1269        "ScalarType?",
1270        "ScalarType",
1271        "Device",
1272        "Device?",
1273        "Layout",
1274        "Layout?",
1275        "bool",
1276        "bool?",
1277    ):
1278        raise RuntimeError(f"type '{t}' does not supported unpacking with default")
1279
1280    if isinstance(t, BaseType):
1281        if t.name in [
1282            BaseTy.Tensor,
1283            BaseTy.Stream,
1284            BaseTy.Storage,
1285            BaseTy.Scalar,
1286            BaseTy.Dimname,
1287        ]:
1288            # These unpack methods line up with their schema names
1289            return t.name.name.lower()
1290        elif t.name == BaseTy.ScalarType:
1291            return "scalartypeWithDefault" if has_default_init else "scalartype"
1292        elif t.name == BaseTy.Device:
1293            return "deviceWithDefault" if has_default_init else "device"
1294        elif t.name == BaseTy.DeviceIndex:
1295            return "toInt64"
1296        elif t.name == BaseTy.int:
1297            return "toInt64"
1298        elif t.name == BaseTy.SymInt:
1299            return "toSymInt" if symint else "toInt64"
1300        elif t.name == BaseTy.bool:
1301            return "toBoolWithDefault" if has_default_init else "toBool"
1302        elif t.name == BaseTy.float:
1303            return "toDouble"
1304        elif t.name == BaseTy.str:
1305            return "stringView"
1306        elif t.name == BaseTy.Layout:
1307            return "layoutWithDefault" if has_default_init else "layout"
1308        elif t.name == BaseTy.MemoryFormat:
1309            return "memoryformat"
1310
1311    elif isinstance(t, OptionalType):
1312        if str(t.elem) == "Tensor":
1313            return "optionalTensor"
1314        elif str(t.elem) == "Generator":
1315            return "generator"
1316        elif str(t.elem) == "Dimname[]":
1317            return "toDimnameListOptional"
1318        elif not has_default_init and default in (
1319            None,
1320            "None",
1321            "::std::nullopt",
1322            "std::nullopt",
1323        ):
1324            # If default is None: append 'Optional' to elem's unpacking method
1325            return (
1326                arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
1327            )
1328        else:
1329            # Otherwise, load as underlying type with default
1330            return arg_parser_unpack_method(
1331                t.elem, default, default_init, symint=symint
1332            )
1333
1334    elif isinstance(t, ListType):
1335        if str(t.elem) == "Tensor":
1336            # accept and use definite size
1337            return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
1338        elif str(t.elem) == "Tensor?":
1339            return "list_of_optional_tensors"
1340        elif str(t.elem) == "Dimname":
1341            # accept definite size
1342            return "dimnamelist"
1343        elif str(t.elem) == "int":
1344            # accept definite size
1345            return "intlist"
1346        elif str(t.elem) == "float":
1347            return "doublelist"
1348        elif str(t.elem) == "SymInt":
1349            # accept definite size
1350            return "symintlist" if symint else "intlist"
1351        elif str(t.elem) == "Scalar":
1352            return "scalarlist"
1353    raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
1354
1355
1356# Return RHS expression for python argument using PythonArgParser output.
1357# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
1358def arg_parser_output_expr(
1359    arg_index: int, a: PythonArgument, *, symint: bool = True
1360) -> PythonArgParserOutputExpr:
1361    has_default = a.default_init is not None
1362    unpack_method = arg_parser_unpack_method(
1363        t=a.type, default=a.default, default_init=a.default_init, symint=symint
1364    )
1365    default = f", {a.default_init}" if has_default else ""
1366    expr = f"_r.{unpack_method}({arg_index}{default})"
1367
1368    return PythonArgParserOutputExpr(
1369        name=a.name,
1370        expr=expr,
1371        index=arg_index,
1372        argument=a,
1373    )
1374
1375
1376# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
1377def arg_parser_output_exprs(
1378    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1379) -> dict[str, PythonArgParserOutputExpr]:
1380    return {
1381        e.name: e
1382        for i, a in enumerate(ps.arguments())
1383        for e in (arg_parser_output_expr(i, a, symint=symint),)
1384    }
1385
1386
1387# argument name to type for scattered tensor options fields
1388TENSOR_OPTIONS_FIELDS = {
1389    "dtype": "ScalarType?",
1390    "device": "Device?",
1391    "layout": "Layout?",
1392    "pin_memory": "bool?",
1393    "requires_grad": "bool?",
1394}
1395
1396
1397# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
1398def dispatch_lambda_exprs(
1399    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1400) -> DispatchLambdaArgumentExprs:
1401    # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
1402    # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
1403    # outputs.
1404    arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
1405    lambda_args = dispatch_lambda_args(ps, f, symint=symint)
1406    inits: list[str] = []
1407    lambda_args_exprs: dict[str, str] = {}
1408
1409    has_toptions = has_tensor_options(f)
1410
1411    # 1. special inits/unpacking to provide binding exprs for lambda arguments.
1412    for a in ps.arguments(skip_tensor_options=True):
1413        name = a.name
1414        arg_parser_expr = arg_parser_outputs[a.name].expr
1415
1416        if has_toptions and name == "self":
1417            # TODO: why this needs to be special case?
1418            inits.extend(
1419                [
1420                    f"auto self = {arg_parser_expr};",
1421                ]
1422            )
1423            lambda_args_exprs[name] = name
1424        elif (
1425            isinstance(a, PythonOutArgument)
1426            and len(a.outputs) > 1
1427            and f.func.is_out_fn()
1428        ):
1429            inits.extend(
1430                [
1431                    f"auto out = {arg_parser_expr};",
1432                ]
1433            )
1434            for i, out_arg in enumerate(a.outputs):
1435                lambda_args_exprs[out_arg.name] = f"out[{i}]"
1436        elif str(a.type) == "Dimname[]?":
1437            # [old codegen]
1438            # TODO: make this part of something more general, or get rid of it.
1439            # optional<ArrayRef<T>> are special. The PythonArgParser returns an
1440            # optional<vector<T>>, which cannot be implicitly converted to
1441            # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
1442            inits.extend(
1443                [
1444                    f"auto __{name} = {arg_parser_expr};",
1445                    f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;",  # noqa: B950
1446                ]
1447            )
1448            lambda_args_exprs[name] = name
1449        else:
1450            # default case - directly using PythonArgParser output expr
1451            lambda_args_exprs[name] = arg_parser_expr
1452
1453    # method's self is passed directly to python binding, rather than parsed
1454    if ps.method:
1455        lambda_args_exprs["self"] = "self"
1456
1457    # 2. special packing/checking for TensorOptions.
1458    tensor_options_args_names = [a.name for a in ps.tensor_options_args]
1459    if has_toptions:
1460        if f.func.is_out_fn():
1461            raise RuntimeError(f"{f.func}: tensor options with output arg")
1462        for a in ps.tensor_options_args:
1463            if a.name not in TENSOR_OPTIONS_FIELDS:
1464                raise RuntimeError(
1465                    f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
1466                )
1467            if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
1468                raise RuntimeError(
1469                    f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
1470                )
1471        if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
1472            raise RuntimeError(
1473                f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
1474            )
1475
1476        inits.append(
1477            f"""\
1478const auto options = TensorOptions()
1479    .dtype({arg_parser_outputs['dtype'].expr})
1480    .device({arg_parser_outputs['device'].expr})
1481    .layout({arg_parser_outputs['layout'].expr})
1482    .requires_grad({arg_parser_outputs['requires_grad'].expr})
1483    .pinned_memory({arg_parser_outputs['pin_memory'].expr});
1484torch::utils::maybe_initialize_device(options);
1485"""
1486        )
1487        lambda_args_exprs["options"] = "options"
1488
1489    # 3. special case - access scattered TensorOptions fields without packing
1490    # TODO: maybe move to the generator side as it's not related to binding.
1491    if not has_toptions and tensor_options_args_names:
1492        if "dtype" in tensor_options_args_names:
1493            # we're an output-arg variant, check these args against output tensor
1494            if not f.func.is_out_fn():
1495                raise RuntimeError(
1496                    f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
1497                )
1498            if not all(a in tensor_options_args_names for a in ("layout", "device")):
1499                raise RuntimeError(
1500                    f"{f.func}: incomplete tensor options for output check"
1501                )
1502
1503            inits.append(
1504                f"""\
1505check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
1506                       {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
1507                       {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
1508"""
1509            )
1510        # we'll set requires_grad on outgoing tensor
1511        if "requires_grad" not in tensor_options_args_names:
1512            raise RuntimeError(
1513                f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
1514            )
1515
1516    return DispatchLambdaArgumentExprs(
1517        exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
1518        inits=inits,
1519    )
1520