xref: /aosp_15_r20/external/pytorch/torchgen/api/python.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import Binding, CppSignature, CppSignatureGroup
8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import pythonify_default
9*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
10*da0073e9SAndroid Build Coastguard Worker    Argument,
11*da0073e9SAndroid Build Coastguard Worker    BaseTy,
12*da0073e9SAndroid Build Coastguard Worker    BaseType,
13*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
14*da0073e9SAndroid Build Coastguard Worker    ListType,
15*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
16*da0073e9SAndroid Build Coastguard Worker    OptionalType,
17*da0073e9SAndroid Build Coastguard Worker    Return,
18*da0073e9SAndroid Build Coastguard Worker    Type,
19*da0073e9SAndroid Build Coastguard Worker    Variant,
20*da0073e9SAndroid Build Coastguard Worker)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
24*da0073e9SAndroid Build Coastguard Worker#
25*da0073e9SAndroid Build Coastguard Worker#                           Data Models
26*da0073e9SAndroid Build Coastguard Worker#
27*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
28*da0073e9SAndroid Build Coastguard Worker#
29*da0073e9SAndroid Build Coastguard Worker# [Notes] python binding codegen
30*da0073e9SAndroid Build Coastguard Worker#
31*da0073e9SAndroid Build Coastguard Worker# The Python binding codegen produces code that takes the input list of
32*da0073e9SAndroid Build Coastguard Worker# PyObjects, finds the matching ATen C++ function using PythonArgParser,
33*da0073e9SAndroid Build Coastguard Worker# converts the PyObjects into C++ types and calls the ATen C++ function:
34*da0073e9SAndroid Build Coastguard Worker#
35*da0073e9SAndroid Build Coastguard Worker# +--------+  parsing   +------------------------+  binding   +-----------------------+
36*da0073e9SAndroid Build Coastguard Worker# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
37*da0073e9SAndroid Build Coastguard Worker# +--------+            +------------------------+            +-----------------------+
38*da0073e9SAndroid Build Coastguard Worker#
39*da0073e9SAndroid Build Coastguard Worker# The following examples demonstrate the data models the Python binding
40*da0073e9SAndroid Build Coastguard Worker# codegen needs to deal with and the tasks it needs to accomplish. It
41*da0073e9SAndroid Build Coastguard Worker# helps understand the purpose of the new data types we introduced below.
42*da0073e9SAndroid Build Coastguard Worker#
43*da0073e9SAndroid Build Coastguard Worker#  - Function Schema (source of truth)
44*da0073e9SAndroid Build Coastguard Worker#
45*da0073e9SAndroid Build Coastguard Worker#      aten::empty.names(int[] size, *, Dimname[]? names,
46*da0073e9SAndroid Build Coastguard Worker#                        ScalarType? dtype=None, Layout? layout=None,
47*da0073e9SAndroid Build Coastguard Worker#                        Device? device=None, bool? pin_memory=None,
48*da0073e9SAndroid Build Coastguard Worker#                        MemoryFormat? memory_format=None) -> Tensor
49*da0073e9SAndroid Build Coastguard Worker#
50*da0073e9SAndroid Build Coastguard Worker#  - Python Signature
51*da0073e9SAndroid Build Coastguard Worker#
52*da0073e9SAndroid Build Coastguard Worker#    It's used to generate input schema string for PythonArgParser.
53*da0073e9SAndroid Build Coastguard Worker#    Note: TensorOptions fields are reordered and the additional
54*da0073e9SAndroid Build Coastguard Worker#    'requires_grad' field is added:
55*da0073e9SAndroid Build Coastguard Worker#
56*da0073e9SAndroid Build Coastguard Worker#      empty(IntArrayRef size, *, DimnameList? names,
57*da0073e9SAndroid Build Coastguard Worker#            MemoryFormat? memory_format=None, ScalarType dtype=None,
58*da0073e9SAndroid Build Coastguard Worker#            Layout layout=torch.strided, Device device=None,
59*da0073e9SAndroid Build Coastguard Worker#            bool pin_memory=False, bool requires_grad=False)
60*da0073e9SAndroid Build Coastguard Worker#
61*da0073e9SAndroid Build Coastguard Worker#  - C++ Signature
62*da0073e9SAndroid Build Coastguard Worker#
63*da0073e9SAndroid Build Coastguard Worker#    It's used to generate C++ lambda formals & dispatch call.
64*da0073e9SAndroid Build Coastguard Worker#    Note: the scattered TensorOptions fields are packed into 'options'.
65*da0073e9SAndroid Build Coastguard Worker#
66*da0073e9SAndroid Build Coastguard Worker#      auto dispatch_empty =
67*da0073e9SAndroid Build Coastguard Worker#          [](IntArrayRef size, std::optional<DimnameList> names,
68*da0073e9SAndroid Build Coastguard Worker#             const TensorOptions & options,
69*da0073e9SAndroid Build Coastguard Worker#             std::optional<MemoryFormat> memory_format) -> Tensor {
70*da0073e9SAndroid Build Coastguard Worker#          pybind11::gil_scoped_release no_gil;
71*da0073e9SAndroid Build Coastguard Worker#          return torch::empty(size, names, options, memory_format);
72*da0073e9SAndroid Build Coastguard Worker#      };
73*da0073e9SAndroid Build Coastguard Worker#
74*da0073e9SAndroid Build Coastguard Worker#  - Binding between Python Arguments and C++ Arguments
75*da0073e9SAndroid Build Coastguard Worker#
76*da0073e9SAndroid Build Coastguard Worker#    Given a set of Python Arguments in scope, we need produce the
77*da0073e9SAndroid Build Coastguard Worker#    binding expressions that translate the Python API into C++ API:
78*da0073e9SAndroid Build Coastguard Worker#
79*da0073e9SAndroid Build Coastguard Worker#            Python Args               Cpp Args       Binding Exprs
80*da0073e9SAndroid Build Coastguard Worker#     -----------------------------------------------------------------
81*da0073e9SAndroid Build Coastguard Worker#         0: size                      size           '_r.intlist(0)'
82*da0073e9SAndroid Build Coastguard Worker#         1: names                     names          'names' [special init]
83*da0073e9SAndroid Build Coastguard Worker#         2: memory_format -------+
84*da0073e9SAndroid Build Coastguard Worker#         3: dtype         -----+-|--> options        'options' [special packing]
85*da0073e9SAndroid Build Coastguard Worker#         4: layout            /  |
86*da0073e9SAndroid Build Coastguard Worker#         5: device           /   +--> memory_format  '_r.memoryformatOptional(2)'
87*da0073e9SAndroid Build Coastguard Worker#         6: pin_memory      /
88*da0073e9SAndroid Build Coastguard Worker#         7: requires_grad -+
89*da0073e9SAndroid Build Coastguard Worker#
90*da0073e9SAndroid Build Coastguard Worker#    So the full dispatch expression would look like:
91*da0073e9SAndroid Build Coastguard Worker#
92*da0073e9SAndroid Build Coastguard Worker#      dispatch_empty(_r.intlist(0), names, options,
93*da0073e9SAndroid Build Coastguard Worker#                     _r.memoryformatOptional(2))
94*da0073e9SAndroid Build Coastguard Worker#
95*da0073e9SAndroid Build Coastguard Worker#    Where does 'names' come from? It involves special local init:
96*da0073e9SAndroid Build Coastguard Worker#
97*da0073e9SAndroid Build Coastguard Worker#      auto __names = _r.toDimnameListOptional(1);
98*da0073e9SAndroid Build Coastguard Worker#      std::optional<DimnameList> names =
99*da0073e9SAndroid Build Coastguard Worker#          __names ? std::make_optional(DimnameList(__names.value()))
100*da0073e9SAndroid Build Coastguard Worker#                  : std::nullopt;
101*da0073e9SAndroid Build Coastguard Worker#
102*da0073e9SAndroid Build Coastguard Worker#    Where does 'options' come from? It involves special local init
103*da0073e9SAndroid Build Coastguard Worker#    for TensorOptions. Note that Python side has the additional
104*da0073e9SAndroid Build Coastguard Worker#    'requires_grad' field:
105*da0073e9SAndroid Build Coastguard Worker#
106*da0073e9SAndroid Build Coastguard Worker#      const auto options = TensorOptions()
107*da0073e9SAndroid Build Coastguard Worker#          .dtype(_r.scalartype(3))
108*da0073e9SAndroid Build Coastguard Worker#          .device(_r.device(5))
109*da0073e9SAndroid Build Coastguard Worker#          .layout(_r.layoutOptional(4))
110*da0073e9SAndroid Build Coastguard Worker#          .requires_grad(_r.toBool(7))
111*da0073e9SAndroid Build Coastguard Worker#          .pinned_memory(_r.toBool(6));
112*da0073e9SAndroid Build Coastguard Worker#
113*da0073e9SAndroid Build Coastguard Worker#    In some other cases one Python Argument can map to multiple C++
114*da0073e9SAndroid Build Coastguard Worker#    Arguments. For example:
115*da0073e9SAndroid Build Coastguard Worker#
116*da0073e9SAndroid Build Coastguard Worker#     aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
117*da0073e9SAndroid Build Coastguard Worker#       -> (Tensor values, Tensor indices)
118*da0073e9SAndroid Build Coastguard Worker#
119*da0073e9SAndroid Build Coastguard Worker#            Python Args               Cpp Args          Binding Exprs
120*da0073e9SAndroid Build Coastguard Worker#     ---------------------------------------------------------------------
121*da0073e9SAndroid Build Coastguard Worker#                               +----> max               'out[0]'
122*da0073e9SAndroid Build Coastguard Worker#                              /-----> max_values        'out[1]
123*da0073e9SAndroid Build Coastguard Worker#         0: input            /        self              '_r.tensor(0)'
124*da0073e9SAndroid Build Coastguard Worker#         1: dim             /         dim               '_r.dimname(1)'
125*da0073e9SAndroid Build Coastguard Worker#         2: keepdim        /          keepdim           '_r.toBool(2)'
126*da0073e9SAndroid Build Coastguard Worker#         3: out      -----+           [local init] out  '_r.tensorlist_n<2>(3)'
127*da0073e9SAndroid Build Coastguard Worker#
128*da0073e9SAndroid Build Coastguard Worker#    As demonstrated above, the binding can involve reordering,
129*da0073e9SAndroid Build Coastguard Worker#    packing, unpacking and special local inits.
130*da0073e9SAndroid Build Coastguard Worker#
131*da0073e9SAndroid Build Coastguard Worker#
132*da0073e9SAndroid Build Coastguard Worker#  Let's look at a concrete example:
133*da0073e9SAndroid Build Coastguard Worker#
134*da0073e9SAndroid Build Coastguard Worker#      static PythonArgParser parser({
135*da0073e9SAndroid Build Coastguard Worker#        "abs(Tensor input, *, Tensor out=None)",
136*da0073e9SAndroid Build Coastguard Worker#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137*da0073e9SAndroid Build Coastguard Worker#         ^
138*da0073e9SAndroid Build Coastguard Worker#         +--- Python Schema, represented by PythonSignature and PythonArgument
139*da0073e9SAndroid Build Coastguard Worker#
140*da0073e9SAndroid Build Coastguard Worker#      }, /*traceable=*/true);
141*da0073e9SAndroid Build Coastguard Worker#
142*da0073e9SAndroid Build Coastguard Worker#      ParsedArgs<2> parsed_args;
143*da0073e9SAndroid Build Coastguard Worker#      auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
144*da0073e9SAndroid Build Coastguard Worker#
145*da0073e9SAndroid Build Coastguard Worker#      ...
146*da0073e9SAndroid Build Coastguard Worker#
147*da0073e9SAndroid Build Coastguard Worker#      if (_r.isNone(1)) {
148*da0073e9SAndroid Build Coastguard Worker#          ~~~~~~~~~~~~  <--- Scattered PythonArgParser output (arg name = 'out')
149*da0073e9SAndroid Build Coastguard Worker#                             represented by PythonArgParserOutputExpr
150*da0073e9SAndroid Build Coastguard Worker#
151*da0073e9SAndroid Build Coastguard Worker#        // aten::abs(Tensor self) -> Tensor
152*da0073e9SAndroid Build Coastguard Worker#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153*da0073e9SAndroid Build Coastguard Worker#         ^
154*da0073e9SAndroid Build Coastguard Worker#         +--- NativeFunction schema, base version
155*da0073e9SAndroid Build Coastguard Worker#
156*da0073e9SAndroid Build Coastguard Worker#        auto dispatch_abs = [](const Tensor & self) -> Tensor {
157*da0073e9SAndroid Build Coastguard Worker#                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
158*da0073e9SAndroid Build Coastguard Worker#                             ^
159*da0073e9SAndroid Build Coastguard Worker#                             +--- dispatch_lambda_args / dispatch_lambda_return_str
160*da0073e9SAndroid Build Coastguard Worker#                                  generated from NativeFunction / CppSignature
161*da0073e9SAndroid Build Coastguard Worker#                                  (deprecated PythonSignature is special)
162*da0073e9SAndroid Build Coastguard Worker#                                  arguments are represented by DispatchLambdaArgument
163*da0073e9SAndroid Build Coastguard Worker#
164*da0073e9SAndroid Build Coastguard Worker#          pybind11::gil_scoped_release no_gil;
165*da0073e9SAndroid Build Coastguard Worker#          return self.abs();
166*da0073e9SAndroid Build Coastguard Worker#                 ~~~~~~~~~~~  <--- cpp_dispatch_target / cpp_dispatch_exprs
167*da0073e9SAndroid Build Coastguard Worker#                                   generated from NativeFunction / CppSignature
168*da0073e9SAndroid Build Coastguard Worker#        };
169*da0073e9SAndroid Build Coastguard Worker#        return wrap(dispatch_abs(_r.tensor(0)));
170*da0073e9SAndroid Build Coastguard Worker#                                 ~~~~~~~~~~~~~
171*da0073e9SAndroid Build Coastguard Worker#                                  ^
172*da0073e9SAndroid Build Coastguard Worker#                                  +--- dispatch_lambda_exprs
173*da0073e9SAndroid Build Coastguard Worker#                                       binding PythonArgParserOutputExpr (python args)
174*da0073e9SAndroid Build Coastguard Worker#                                       and DispatchLambdaArgument (c++ args)
175*da0073e9SAndroid Build Coastguard Worker#
176*da0073e9SAndroid Build Coastguard Worker#      } else {
177*da0073e9SAndroid Build Coastguard Worker#        // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
178*da0073e9SAndroid Build Coastguard Worker#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
179*da0073e9SAndroid Build Coastguard Worker#         ^
180*da0073e9SAndroid Build Coastguard Worker#         +--- NativeFunction schema, out-variant
181*da0073e9SAndroid Build Coastguard Worker#
182*da0073e9SAndroid Build Coastguard Worker#        auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
183*da0073e9SAndroid Build Coastguard Worker#          pybind11::gil_scoped_release no_gil;
184*da0073e9SAndroid Build Coastguard Worker#          return at::abs_out(out, self);
185*da0073e9SAndroid Build Coastguard Worker#        };
186*da0073e9SAndroid Build Coastguard Worker#        return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
187*da0073e9SAndroid Build Coastguard Worker#      }
188*da0073e9SAndroid Build Coastguard Worker#
189*da0073e9SAndroid Build Coastguard Worker#
190*da0073e9SAndroid Build Coastguard Worker# [Notes] python interface codegen
191*da0073e9SAndroid Build Coastguard Worker# The python dataclasses below are used used to generate both python binding code
192*da0073e9SAndroid Build Coastguard Worker# and pyi type hint signatures.
193*da0073e9SAndroid Build Coastguard Worker# In theory these two should look very similar, but there are number of differences
194*da0073e9SAndroid Build Coastguard Worker# in how pyi signatures vs. python_arg_parser signatures are generated.
195*da0073e9SAndroid Build Coastguard Worker# These differences have been encapsulated in signature_str() vs. signature_str_pyi()
196*da0073e9SAndroid Build Coastguard Worker# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
197*da0073e9SAndroid Build Coastguard Worker# For examples, only pyi signatures include return types.
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
201*da0073e9SAndroid Build Coastguard Workerclass PythonReturns:
202*da0073e9SAndroid Build Coastguard Worker    returns: tuple[Return, ...]
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
206*da0073e9SAndroid Build Coastguard Workerclass PythonArgument:
207*da0073e9SAndroid Build Coastguard Worker    name: str
208*da0073e9SAndroid Build Coastguard Worker    type: Type
209*da0073e9SAndroid Build Coastguard Worker    default: str | None
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
212*da0073e9SAndroid Build Coastguard Worker    #
213*da0073e9SAndroid Build Coastguard Worker    #   _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
214*da0073e9SAndroid Build Coastguard Worker    #                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215*da0073e9SAndroid Build Coastguard Worker    #                            ^
216*da0073e9SAndroid Build Coastguard Worker    #                            +--- default_init str
217*da0073e9SAndroid Build Coastguard Worker    default_init: str | None
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    # Compute argument formal for python argument parsing.
220*da0073e9SAndroid Build Coastguard Worker    # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
221*da0073e9SAndroid Build Coastguard Worker    def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
222*da0073e9SAndroid Build Coastguard Worker        type_str = (
223*da0073e9SAndroid Build Coastguard Worker            argument_type_str(self.type, symint=symint)
224*da0073e9SAndroid Build Coastguard Worker            .replace("const ", "")
225*da0073e9SAndroid Build Coastguard Worker            .replace(" &", "")
226*da0073e9SAndroid Build Coastguard Worker        )
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        name = self.name
229*da0073e9SAndroid Build Coastguard Worker        # s/self/input/ outside method bindings
230*da0073e9SAndroid Build Coastguard Worker        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
231*da0073e9SAndroid Build Coastguard Worker        # for the parse string
232*da0073e9SAndroid Build Coastguard Worker        if name == "self" and type_str in ["Tensor", "Number"] and not method:
233*da0073e9SAndroid Build Coastguard Worker            name = "input"
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        # add default
236*da0073e9SAndroid Build Coastguard Worker        if self.default is not None:
237*da0073e9SAndroid Build Coastguard Worker            default = {
238*da0073e9SAndroid Build Coastguard Worker                "nullptr": "None",
239*da0073e9SAndroid Build Coastguard Worker                "::std::nullopt": "None",
240*da0073e9SAndroid Build Coastguard Worker                "std::nullopt": "None",
241*da0073e9SAndroid Build Coastguard Worker                "{}": "None",
242*da0073e9SAndroid Build Coastguard Worker            }.get(self.default, self.default)
243*da0073e9SAndroid Build Coastguard Worker            return f"{type_str} {name}={default}"
244*da0073e9SAndroid Build Coastguard Worker        else:
245*da0073e9SAndroid Build Coastguard Worker            return f"{type_str} {name}"
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    def argument_str_pyi(
248*da0073e9SAndroid Build Coastguard Worker        self, *, method: bool = False, deprecated: bool = False
249*da0073e9SAndroid Build Coastguard Worker    ) -> str:
250*da0073e9SAndroid Build Coastguard Worker        type_str = argument_type_str_pyi(self.type)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        name = self.name
253*da0073e9SAndroid Build Coastguard Worker        # s/self/input/ outside method bindings
254*da0073e9SAndroid Build Coastguard Worker        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
255*da0073e9SAndroid Build Coastguard Worker        # for the parse string
256*da0073e9SAndroid Build Coastguard Worker        if name == "self" and type_str == "Tensor" and not method and not deprecated:
257*da0073e9SAndroid Build Coastguard Worker            name = "input"
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        if name == "from":  # from is a Python keyword...
260*da0073e9SAndroid Build Coastguard Worker            name += "_"
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        # pyi merges the _out and functional variants into the same signature, with an optional out arg
263*da0073e9SAndroid Build Coastguard Worker        if name == "out" and type_str == "Tensor" and not deprecated:
264*da0073e9SAndroid Build Coastguard Worker            type_str = "Optional[" + type_str + "]"
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        # pyi deprecated signatures don't get defaults for their out arg
267*da0073e9SAndroid Build Coastguard Worker        treat_as_no_default = (
268*da0073e9SAndroid Build Coastguard Worker            deprecated
269*da0073e9SAndroid Build Coastguard Worker            and isinstance(self, PythonOutArgument)
270*da0073e9SAndroid Build Coastguard Worker            and self.default == "None"
271*da0073e9SAndroid Build Coastguard Worker        )
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        # add default
274*da0073e9SAndroid Build Coastguard Worker        if self.default is not None and not treat_as_no_default:
275*da0073e9SAndroid Build Coastguard Worker            if (
276*da0073e9SAndroid Build Coastguard Worker                isinstance(self.type, ListType)
277*da0073e9SAndroid Build Coastguard Worker                and self.type.elem == BaseType(BaseTy.int)
278*da0073e9SAndroid Build Coastguard Worker                and self.default.startswith("{")
279*da0073e9SAndroid Build Coastguard Worker                and self.default.endswith("}")
280*da0073e9SAndroid Build Coastguard Worker            ):
281*da0073e9SAndroid Build Coastguard Worker                default = (
282*da0073e9SAndroid Build Coastguard Worker                    "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
283*da0073e9SAndroid Build Coastguard Worker                )
284*da0073e9SAndroid Build Coastguard Worker            else:
285*da0073e9SAndroid Build Coastguard Worker                default = {
286*da0073e9SAndroid Build Coastguard Worker                    "nullptr": "None",
287*da0073e9SAndroid Build Coastguard Worker                    "::std::nullopt": "None",
288*da0073e9SAndroid Build Coastguard Worker                    "std::nullopt": "None",
289*da0073e9SAndroid Build Coastguard Worker                    "{}": "None",
290*da0073e9SAndroid Build Coastguard Worker                    "c10::MemoryFormat::Contiguous": "contiguous_format",
291*da0073e9SAndroid Build Coastguard Worker                    "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
292*da0073e9SAndroid Build Coastguard Worker                }.get(self.default, self.default)
293*da0073e9SAndroid Build Coastguard Worker            return f"{name}: {type_str} = {default}"
294*da0073e9SAndroid Build Coastguard Worker        else:
295*da0073e9SAndroid Build Coastguard Worker            return f"{name}: {type_str}"
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
299*da0073e9SAndroid Build Coastguard Workerclass PythonOutArgument(PythonArgument):
300*da0073e9SAndroid Build Coastguard Worker    # In Python signature multiple output fields are packed into one 'out' argument.
301*da0073e9SAndroid Build Coastguard Worker    # When binding to C++, it's first binded to a local 'out' variable:
302*da0073e9SAndroid Build Coastguard Worker    #   'auto out = _r.tensorlist_n<2>(2);',
303*da0073e9SAndroid Build Coastguard Worker    # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
304*da0073e9SAndroid Build Coastguard Worker    # TODO: maybe don't need keep scattered out fields for python signature?
305*da0073e9SAndroid Build Coastguard Worker    outputs: tuple[PythonArgument, ...]
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    @staticmethod
308*da0073e9SAndroid Build Coastguard Worker    def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
309*da0073e9SAndroid Build Coastguard Worker        if not outputs:
310*da0073e9SAndroid Build Coastguard Worker            return None
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        size = len(outputs)
313*da0073e9SAndroid Build Coastguard Worker        if size == 1:
314*da0073e9SAndroid Build Coastguard Worker            return PythonOutArgument(
315*da0073e9SAndroid Build Coastguard Worker                name=outputs[0].name,
316*da0073e9SAndroid Build Coastguard Worker                type=outputs[0].type,
317*da0073e9SAndroid Build Coastguard Worker                default="None",
318*da0073e9SAndroid Build Coastguard Worker                default_init=None,
319*da0073e9SAndroid Build Coastguard Worker                outputs=outputs,
320*da0073e9SAndroid Build Coastguard Worker            )
321*da0073e9SAndroid Build Coastguard Worker        elif size > 1:
322*da0073e9SAndroid Build Coastguard Worker            if any(not a.type.is_tensor_like() for a in outputs):
323*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"Unsupported output type: {outputs}")
324*da0073e9SAndroid Build Coastguard Worker            return PythonOutArgument(
325*da0073e9SAndroid Build Coastguard Worker                name="out",
326*da0073e9SAndroid Build Coastguard Worker                # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
327*da0073e9SAndroid Build Coastguard Worker                type=ListType(BaseType(BaseTy.Tensor), size),
328*da0073e9SAndroid Build Coastguard Worker                default="None",
329*da0073e9SAndroid Build Coastguard Worker                default_init=None,
330*da0073e9SAndroid Build Coastguard Worker                outputs=outputs,
331*da0073e9SAndroid Build Coastguard Worker            )
332*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(r"Unexpected PythonOutArgument size")
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
336*da0073e9SAndroid Build Coastguard Workerclass PythonSignature:
337*da0073e9SAndroid Build Coastguard Worker    # Base operator name, without inplace/outplace suffix.
338*da0073e9SAndroid Build Coastguard Worker    name: str
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    # Positional arguments.
341*da0073e9SAndroid Build Coastguard Worker    # TODO: create a dedicated SelfArgument type for 'self'?
342*da0073e9SAndroid Build Coastguard Worker    input_args: tuple[PythonArgument, ...]
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
345*da0073e9SAndroid Build Coastguard Worker    # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
346*da0073e9SAndroid Build Coastguard Worker    input_kwargs: tuple[PythonArgument, ...]
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    output_args: PythonOutArgument | None
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    # Return types, which are only used by pyi
351*da0073e9SAndroid Build Coastguard Worker    returns: PythonReturns
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    # These are scattered kwargs arguments belonging to TensorOptions.
354*da0073e9SAndroid Build Coastguard Worker    # When binding to C++, they are packed into a TensorOptions object 'options'.
355*da0073e9SAndroid Build Coastguard Worker    # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
356*da0073e9SAndroid Build Coastguard Worker    # for out variant), in which case they will be used as scattered fields without
357*da0073e9SAndroid Build Coastguard Worker    # being packed into 'options'.
358*da0073e9SAndroid Build Coastguard Worker    # TODO: maybe create a PythonTensorOptionsArgument?
359*da0073e9SAndroid Build Coastguard Worker    tensor_options_args: tuple[PythonArgument, ...]
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    # method or function signature?
362*da0073e9SAndroid Build Coastguard Worker    method: bool
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker    @property
365*da0073e9SAndroid Build Coastguard Worker    def deprecated(self) -> bool:
366*da0073e9SAndroid Build Coastguard Worker        return False
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    def arguments(
369*da0073e9SAndroid Build Coastguard Worker        self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
370*da0073e9SAndroid Build Coastguard Worker    ) -> tuple[PythonArgument | PythonOutArgument, ...]:
371*da0073e9SAndroid Build Coastguard Worker        result: list[PythonArgument | PythonOutArgument] = []
372*da0073e9SAndroid Build Coastguard Worker        result.extend(self.input_args)
373*da0073e9SAndroid Build Coastguard Worker        result.extend(self.input_kwargs)
374*da0073e9SAndroid Build Coastguard Worker        if self.output_args is not None and not skip_outputs:
375*da0073e9SAndroid Build Coastguard Worker            result.append(self.output_args)
376*da0073e9SAndroid Build Coastguard Worker        if not skip_tensor_options:
377*da0073e9SAndroid Build Coastguard Worker            result.extend(self.tensor_options_args)
378*da0073e9SAndroid Build Coastguard Worker        return tuple(result)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker    def arguments_count(self) -> int:
381*da0073e9SAndroid Build Coastguard Worker        return len(self.arguments())
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker    def output_idx(self) -> int:
384*da0073e9SAndroid Build Coastguard Worker        return len(self.input_args) + len(self.input_kwargs)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker    # [old codegen] Compute the Python function signature for argument parsing,
387*da0073e9SAndroid Build Coastguard Worker    # as specified in torch/csrc/utils/python_arg_parser.h.  WARNING:
388*da0073e9SAndroid Build Coastguard Worker    # this is NOT the same type signature as specified by PEP 484
389*da0073e9SAndroid Build Coastguard Worker    # as understood by mypy; our format was independently developed
390*da0073e9SAndroid Build Coastguard Worker    # and has some quirks to make it more suitable specifically
391*da0073e9SAndroid Build Coastguard Worker    # for error parsing.
392*da0073e9SAndroid Build Coastguard Worker    #
393*da0073e9SAndroid Build Coastguard Worker    # For a translation to mypy-valid type signatures, see
394*da0073e9SAndroid Build Coastguard Worker    # signature_str_pyi().
395*da0073e9SAndroid Build Coastguard Worker    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
396*da0073e9SAndroid Build Coastguard Worker        args = self.arguments(skip_outputs=skip_outputs)
397*da0073e9SAndroid Build Coastguard Worker        schema_formals: list[str] = [
398*da0073e9SAndroid Build Coastguard Worker            a.argument_str(method=self.method, symint=symint) for a in args
399*da0073e9SAndroid Build Coastguard Worker        ]
400*da0073e9SAndroid Build Coastguard Worker        positional_argc = len(self.input_args)
401*da0073e9SAndroid Build Coastguard Worker        if len(schema_formals) > positional_argc:
402*da0073e9SAndroid Build Coastguard Worker            schema_formals.insert(positional_argc, "*")
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        return f'{self.name}({", ".join(schema_formals)})'
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
407*da0073e9SAndroid Build Coastguard Worker        args = self.arguments(skip_outputs=skip_outputs)
408*da0073e9SAndroid Build Coastguard Worker        schema_formals: list[str] = [
409*da0073e9SAndroid Build Coastguard Worker            a.argument_str_pyi(method=self.method) for a in args
410*da0073e9SAndroid Build Coastguard Worker        ]
411*da0073e9SAndroid Build Coastguard Worker        positional_argc = len(self.input_args)
412*da0073e9SAndroid Build Coastguard Worker        if len(schema_formals) > positional_argc:
413*da0073e9SAndroid Build Coastguard Worker            schema_formals.insert(positional_argc, "*")
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker        # only pyi signatures include returns
416*da0073e9SAndroid Build Coastguard Worker        returns_str = returns_str_pyi(self)
417*da0073e9SAndroid Build Coastguard Worker        # pyi also includes self (with no typing/defaults) for methods
418*da0073e9SAndroid Build Coastguard Worker        if self.method:
419*da0073e9SAndroid Build Coastguard Worker            schema_formals.insert(0, "self")
420*da0073e9SAndroid Build Coastguard Worker        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
423*da0073e9SAndroid Build Coastguard Worker        # only pyi uses vararg signatures
424*da0073e9SAndroid Build Coastguard Worker        args = self.arguments(skip_outputs=skip_outputs)
425*da0073e9SAndroid Build Coastguard Worker        schema_formals: list[str] = [
426*da0073e9SAndroid Build Coastguard Worker            a.argument_str_pyi(method=self.method) for a in args
427*da0073e9SAndroid Build Coastguard Worker        ]
428*da0073e9SAndroid Build Coastguard Worker        # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
429*da0073e9SAndroid Build Coastguard Worker        num_args = self.arguments_count()
430*da0073e9SAndroid Build Coastguard Worker        num_positionalargs = len(self.input_args)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        have_vararg_version = False
433*da0073e9SAndroid Build Coastguard Worker        if num_args > 0:
434*da0073e9SAndroid Build Coastguard Worker            vararg_type = args[0].type
435*da0073e9SAndroid Build Coastguard Worker            if (
436*da0073e9SAndroid Build Coastguard Worker                isinstance(vararg_type, ListType)
437*da0073e9SAndroid Build Coastguard Worker                and str(vararg_type.elem) in ["int", "SymInt"]
438*da0073e9SAndroid Build Coastguard Worker                and num_positionalargs == 1
439*da0073e9SAndroid Build Coastguard Worker            ):
440*da0073e9SAndroid Build Coastguard Worker                have_vararg_version = True
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        if not have_vararg_version:
443*da0073e9SAndroid Build Coastguard Worker            return None
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        # Below are the major changes in vararg vs. regular pyi signatures
446*da0073e9SAndroid Build Coastguard Worker        # vararg signatures also omit the asterix
447*da0073e9SAndroid Build Coastguard Worker        assert isinstance(vararg_type, ListType)
448*da0073e9SAndroid Build Coastguard Worker        schema_formals[0] = (
449*da0073e9SAndroid Build Coastguard Worker            "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
450*da0073e9SAndroid Build Coastguard Worker        )
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        returns_str = returns_str_pyi(self)
453*da0073e9SAndroid Build Coastguard Worker        # pyi also includes self (with no typing/defaults) for methods
454*da0073e9SAndroid Build Coastguard Worker        if self.method:
455*da0073e9SAndroid Build Coastguard Worker            schema_formals.insert(0, "self")
456*da0073e9SAndroid Build Coastguard Worker        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker# The deprecated python signature involves some special logic, so create a
460*da0073e9SAndroid Build Coastguard Worker# dedicated data model to store these extra properties.
461*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
462*da0073e9SAndroid Build Coastguard Workerclass PythonSignatureDeprecated(PythonSignature):
463*da0073e9SAndroid Build Coastguard Worker    # Schema for the deprecated function
464*da0073e9SAndroid Build Coastguard Worker    deprecated_schema: FunctionSchema
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    # The deprecated signature might miss some arguments that the corresponding
467*da0073e9SAndroid Build Coastguard Worker    # C++ signature expects. We need store the constant default values to pass in.
468*da0073e9SAndroid Build Coastguard Worker    # For example:
469*da0073e9SAndroid Build Coastguard Worker    #   [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
470*da0073e9SAndroid Build Coastguard Worker    #   [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
471*da0073e9SAndroid Build Coastguard Worker    #   [func call]: self.addmm(mat1, mat2, beta, 1)
472*da0073e9SAndroid Build Coastguard Worker    # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
473*da0073e9SAndroid Build Coastguard Worker    deprecated_args_exprs: tuple[str, ...]
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    @property
476*da0073e9SAndroid Build Coastguard Worker    def deprecated(self) -> bool:
477*da0073e9SAndroid Build Coastguard Worker        return True
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
480*da0073e9SAndroid Build Coastguard Worker        return (
481*da0073e9SAndroid Build Coastguard Worker            PythonSignature.signature_str(
482*da0073e9SAndroid Build Coastguard Worker                self, skip_outputs=skip_outputs, symint=symint
483*da0073e9SAndroid Build Coastguard Worker            )
484*da0073e9SAndroid Build Coastguard Worker            + "|deprecated"
485*da0073e9SAndroid Build Coastguard Worker        )
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
488*da0073e9SAndroid Build Coastguard Worker        args = self.arguments(skip_outputs=skip_outputs)
489*da0073e9SAndroid Build Coastguard Worker        schema_formals: list[str] = [
490*da0073e9SAndroid Build Coastguard Worker            a.argument_str_pyi(method=self.method, deprecated=True) for a in args
491*da0073e9SAndroid Build Coastguard Worker        ]
492*da0073e9SAndroid Build Coastguard Worker        positional_argc = len(self.input_args)
493*da0073e9SAndroid Build Coastguard Worker        if len(schema_formals) > positional_argc:
494*da0073e9SAndroid Build Coastguard Worker            schema_formals.insert(positional_argc, "*")
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        returns_str = returns_str_pyi(self)
497*da0073e9SAndroid Build Coastguard Worker        return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
500*da0073e9SAndroid Build Coastguard Worker        # the codegen doesn't include vararg variants for deprecated signatures
501*da0073e9SAndroid Build Coastguard Worker        return None
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker# This struct is used to hold the PythonSignature and its corresponding
505*da0073e9SAndroid Build Coastguard Worker# NativeFunction BEFORE grouping base and out-variant functions.
506*da0073e9SAndroid Build Coastguard Worker# Why not store NativeFunction in PythonSignature or construct PythonSignature
507*da0073e9SAndroid Build Coastguard Worker# from NativeFunction? Because they are not 1-1 mapped.
508*da0073e9SAndroid Build Coastguard Worker# One native function could have both deprecated and non-deprecated python
509*da0073e9SAndroid Build Coastguard Worker# signatures - NativeFunction doesn't contain information to construct the
510*da0073e9SAndroid Build Coastguard Worker# deprecated python signature.
511*da0073e9SAndroid Build Coastguard Worker# One python signature is used to handle both the base and the out-variant
512*da0073e9SAndroid Build Coastguard Worker# function - see 'PythonSignatureGroup'.
513*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
514*da0073e9SAndroid Build Coastguard Workerclass PythonSignatureNativeFunctionPair:
515*da0073e9SAndroid Build Coastguard Worker    signature: PythonSignature
516*da0073e9SAndroid Build Coastguard Worker    function: NativeFunction
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker# We merge pairs of functions with signatures that are equivalent mod
520*da0073e9SAndroid Build Coastguard Worker# output arguments, and use a single entry in the python_arg_parser sig
521*da0073e9SAndroid Build Coastguard Worker# list for both (output arguments become optional).
522*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
523*da0073e9SAndroid Build Coastguard Workerclass PythonSignatureGroup:
524*da0073e9SAndroid Build Coastguard Worker    # The signature used for Python argument parsing. The outplace signature
525*da0073e9SAndroid Build Coastguard Worker    # is preferred if exists, because it can be used to parse inputs for both
526*da0073e9SAndroid Build Coastguard Worker    # the out-place variant and the base version (with output omitted).
527*da0073e9SAndroid Build Coastguard Worker    signature: PythonSignature
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    # The regular ATen declaration (e.g. conv2d)
530*da0073e9SAndroid Build Coastguard Worker    base: NativeFunction
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    # The out variant (e.g. conv2d_out)
533*da0073e9SAndroid Build Coastguard Worker    outplace: NativeFunction | None
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker    @classmethod
536*da0073e9SAndroid Build Coastguard Worker    def from_pairs(
537*da0073e9SAndroid Build Coastguard Worker        cls,
538*da0073e9SAndroid Build Coastguard Worker        functional: PythonSignatureNativeFunctionPair,
539*da0073e9SAndroid Build Coastguard Worker        out: PythonSignatureNativeFunctionPair | None,
540*da0073e9SAndroid Build Coastguard Worker    ) -> PythonSignatureGroup:
541*da0073e9SAndroid Build Coastguard Worker        if out is None:
542*da0073e9SAndroid Build Coastguard Worker            return PythonSignatureGroup(
543*da0073e9SAndroid Build Coastguard Worker                signature=functional.signature,
544*da0073e9SAndroid Build Coastguard Worker                base=functional.function,
545*da0073e9SAndroid Build Coastguard Worker                outplace=None,
546*da0073e9SAndroid Build Coastguard Worker            )
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        # prefer the signature with optional out=... arguments because it's the
549*da0073e9SAndroid Build Coastguard Worker        # superset that can be used to parse input for both base and outplace.
550*da0073e9SAndroid Build Coastguard Worker        signature_kwargs = out.signature.__dict__.copy()
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        # Out overloads in C++ don't have TensorOptions arguments,
553*da0073e9SAndroid Build Coastguard Worker        # so take these from the functional variant
554*da0073e9SAndroid Build Coastguard Worker        signature_kwargs[
555*da0073e9SAndroid Build Coastguard Worker            "tensor_options_args"
556*da0073e9SAndroid Build Coastguard Worker        ] = functional.signature.tensor_options_args
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        return PythonSignatureGroup(
559*da0073e9SAndroid Build Coastguard Worker            signature=type(out.signature)(**signature_kwargs),
560*da0073e9SAndroid Build Coastguard Worker            base=functional.function,
561*da0073e9SAndroid Build Coastguard Worker            outplace=out.function,
562*da0073e9SAndroid Build Coastguard Worker        )
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker# C++ function dispatch is wrapped in a lambda function. The lambda function
566*da0073e9SAndroid Build Coastguard Worker# has almost the same signature as the C++ function, only with some small
567*da0073e9SAndroid Build Coastguard Worker# variants - see details below.
568*da0073e9SAndroid Build Coastguard Worker# This data model is used to represent arguments of the lambda function
569*da0073e9SAndroid Build Coastguard Worker# signature.
570*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
571*da0073e9SAndroid Build Coastguard Workerclass DispatchLambdaArgument:
572*da0073e9SAndroid Build Coastguard Worker    name: str
573*da0073e9SAndroid Build Coastguard Worker    type_str: str
574*da0073e9SAndroid Build Coastguard Worker    is_out_arg: bool
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker# To pass PyObjects arguments to C++ function (via the lambda wrapper),
578*da0073e9SAndroid Build Coastguard Worker# we need first convert PyObjects into simple C++ objects. This work
579*da0073e9SAndroid Build Coastguard Worker# is done by PythonArgParser.
580*da0073e9SAndroid Build Coastguard Worker# This data model is used to represent the output of PythonArgParser.
581*da0073e9SAndroid Build Coastguard Worker# It has 1-1 mapping with PythonArgument in PythonSignature.
582*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
583*da0073e9SAndroid Build Coastguard Workerclass PythonArgParserOutputExpr:
584*da0073e9SAndroid Build Coastguard Worker    # argument name
585*da0073e9SAndroid Build Coastguard Worker    name: str
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    # RHS expression to reference PythonArgParser output.
588*da0073e9SAndroid Build Coastguard Worker    expr: str
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker    # In some special cases we need create different expr, e.g.:
591*da0073e9SAndroid Build Coastguard Worker    # '_r.isNone(1)' instead of '_r.tensor(1)'.
592*da0073e9SAndroid Build Coastguard Worker    index: int
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    # The python argument it maps to.
595*da0073e9SAndroid Build Coastguard Worker    argument: PythonArgument
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker    @property
598*da0073e9SAndroid Build Coastguard Worker    def is_none_expr(self) -> str:
599*da0073e9SAndroid Build Coastguard Worker        return f"_r.isNone({self.index})"
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker# To pass PythonArgParser output to the lambda wrapper, we need bind
603*da0073e9SAndroid Build Coastguard Worker# PythonArgParserOutputExpr to DispatchLambdaArgument.
604*da0073e9SAndroid Build Coastguard Worker# They are not always 1-1 mapped, e.g. scattered TensorOptions fields
605*da0073e9SAndroid Build Coastguard Worker# need be packed into a TensorOptions object, which is the argument
606*da0073e9SAndroid Build Coastguard Worker# that the lambda function wrapper takes.
607*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
608*da0073e9SAndroid Build Coastguard Workerclass DispatchLambdaArgumentExprs:
609*da0073e9SAndroid Build Coastguard Worker    # The exprs that provide the binding for lambda arguments, e.g.:
610*da0073e9SAndroid Build Coastguard Worker    #
611*da0073e9SAndroid Build Coastguard Worker    #   'self' -> '_r.tensor(0)'
612*da0073e9SAndroid Build Coastguard Worker    #   'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
613*da0073e9SAndroid Build Coastguard Worker    #   'options' -> 'options'
614*da0073e9SAndroid Build Coastguard Worker    #
615*da0073e9SAndroid Build Coastguard Worker    # It has 1-1 mapping with DispatchLambdaArgument.
616*da0073e9SAndroid Build Coastguard Worker    exprs: Sequence[str]
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    # Special local inits, which might introduce new variables that
619*da0073e9SAndroid Build Coastguard Worker    # the 'exprs' above reference, e.g.:
620*da0073e9SAndroid Build Coastguard Worker    #
621*da0073e9SAndroid Build Coastguard Worker    #   'auto out = _r.tensorlist_n<2>(2);'
622*da0073e9SAndroid Build Coastguard Worker    #
623*da0073e9SAndroid Build Coastguard Worker    inits: Sequence[str]
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
627*da0073e9SAndroid Build Coastguard Worker#
628*da0073e9SAndroid Build Coastguard Worker#                          Helper Functions
629*da0073e9SAndroid Build Coastguard Worker#
630*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Workerdef _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
634*da0073e9SAndroid Build Coastguard Worker    return CppSignatureGroup.from_native_function(f, method=method).signature
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Workerdef has_tensor_options(f: NativeFunction) -> bool:
638*da0073e9SAndroid Build Coastguard Worker    return f.func.arguments.tensor_options is not None
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
642*da0073e9SAndroid Build Coastguard Worker#
643*da0073e9SAndroid Build Coastguard Worker#                          Python Signature
644*da0073e9SAndroid Build Coastguard Worker#
645*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker# 'simple_type' was introduced by the old codegen, which is slightly
649*da0073e9SAndroid Build Coastguard Worker# different from the python schema type, e.g.: doesn't have '?' suffix
650*da0073e9SAndroid Build Coastguard Worker# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
651*da0073e9SAndroid Build Coastguard Workerdef argument_type_str(
652*da0073e9SAndroid Build Coastguard Worker    t: Type, *, simple_type: bool = False, symint: bool = True
653*da0073e9SAndroid Build Coastguard Worker) -> str:
654*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
655*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.Tensor:
656*da0073e9SAndroid Build Coastguard Worker            return "Tensor"
657*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.int:
658*da0073e9SAndroid Build Coastguard Worker            return "int64_t"
659*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.float:
660*da0073e9SAndroid Build Coastguard Worker            return "double"
661*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.str:
662*da0073e9SAndroid Build Coastguard Worker            return "c10::string_view"
663*da0073e9SAndroid Build Coastguard Worker        elif t.name in [
664*da0073e9SAndroid Build Coastguard Worker            BaseTy.bool,
665*da0073e9SAndroid Build Coastguard Worker            BaseTy.QScheme,
666*da0073e9SAndroid Build Coastguard Worker            BaseTy.Scalar,
667*da0073e9SAndroid Build Coastguard Worker            BaseTy.ScalarType,
668*da0073e9SAndroid Build Coastguard Worker            BaseTy.Generator,
669*da0073e9SAndroid Build Coastguard Worker            BaseTy.Storage,
670*da0073e9SAndroid Build Coastguard Worker            BaseTy.Layout,
671*da0073e9SAndroid Build Coastguard Worker            BaseTy.Device,
672*da0073e9SAndroid Build Coastguard Worker            BaseTy.DeviceIndex,
673*da0073e9SAndroid Build Coastguard Worker            BaseTy.MemoryFormat,
674*da0073e9SAndroid Build Coastguard Worker            BaseTy.Dimname,
675*da0073e9SAndroid Build Coastguard Worker            BaseTy.Stream,
676*da0073e9SAndroid Build Coastguard Worker            BaseTy.ConstQuantizerPtr,
677*da0073e9SAndroid Build Coastguard Worker            BaseTy.SymInt,
678*da0073e9SAndroid Build Coastguard Worker        ]:
679*da0073e9SAndroid Build Coastguard Worker            # These python schema type names line up with their function schema names
680*da0073e9SAndroid Build Coastguard Worker            return t.name.name
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
683*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "Tensor":
684*da0073e9SAndroid Build Coastguard Worker            # Is it desired to keep '?' for simple_type with new style dispatcher?
685*da0073e9SAndroid Build Coastguard Worker            return "Tensor?"
686*da0073e9SAndroid Build Coastguard Worker        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
687*da0073e9SAndroid Build Coastguard Worker        return f"{elem}?"
688*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
689*da0073e9SAndroid Build Coastguard Worker        size = t.size if not simple_type else None
690*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "bool":
691*da0073e9SAndroid Build Coastguard Worker            assert t.size is not None
692*da0073e9SAndroid Build Coastguard Worker            return f"::std::array<bool,{t.size}>"
693*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "int":
694*da0073e9SAndroid Build Coastguard Worker            return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
695*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "SymInt":
696*da0073e9SAndroid Build Coastguard Worker            if symint:
697*da0073e9SAndroid Build Coastguard Worker                return (
698*da0073e9SAndroid Build Coastguard Worker                    f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
699*da0073e9SAndroid Build Coastguard Worker                )
700*da0073e9SAndroid Build Coastguard Worker            else:
701*da0073e9SAndroid Build Coastguard Worker                return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
702*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Tensor":
703*da0073e9SAndroid Build Coastguard Worker            return f"TensorList[{size}]" if size is not None else "TensorList"
704*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Scalar":
705*da0073e9SAndroid Build Coastguard Worker            return f"ScalarList[{size}]" if size is not None else "ScalarList"
706*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Tensor?":
707*da0073e9SAndroid Build Coastguard Worker            if simple_type:
708*da0073e9SAndroid Build Coastguard Worker                return "c10::List<::std::optional<Tensor>>"
709*da0073e9SAndroid Build Coastguard Worker            else:
710*da0073e9SAndroid Build Coastguard Worker                return "const c10::List<::std::optional<Tensor>> &"
711*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Dimname":
712*da0073e9SAndroid Build Coastguard Worker            return f"DimnameList[{size}]" if size is not None else "DimnameList"
713*da0073e9SAndroid Build Coastguard Worker        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
714*da0073e9SAndroid Build Coastguard Worker        return f"ArrayRef<{elem}>"
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(f"unrecognized type {repr(t)}")
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Workerdef argument_type_size(t: Type) -> int | None:
720*da0073e9SAndroid Build Coastguard Worker    l = t.is_list_like()
721*da0073e9SAndroid Build Coastguard Worker    if l is not None and str(l.elem) != "bool":
722*da0073e9SAndroid Build Coastguard Worker        return l.size
723*da0073e9SAndroid Build Coastguard Worker    else:
724*da0073e9SAndroid Build Coastguard Worker        return None
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Workerdef argument(a: Argument) -> PythonArgument:
728*da0073e9SAndroid Build Coastguard Worker    return PythonArgument(
729*da0073e9SAndroid Build Coastguard Worker        name=a.name,
730*da0073e9SAndroid Build Coastguard Worker        type=a.type,
731*da0073e9SAndroid Build Coastguard Worker        # TODO: directly translate a.default to python default
732*da0073e9SAndroid Build Coastguard Worker        default=(
733*da0073e9SAndroid Build Coastguard Worker            str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
734*da0073e9SAndroid Build Coastguard Worker            if a.default is not None
735*da0073e9SAndroid Build Coastguard Worker            else None
736*da0073e9SAndroid Build Coastguard Worker        ),
737*da0073e9SAndroid Build Coastguard Worker        default_init=None,
738*da0073e9SAndroid Build Coastguard Worker    )
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
742*da0073e9SAndroid Build Coastguard Workerdef signature(
743*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction, *, method: bool = False, pyi: bool = False
744*da0073e9SAndroid Build Coastguard Worker) -> PythonSignature:
745*da0073e9SAndroid Build Coastguard Worker    return signature_from_schema(
746*da0073e9SAndroid Build Coastguard Worker        f.func, category_override=f.category_override, method=method, pyi=pyi
747*da0073e9SAndroid Build Coastguard Worker    )
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Workerdef signature_from_schema(
751*da0073e9SAndroid Build Coastguard Worker    func: FunctionSchema,
752*da0073e9SAndroid Build Coastguard Worker    *,
753*da0073e9SAndroid Build Coastguard Worker    category_override: str | None,
754*da0073e9SAndroid Build Coastguard Worker    method: bool = False,
755*da0073e9SAndroid Build Coastguard Worker    pyi: bool = False,
756*da0073e9SAndroid Build Coastguard Worker) -> PythonSignature:
757*da0073e9SAndroid Build Coastguard Worker    args: list[Argument] = []
758*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.pre_self_positional)
759*da0073e9SAndroid Build Coastguard Worker    # Skip SelfArgument if this is method.
760*da0073e9SAndroid Build Coastguard Worker    if not method and func.arguments.self_arg is not None:
761*da0073e9SAndroid Build Coastguard Worker        args.append(func.arguments.self_arg.argument)
762*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.post_self_positional)
763*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.pre_tensor_options_kwarg_only)
764*da0073e9SAndroid Build Coastguard Worker    # Skip TensorOptionsArguments. Python side TensorOptions
765*da0073e9SAndroid Build Coastguard Worker    # arguments are created based on different rules - see below.
766*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.post_tensor_options_kwarg_only)
767*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.out)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker    input_arg_set = {a.name for a in func.arguments.flat_positional}
770*da0073e9SAndroid Build Coastguard Worker    kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
771*da0073e9SAndroid Build Coastguard Worker    out_arg_set = {a.name for a in func.arguments.out}
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker    input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
774*da0073e9SAndroid Build Coastguard Worker    input_kwargs = tuple(
775*da0073e9SAndroid Build Coastguard Worker        map(argument, filter(lambda a: a.name in kwarg_only_set, args))
776*da0073e9SAndroid Build Coastguard Worker    )
777*da0073e9SAndroid Build Coastguard Worker    outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker    # Reintroduce the scattered fields of TensorOptions for Python.
780*da0073e9SAndroid Build Coastguard Worker    # Compared to the cpp counterpart, the python arguments have new property
781*da0073e9SAndroid Build Coastguard Worker    # (default_init) and a new argument 'requires_grad', which require some
782*da0073e9SAndroid Build Coastguard Worker    # special handlings.
783*da0073e9SAndroid Build Coastguard Worker    # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
784*da0073e9SAndroid Build Coastguard Worker    # to the original versions in the yaml, this recreation is a potential
785*da0073e9SAndroid Build Coastguard Worker    # source of drift between eager and JIT. Pull this logic out to a shared place.
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker    has_tensor_input_arg = any(
788*da0073e9SAndroid Build Coastguard Worker        a.type.is_tensor_like() for a in func.arguments.flat_non_out
789*da0073e9SAndroid Build Coastguard Worker    )
790*da0073e9SAndroid Build Coastguard Worker    if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
791*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
792*da0073e9SAndroid Build Coastguard Worker            "argument named requires_grad is reserved, should not explicitly add it in the schema"
793*da0073e9SAndroid Build Coastguard Worker        )
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker    # [old codegen] this probably won't work if one of the returns is not a tensor,
796*da0073e9SAndroid Build Coastguard Worker    # but it will produce a compile-time error that is obvious.
797*da0073e9SAndroid Build Coastguard Worker    has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker    name: str = cpp.name(func)
800*da0073e9SAndroid Build Coastguard Worker    is_factory_function = category_override == "factory" or (
801*da0073e9SAndroid Build Coastguard Worker        has_tensor_return and not has_tensor_input_arg
802*da0073e9SAndroid Build Coastguard Worker    )
803*da0073e9SAndroid Build Coastguard Worker    is_like_or_new_function = (
804*da0073e9SAndroid Build Coastguard Worker        category_override in ("new", "like")
805*da0073e9SAndroid Build Coastguard Worker        or name.startswith("new_")
806*da0073e9SAndroid Build Coastguard Worker        or name.endswith("_like")
807*da0073e9SAndroid Build Coastguard Worker    )
808*da0073e9SAndroid Build Coastguard Worker    is_dummy_function = category_override == "dummy"
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    tensor_options_args: list[PythonArgument] = []
811*da0073e9SAndroid Build Coastguard Worker    if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker        def topt_default_init(name: str) -> str | None:
814*da0073e9SAndroid Build Coastguard Worker            topt_args = func.arguments.tensor_options
815*da0073e9SAndroid Build Coastguard Worker            if topt_args is None:
816*da0073e9SAndroid Build Coastguard Worker                return None
817*da0073e9SAndroid Build Coastguard Worker            a = getattr(topt_args, name)
818*da0073e9SAndroid Build Coastguard Worker            if a.default is None or a.default == "None":
819*da0073e9SAndroid Build Coastguard Worker                return None
820*da0073e9SAndroid Build Coastguard Worker            return cpp.default_expr(a.default, a.type, symint=False)
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        tensor_options_args.append(
823*da0073e9SAndroid Build Coastguard Worker            PythonArgument(
824*da0073e9SAndroid Build Coastguard Worker                name="dtype",
825*da0073e9SAndroid Build Coastguard Worker                type=OptionalType(BaseType(BaseTy.ScalarType)),
826*da0073e9SAndroid Build Coastguard Worker                default="None",
827*da0073e9SAndroid Build Coastguard Worker                default_init=(
828*da0073e9SAndroid Build Coastguard Worker                    None if is_like_or_new_function else topt_default_init("dtype")
829*da0073e9SAndroid Build Coastguard Worker                ),
830*da0073e9SAndroid Build Coastguard Worker            )
831*da0073e9SAndroid Build Coastguard Worker        )
832*da0073e9SAndroid Build Coastguard Worker        tensor_options_args.append(
833*da0073e9SAndroid Build Coastguard Worker            PythonArgument(
834*da0073e9SAndroid Build Coastguard Worker                name="layout",
835*da0073e9SAndroid Build Coastguard Worker                type=OptionalType(BaseType(BaseTy.Layout)),
836*da0073e9SAndroid Build Coastguard Worker                default="None",
837*da0073e9SAndroid Build Coastguard Worker                default_init=(
838*da0073e9SAndroid Build Coastguard Worker                    None if is_like_or_new_function else topt_default_init("layout")
839*da0073e9SAndroid Build Coastguard Worker                ),
840*da0073e9SAndroid Build Coastguard Worker            )
841*da0073e9SAndroid Build Coastguard Worker        )
842*da0073e9SAndroid Build Coastguard Worker        tensor_options_args.append(
843*da0073e9SAndroid Build Coastguard Worker            PythonArgument(
844*da0073e9SAndroid Build Coastguard Worker                name="device",
845*da0073e9SAndroid Build Coastguard Worker                type=OptionalType(BaseType(BaseTy.Device)),
846*da0073e9SAndroid Build Coastguard Worker                default="None",
847*da0073e9SAndroid Build Coastguard Worker                default_init=(
848*da0073e9SAndroid Build Coastguard Worker                    None
849*da0073e9SAndroid Build Coastguard Worker                    if is_like_or_new_function
850*da0073e9SAndroid Build Coastguard Worker                    else (
851*da0073e9SAndroid Build Coastguard Worker                        topt_default_init("device")
852*da0073e9SAndroid Build Coastguard Worker                        or "torch::tensors::get_default_device()"
853*da0073e9SAndroid Build Coastguard Worker                    )
854*da0073e9SAndroid Build Coastguard Worker                ),
855*da0073e9SAndroid Build Coastguard Worker            )
856*da0073e9SAndroid Build Coastguard Worker        )
857*da0073e9SAndroid Build Coastguard Worker        tensor_options_args.append(
858*da0073e9SAndroid Build Coastguard Worker            PythonArgument(
859*da0073e9SAndroid Build Coastguard Worker                name="pin_memory",
860*da0073e9SAndroid Build Coastguard Worker                type=OptionalType(BaseType(BaseTy.bool)),
861*da0073e9SAndroid Build Coastguard Worker                default="False",
862*da0073e9SAndroid Build Coastguard Worker                default_init=None,
863*da0073e9SAndroid Build Coastguard Worker            )
864*da0073e9SAndroid Build Coastguard Worker        )
865*da0073e9SAndroid Build Coastguard Worker        tensor_options_args.append(
866*da0073e9SAndroid Build Coastguard Worker            PythonArgument(
867*da0073e9SAndroid Build Coastguard Worker                name="requires_grad",
868*da0073e9SAndroid Build Coastguard Worker                type=OptionalType(BaseType(BaseTy.bool)),
869*da0073e9SAndroid Build Coastguard Worker                default="False",
870*da0073e9SAndroid Build Coastguard Worker                default_init=None,
871*da0073e9SAndroid Build Coastguard Worker            )
872*da0073e9SAndroid Build Coastguard Worker        )
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker    returns = PythonReturns(returns=func.returns)
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker    return PythonSignature(
877*da0073e9SAndroid Build Coastguard Worker        name=str(func.name.name),
878*da0073e9SAndroid Build Coastguard Worker        input_args=input_args,
879*da0073e9SAndroid Build Coastguard Worker        input_kwargs=input_kwargs,
880*da0073e9SAndroid Build Coastguard Worker        output_args=PythonOutArgument.from_outputs(outputs),
881*da0073e9SAndroid Build Coastguard Worker        tensor_options_args=tuple(tensor_options_args),
882*da0073e9SAndroid Build Coastguard Worker        returns=returns,
883*da0073e9SAndroid Build Coastguard Worker        method=method,
884*da0073e9SAndroid Build Coastguard Worker    )
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
888*da0073e9SAndroid Build Coastguard Worker#
889*da0073e9SAndroid Build Coastguard Worker#                          Python Interface
890*da0073e9SAndroid Build Coastguard Worker#
891*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Workerdef structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
895*da0073e9SAndroid Build Coastguard Worker    if len(returns) <= 1 or all(r.name is None for r in returns):
896*da0073e9SAndroid Build Coastguard Worker        return []
897*da0073e9SAndroid Build Coastguard Worker    else:
898*da0073e9SAndroid Build Coastguard Worker        if any(r.name is None for r in returns):
899*da0073e9SAndroid Build Coastguard Worker            # When building on Windows, `PyStructSequence_UnnamedField` could not be
900*da0073e9SAndroid Build Coastguard Worker            # resolved by the linker for some reason, which cause error in building:
901*da0073e9SAndroid Build Coastguard Worker            #
902*da0073e9SAndroid Build Coastguard Worker            # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
903*da0073e9SAndroid Build Coastguard Worker            # PyStructSequence_UnnamedField
904*da0073e9SAndroid Build Coastguard Worker            #
905*da0073e9SAndroid Build Coastguard Worker            # Thus, at this point in time, we do not support unnamed
906*da0073e9SAndroid Build Coastguard Worker            # fields in structseq; you must either name all fields,
907*da0073e9SAndroid Build Coastguard Worker            # or none of them.
908*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Unnamed field is not supported by codegen")
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker        return [str(r.name) for r in returns]
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Workerdef argument_type_str_pyi(t: Type) -> str:
914*da0073e9SAndroid Build Coastguard Worker    add_optional = False
915*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, OptionalType):
916*da0073e9SAndroid Build Coastguard Worker        t = t.elem
917*da0073e9SAndroid Build Coastguard Worker        add_optional = True
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
920*da0073e9SAndroid Build Coastguard Worker        if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
921*da0073e9SAndroid Build Coastguard Worker            ret = "_int"
922*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.SymInt:
923*da0073e9SAndroid Build Coastguard Worker            ret = "Union[_int, SymInt]"
924*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.float:
925*da0073e9SAndroid Build Coastguard Worker            ret = "_float"
926*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.str:
927*da0073e9SAndroid Build Coastguard Worker            ret = "str"
928*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Scalar:
929*da0073e9SAndroid Build Coastguard Worker            ret = "Union[Number, _complex]"
930*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.ScalarType:
931*da0073e9SAndroid Build Coastguard Worker            ret = "_dtype"
932*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.bool:
933*da0073e9SAndroid Build Coastguard Worker            ret = "_bool"
934*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.QScheme:
935*da0073e9SAndroid Build Coastguard Worker            ret = "_qscheme"
936*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Layout:
937*da0073e9SAndroid Build Coastguard Worker            ret = "_layout"
938*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Device:
939*da0073e9SAndroid Build Coastguard Worker            ret = "Optional[DeviceLikeType]"
940*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.MemoryFormat:
941*da0073e9SAndroid Build Coastguard Worker            ret = "memory_format"
942*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Dimname:
943*da0073e9SAndroid Build Coastguard Worker            ret = "Union[str, ellipsis, None]"
944*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Storage:
945*da0073e9SAndroid Build Coastguard Worker            ret = "Union[Storage, UntypedStorage]"
946*da0073e9SAndroid Build Coastguard Worker        elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
947*da0073e9SAndroid Build Coastguard Worker            # These python schema type names line up with their function schema names
948*da0073e9SAndroid Build Coastguard Worker            ret = t.name.name
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
951*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "int":
952*da0073e9SAndroid Build Coastguard Worker            ret = "Union[_int, _size]" if t.size is not None else "_size"
953*da0073e9SAndroid Build Coastguard Worker        elif t.is_tensor_like():
954*da0073e9SAndroid Build Coastguard Worker            # TODO: this doesn't seem right...
955*da0073e9SAndroid Build Coastguard Worker            # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
956*da0073e9SAndroid Build Coastguard Worker            # It should probably translate to   Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
957*da0073e9SAndroid Build Coastguard Worker            if isinstance(t.elem, OptionalType):
958*da0073e9SAndroid Build Coastguard Worker                add_optional = True
959*da0073e9SAndroid Build Coastguard Worker            ret = (
960*da0073e9SAndroid Build Coastguard Worker                "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
961*da0073e9SAndroid Build Coastguard Worker                if t.size is not None
962*da0073e9SAndroid Build Coastguard Worker                else "Union[Tuple[Tensor, ...], List[Tensor]]"
963*da0073e9SAndroid Build Coastguard Worker            )
964*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "float":
965*da0073e9SAndroid Build Coastguard Worker            ret = "Sequence[_float]"
966*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "SymInt" and t.size is not None:
967*da0073e9SAndroid Build Coastguard Worker            elem = argument_type_str_pyi(t.elem)
968*da0073e9SAndroid Build Coastguard Worker            ret = f"Union[{elem}, Sequence[{elem}]]"
969*da0073e9SAndroid Build Coastguard Worker        else:
970*da0073e9SAndroid Build Coastguard Worker            elem = argument_type_str_pyi(t.elem)
971*da0073e9SAndroid Build Coastguard Worker            ret = f"Sequence[{elem}]"
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker    else:
974*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"unrecognized type {repr(t)}")
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker    if add_optional:
977*da0073e9SAndroid Build Coastguard Worker        ret = "Optional[" + ret + "]"
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker    return ret
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Workerdef return_type_str_pyi(t: Type) -> str:
983*da0073e9SAndroid Build Coastguard Worker    # Where arguments are open to accepting Union, return types should return
984*da0073e9SAndroid Build Coastguard Worker    # concrete types
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, OptionalType):
987*da0073e9SAndroid Build Coastguard Worker        inner = return_type_str_pyi(t.elem)
988*da0073e9SAndroid Build Coastguard Worker        return f"Optional[{inner}]"
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
991*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.Device:
992*da0073e9SAndroid Build Coastguard Worker            return "_device"
993*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Dimname:
994*da0073e9SAndroid Build Coastguard Worker            ret = "Optional[str]"
995*da0073e9SAndroid Build Coastguard Worker        else:
996*da0073e9SAndroid Build Coastguard Worker            return argument_type_str_pyi(t)
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, ListType):
999*da0073e9SAndroid Build Coastguard Worker        inner = return_type_str_pyi(t.elem)
1000*da0073e9SAndroid Build Coastguard Worker        return f"Tuple[{inner}, ...]"
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker    return argument_type_str_pyi(t)
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Workerdef returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
1006*da0073e9SAndroid Build Coastguard Worker    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1007*da0073e9SAndroid Build Coastguard Worker    structseq_name = signature.name
1008*da0073e9SAndroid Build Coastguard Worker    field_names = structseq_fieldnames(signature.returns.returns)
1009*da0073e9SAndroid Build Coastguard Worker    if field_names:
1010*da0073e9SAndroid Build Coastguard Worker        # These types are structseq objects which act like named NamedTuples, but
1011*da0073e9SAndroid Build Coastguard Worker        # the constructor acts like the constructor of tuple. Using typing.NamedTuple
1012*da0073e9SAndroid Build Coastguard Worker        # does not allow us to override __init__.
1013*da0073e9SAndroid Build Coastguard Worker        seq_type = f"Tuple[{', '.join(python_returns)}]"
1014*da0073e9SAndroid Build Coastguard Worker        structseq_def_lines = [
1015*da0073e9SAndroid Build Coastguard Worker            f"class {structseq_name}({seq_type}):",
1016*da0073e9SAndroid Build Coastguard Worker        ]
1017*da0073e9SAndroid Build Coastguard Worker        for name, typ in zip(field_names, python_returns):
1018*da0073e9SAndroid Build Coastguard Worker            structseq_def_lines.extend(
1019*da0073e9SAndroid Build Coastguard Worker                [
1020*da0073e9SAndroid Build Coastguard Worker                    "    @property",
1021*da0073e9SAndroid Build Coastguard Worker                    f"    def {name}(self) -> {typ}: ...",
1022*da0073e9SAndroid Build Coastguard Worker                ]
1023*da0073e9SAndroid Build Coastguard Worker            )
1024*da0073e9SAndroid Build Coastguard Worker        structseq_def_lines.extend(
1025*da0073e9SAndroid Build Coastguard Worker            [
1026*da0073e9SAndroid Build Coastguard Worker                f"    def __new__(cls, sequence: {seq_type}): ...",
1027*da0073e9SAndroid Build Coastguard Worker                f"    n_fields: _int = {len(field_names)}",
1028*da0073e9SAndroid Build Coastguard Worker                f"    n_sequeunce_fields: _int = {len(field_names)}",
1029*da0073e9SAndroid Build Coastguard Worker                "    n_unnamed_fields: _int = 0",
1030*da0073e9SAndroid Build Coastguard Worker                "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
1031*da0073e9SAndroid Build Coastguard Worker                "",  # add an extra newline
1032*da0073e9SAndroid Build Coastguard Worker            ]
1033*da0073e9SAndroid Build Coastguard Worker        )
1034*da0073e9SAndroid Build Coastguard Worker        structseq_def = "\n".join(structseq_def_lines)
1035*da0073e9SAndroid Build Coastguard Worker        # Example:
1036*da0073e9SAndroid Build Coastguard Worker        # structseq_def = (
1037*da0073e9SAndroid Build Coastguard Worker        #     "class max(Tuple[Tensor, Tensor]):\n"
1038*da0073e9SAndroid Build Coastguard Worker        #     "    @property\n"
1039*da0073e9SAndroid Build Coastguard Worker        #     "    def values(self) -> Tensor: ...\n"
1040*da0073e9SAndroid Build Coastguard Worker        #     "    @property\n"
1041*da0073e9SAndroid Build Coastguard Worker        #     "    def indices(self) -> Tensor: ...\n"
1042*da0073e9SAndroid Build Coastguard Worker        #     "    def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
1043*da0073e9SAndroid Build Coastguard Worker        #     "    n_fields: _int = 2",
1044*da0073e9SAndroid Build Coastguard Worker        #     "    n_sequeunce_fields: _int = 2",
1045*da0073e9SAndroid Build Coastguard Worker        #     "    n_unnamed_fields: _int = 0",
1046*da0073e9SAndroid Build Coastguard Worker        #     "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
1047*da0073e9SAndroid Build Coastguard Worker        # )
1048*da0073e9SAndroid Build Coastguard Worker        return structseq_name, structseq_def
1049*da0073e9SAndroid Build Coastguard Worker    return None
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Workerdef returns_str_pyi(signature: PythonSignature) -> str:
1053*da0073e9SAndroid Build Coastguard Worker    field_names = structseq_fieldnames(signature.returns.returns)
1054*da0073e9SAndroid Build Coastguard Worker    if field_names:
1055*da0073e9SAndroid Build Coastguard Worker        return f"torch.return_types.{signature.name}"
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1058*da0073e9SAndroid Build Coastguard Worker    if len(python_returns) > 1:
1059*da0073e9SAndroid Build Coastguard Worker        return "Tuple[" + ", ".join(python_returns) + "]"
1060*da0073e9SAndroid Build Coastguard Worker    if len(python_returns) == 1:
1061*da0073e9SAndroid Build Coastguard Worker        return python_returns[0]
1062*da0073e9SAndroid Build Coastguard Worker    return "None"
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1066*da0073e9SAndroid Build Coastguard Worker#
1067*da0073e9SAndroid Build Coastguard Worker#                        C++ Function Dispatch
1068*da0073e9SAndroid Build Coastguard Worker#
1069*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1070*da0073e9SAndroid Build Coastguard Worker# This section provides APIs to generate the code that does C++ function
1071*da0073e9SAndroid Build Coastguard Worker# dispatch. The C++ function call is wrapped by a lambda function.
1072*da0073e9SAndroid Build Coastguard Worker# For example:
1073*da0073e9SAndroid Build Coastguard Worker#
1074*da0073e9SAndroid Build Coastguard Worker#    // aten::selu_(Tensor(a!) self) -> Tensor(a!)
1075*da0073e9SAndroid Build Coastguard Worker#    auto dispatch_selu_ = [](Tensor self) -> Tensor {
1076*da0073e9SAndroid Build Coastguard Worker#      pybind11::gil_scoped_release no_gil;
1077*da0073e9SAndroid Build Coastguard Worker#      return at::selu_(self);
1078*da0073e9SAndroid Build Coastguard Worker#    };
1079*da0073e9SAndroid Build Coastguard Worker#
1080*da0073e9SAndroid Build Coastguard Worker# The lambda function's signature follows the C++ signature in common
1081*da0073e9SAndroid Build Coastguard Worker# cases, e.g.:
1082*da0073e9SAndroid Build Coastguard Worker#
1083*da0073e9SAndroid Build Coastguard Worker#   // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
1084*da0073e9SAndroid Build Coastguard Worker#   [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1085*da0073e9SAndroid Build Coastguard Worker#
1086*da0073e9SAndroid Build Coastguard Worker# For out variant the 'out' argument's type is changed from 'Tensor &'
1087*da0073e9SAndroid Build Coastguard Worker# to 'Tensor'. It's because when calling the lambda it passes in the
1088*da0073e9SAndroid Build Coastguard Worker# PythonArgParser output '_r.tensor(3)', which is stack allocated object
1089*da0073e9SAndroid Build Coastguard Worker# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
1090*da0073e9SAndroid Build Coastguard Worker#
1091*da0073e9SAndroid Build Coastguard Worker#   // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
1092*da0073e9SAndroid Build Coastguard Worker#   [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1093*da0073e9SAndroid Build Coastguard Worker#
1094*da0073e9SAndroid Build Coastguard Worker# For multi-output case it can keep using reference type because the
1095*da0073e9SAndroid Build Coastguard Worker# PythonArgParser output has been unpacked to local variables, e.g.:
1096*da0073e9SAndroid Build Coastguard Worker#
1097*da0073e9SAndroid Build Coastguard Worker#   // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
1098*da0073e9SAndroid Build Coastguard Worker#   //     Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
1099*da0073e9SAndroid Build Coastguard Worker#   [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
1100*da0073e9SAndroid Build Coastguard Worker#
1101*da0073e9SAndroid Build Coastguard Worker# For deprecated python signature, it should follow deprecated python arg order.
1102*da0073e9SAndroid Build Coastguard Worker# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Workerdef dispatch_lambda_args(
1106*da0073e9SAndroid Build Coastguard Worker    ps: PythonSignature, f: NativeFunction, symint: bool = True
1107*da0073e9SAndroid Build Coastguard Worker) -> tuple[DispatchLambdaArgument, ...]:
1108*da0073e9SAndroid Build Coastguard Worker    if isinstance(ps, PythonSignatureDeprecated):
1109*da0073e9SAndroid Build Coastguard Worker        schema = ps.deprecated_schema
1110*da0073e9SAndroid Build Coastguard Worker    else:
1111*da0073e9SAndroid Build Coastguard Worker        schema = f.func
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker    # Start with cpp arguments - dispatch lambda signature always include 'self'
1114*da0073e9SAndroid Build Coastguard Worker    cpp_args = cpp.arguments(
1115*da0073e9SAndroid Build Coastguard Worker        arguments=schema.arguments,
1116*da0073e9SAndroid Build Coastguard Worker        faithful=False,
1117*da0073e9SAndroid Build Coastguard Worker        symint=symint,
1118*da0073e9SAndroid Build Coastguard Worker        method=False,
1119*da0073e9SAndroid Build Coastguard Worker        cpp_no_default_args=f.cpp_no_default_args,
1120*da0073e9SAndroid Build Coastguard Worker    )
1121*da0073e9SAndroid Build Coastguard Worker    out_args: set[str] = {a.name for a in schema.arguments.out}
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker    # Convert from cpp argument to lambda argument
1124*da0073e9SAndroid Build Coastguard Worker    def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
1125*da0073e9SAndroid Build Coastguard Worker        type_str = cpp_arg.type
1126*da0073e9SAndroid Build Coastguard Worker        is_out_arg = cpp_arg.name in out_args
1127*da0073e9SAndroid Build Coastguard Worker        if ps.method and cpp_arg.name == "self":
1128*da0073e9SAndroid Build Coastguard Worker            # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
1129*da0073e9SAndroid Build Coastguard Worker            type_str = "const at::Tensor &"
1130*da0073e9SAndroid Build Coastguard Worker        else:
1131*da0073e9SAndroid Build Coastguard Worker            # For other cases we need prevent dangling refs to temps (unless it's
1132*da0073e9SAndroid Build Coastguard Worker            # unpacked scattered output)
1133*da0073e9SAndroid Build Coastguard Worker            # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
1134*da0073e9SAndroid Build Coastguard Worker            # TODO: avoid this special handling?
1135*da0073e9SAndroid Build Coastguard Worker            ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
1136*da0073e9SAndroid Build Coastguard Worker            if ensure_temp_safe:
1137*da0073e9SAndroid Build Coastguard Worker                type_str = {
1138*da0073e9SAndroid Build Coastguard Worker                    "at::Tensor &": "at::Tensor",
1139*da0073e9SAndroid Build Coastguard Worker                }.get(type_str, type_str)
1140*da0073e9SAndroid Build Coastguard Worker        return DispatchLambdaArgument(
1141*da0073e9SAndroid Build Coastguard Worker            name=cpp_arg.name,
1142*da0073e9SAndroid Build Coastguard Worker            type_str=type_str,
1143*da0073e9SAndroid Build Coastguard Worker            is_out_arg=is_out_arg,
1144*da0073e9SAndroid Build Coastguard Worker        )
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker    return tuple(map(dispatch_lambda_arg, cpp_args))
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
1150*da0073e9SAndroid Build Coastguard Worker# it's enough to just extend the list here. Before you do this, make sure
1151*da0073e9SAndroid Build Coastguard Worker# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
1152*da0073e9SAndroid Build Coastguard WorkerSUPPORTED_RETURN_TYPES = {
1153*da0073e9SAndroid Build Coastguard Worker    "at::Tensor",
1154*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor>",
1155*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
1156*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1157*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1158*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1159*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
1160*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
1161*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
1162*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
1163*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<double,int64_t>",
1164*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
1165*da0073e9SAndroid Build Coastguard Worker    "::std::vector<at::Tensor>",
1166*da0073e9SAndroid Build Coastguard Worker    # Needed for flash attention forw/backward
1167*da0073e9SAndroid Build Coastguard Worker    "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
1168*da0073e9SAndroid Build Coastguard Worker    "at::Scalar",
1169*da0073e9SAndroid Build Coastguard Worker    "bool",
1170*da0073e9SAndroid Build Coastguard Worker    "int64_t",
1171*da0073e9SAndroid Build Coastguard Worker    "void*",
1172*da0073e9SAndroid Build Coastguard Worker    "void",
1173*da0073e9SAndroid Build Coastguard Worker    "at::QScheme",
1174*da0073e9SAndroid Build Coastguard Worker    "double",
1175*da0073e9SAndroid Build Coastguard Worker    "at::IntArrayRef",
1176*da0073e9SAndroid Build Coastguard Worker    "at::ScalarType",
1177*da0073e9SAndroid Build Coastguard Worker    "at::Stream",
1178*da0073e9SAndroid Build Coastguard Worker}
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Workerdef dispatch_lambda_return_str(f: NativeFunction) -> str:
1182*da0073e9SAndroid Build Coastguard Worker    # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
1183*da0073e9SAndroid Build Coastguard Worker    # because the dispatch lambdas take mutable arguments *by value*, not
1184*da0073e9SAndroid Build Coastguard Worker    # by reference. If you then return a reference to such an argument, you
1185*da0073e9SAndroid Build Coastguard Worker    # will now have a pointer to a dangling stack entry. Not good.
1186*da0073e9SAndroid Build Coastguard Worker    #
1187*da0073e9SAndroid Build Coastguard Worker    # You want:
1188*da0073e9SAndroid Build Coastguard Worker    #
1189*da0073e9SAndroid Build Coastguard Worker    #   auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
1190*da0073e9SAndroid Build Coastguard Worker    #                                            ^^^^^^
1191*da0073e9SAndroid Build Coastguard Worker    #
1192*da0073e9SAndroid Build Coastguard Worker    # *not*
1193*da0073e9SAndroid Build Coastguard Worker    #
1194*da0073e9SAndroid Build Coastguard Worker    #   auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
1195*da0073e9SAndroid Build Coastguard Worker    #                                            ^^^^^^^
1196*da0073e9SAndroid Build Coastguard Worker    #
1197*da0073e9SAndroid Build Coastguard Worker    # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
1198*da0073e9SAndroid Build Coastguard Worker    # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
1199*da0073e9SAndroid Build Coastguard Worker    # mutable reference to temporary.  Maybe we could assign it to a
1200*da0073e9SAndroid Build Coastguard Worker    # variable itself.)
1201*da0073e9SAndroid Build Coastguard Worker    returns_without_annotation = tuple(
1202*da0073e9SAndroid Build Coastguard Worker        Return(r.name, r.type, None) for r in f.func.returns
1203*da0073e9SAndroid Build Coastguard Worker    )
1204*da0073e9SAndroid Build Coastguard Worker    return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
1205*da0073e9SAndroid Build Coastguard Worker    if return_str not in SUPPORTED_RETURN_TYPES:
1206*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
1207*da0073e9SAndroid Build Coastguard Worker    return return_str
1208*da0073e9SAndroid Build Coastguard Worker
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Workerdef cpp_dispatch_target(f: NativeFunction) -> str:
1211*da0073e9SAndroid Build Coastguard Worker    symint = f.func.has_symint()
1212*da0073e9SAndroid Build Coastguard Worker    name = cpp.name(f.func, symint_overload=symint)
1213*da0073e9SAndroid Build Coastguard Worker    if Variant.method in f.variants:
1214*da0073e9SAndroid Build Coastguard Worker        return f"self.{name}"
1215*da0073e9SAndroid Build Coastguard Worker    if Variant.function in f.variants:
1216*da0073e9SAndroid Build Coastguard Worker        if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
1217*da0073e9SAndroid Build Coastguard Worker            namespace = "torch"
1218*da0073e9SAndroid Build Coastguard Worker        else:
1219*da0073e9SAndroid Build Coastguard Worker            namespace = "at"
1220*da0073e9SAndroid Build Coastguard Worker        return f"{namespace}::{name}"
1221*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Workerdef cpp_dispatch_exprs(
1225*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
1226*da0073e9SAndroid Build Coastguard Worker    *,
1227*da0073e9SAndroid Build Coastguard Worker    python_signature: PythonSignature | None = None,
1228*da0073e9SAndroid Build Coastguard Worker) -> tuple[str, ...]:
1229*da0073e9SAndroid Build Coastguard Worker    cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker    exprs: tuple[str, ...] = ()
1232*da0073e9SAndroid Build Coastguard Worker    if not isinstance(python_signature, PythonSignatureDeprecated):
1233*da0073e9SAndroid Build Coastguard Worker        # By default the exprs are consistent with the C++ signature.
1234*da0073e9SAndroid Build Coastguard Worker        exprs = tuple(a.name for a in cpp_args)
1235*da0073e9SAndroid Build Coastguard Worker    else:
1236*da0073e9SAndroid Build Coastguard Worker        # For deprecated python signature we may need fill in some constants.
1237*da0073e9SAndroid Build Coastguard Worker        exprs = tuple(
1238*da0073e9SAndroid Build Coastguard Worker            filter(
1239*da0073e9SAndroid Build Coastguard Worker                lambda n: n != "out" or f.func.is_out_fn(),
1240*da0073e9SAndroid Build Coastguard Worker                python_signature.deprecated_args_exprs,
1241*da0073e9SAndroid Build Coastguard Worker            )
1242*da0073e9SAndroid Build Coastguard Worker        )
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker    if Variant.method in f.variants:
1245*da0073e9SAndroid Build Coastguard Worker        exprs = tuple(filter("self".__ne__, exprs))
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker    return exprs
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1251*da0073e9SAndroid Build Coastguard Worker#
1252*da0073e9SAndroid Build Coastguard Worker#                     Python / C++ Args Binding
1253*da0073e9SAndroid Build Coastguard Worker#
1254*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Worker# We explicitly enumerate the PythonArgParser unpacking methods for all
1258*da0073e9SAndroid Build Coastguard Worker# supported types. This might be more verbose than necessary, partially
1259*da0073e9SAndroid Build Coastguard Worker# because of the irregularity of unpacking method naming, partially
1260*da0073e9SAndroid Build Coastguard Worker# because we want to mimic the old codegen behavior - to reject
1261*da0073e9SAndroid Build Coastguard Worker# unexpected and/or unsupported cases which the old codegen rejects.
1262*da0073e9SAndroid Build Coastguard Worker# For certain cases it is intentionally more restrictive than necessary,
1263*da0073e9SAndroid Build Coastguard Worker# e.g.: it doesn't accepts doublelist with definite size.
1264*da0073e9SAndroid Build Coastguard Workerdef arg_parser_unpack_method(
1265*da0073e9SAndroid Build Coastguard Worker    t: Type, default: str | None, default_init: str | None, *, symint: bool = True
1266*da0073e9SAndroid Build Coastguard Worker) -> str:
1267*da0073e9SAndroid Build Coastguard Worker    has_default_init = default_init is not None
1268*da0073e9SAndroid Build Coastguard Worker    if has_default_init and str(t) not in (
1269*da0073e9SAndroid Build Coastguard Worker        "ScalarType?",
1270*da0073e9SAndroid Build Coastguard Worker        "ScalarType",
1271*da0073e9SAndroid Build Coastguard Worker        "Device",
1272*da0073e9SAndroid Build Coastguard Worker        "Device?",
1273*da0073e9SAndroid Build Coastguard Worker        "Layout",
1274*da0073e9SAndroid Build Coastguard Worker        "Layout?",
1275*da0073e9SAndroid Build Coastguard Worker        "bool",
1276*da0073e9SAndroid Build Coastguard Worker        "bool?",
1277*da0073e9SAndroid Build Coastguard Worker    ):
1278*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"type '{t}' does not supported unpacking with default")
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
1281*da0073e9SAndroid Build Coastguard Worker        if t.name in [
1282*da0073e9SAndroid Build Coastguard Worker            BaseTy.Tensor,
1283*da0073e9SAndroid Build Coastguard Worker            BaseTy.Stream,
1284*da0073e9SAndroid Build Coastguard Worker            BaseTy.Storage,
1285*da0073e9SAndroid Build Coastguard Worker            BaseTy.Scalar,
1286*da0073e9SAndroid Build Coastguard Worker            BaseTy.Dimname,
1287*da0073e9SAndroid Build Coastguard Worker        ]:
1288*da0073e9SAndroid Build Coastguard Worker            # These unpack methods line up with their schema names
1289*da0073e9SAndroid Build Coastguard Worker            return t.name.name.lower()
1290*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.ScalarType:
1291*da0073e9SAndroid Build Coastguard Worker            return "scalartypeWithDefault" if has_default_init else "scalartype"
1292*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Device:
1293*da0073e9SAndroid Build Coastguard Worker            return "deviceWithDefault" if has_default_init else "device"
1294*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.DeviceIndex:
1295*da0073e9SAndroid Build Coastguard Worker            return "toInt64"
1296*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.int:
1297*da0073e9SAndroid Build Coastguard Worker            return "toInt64"
1298*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.SymInt:
1299*da0073e9SAndroid Build Coastguard Worker            return "toSymInt" if symint else "toInt64"
1300*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.bool:
1301*da0073e9SAndroid Build Coastguard Worker            return "toBoolWithDefault" if has_default_init else "toBool"
1302*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.float:
1303*da0073e9SAndroid Build Coastguard Worker            return "toDouble"
1304*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.str:
1305*da0073e9SAndroid Build Coastguard Worker            return "stringView"
1306*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Layout:
1307*da0073e9SAndroid Build Coastguard Worker            return "layoutWithDefault" if has_default_init else "layout"
1308*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.MemoryFormat:
1309*da0073e9SAndroid Build Coastguard Worker            return "memoryformat"
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
1312*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "Tensor":
1313*da0073e9SAndroid Build Coastguard Worker            return "optionalTensor"
1314*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Generator":
1315*da0073e9SAndroid Build Coastguard Worker            return "generator"
1316*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Dimname[]":
1317*da0073e9SAndroid Build Coastguard Worker            return "toDimnameListOptional"
1318*da0073e9SAndroid Build Coastguard Worker        elif not has_default_init and default in (
1319*da0073e9SAndroid Build Coastguard Worker            None,
1320*da0073e9SAndroid Build Coastguard Worker            "None",
1321*da0073e9SAndroid Build Coastguard Worker            "::std::nullopt",
1322*da0073e9SAndroid Build Coastguard Worker            "std::nullopt",
1323*da0073e9SAndroid Build Coastguard Worker        ):
1324*da0073e9SAndroid Build Coastguard Worker            # If default is None: append 'Optional' to elem's unpacking method
1325*da0073e9SAndroid Build Coastguard Worker            return (
1326*da0073e9SAndroid Build Coastguard Worker                arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
1327*da0073e9SAndroid Build Coastguard Worker            )
1328*da0073e9SAndroid Build Coastguard Worker        else:
1329*da0073e9SAndroid Build Coastguard Worker            # Otherwise, load as underlying type with default
1330*da0073e9SAndroid Build Coastguard Worker            return arg_parser_unpack_method(
1331*da0073e9SAndroid Build Coastguard Worker                t.elem, default, default_init, symint=symint
1332*da0073e9SAndroid Build Coastguard Worker            )
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
1335*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "Tensor":
1336*da0073e9SAndroid Build Coastguard Worker            # accept and use definite size
1337*da0073e9SAndroid Build Coastguard Worker            return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
1338*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Tensor?":
1339*da0073e9SAndroid Build Coastguard Worker            return "list_of_optional_tensors"
1340*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Dimname":
1341*da0073e9SAndroid Build Coastguard Worker            # accept definite size
1342*da0073e9SAndroid Build Coastguard Worker            return "dimnamelist"
1343*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "int":
1344*da0073e9SAndroid Build Coastguard Worker            # accept definite size
1345*da0073e9SAndroid Build Coastguard Worker            return "intlist"
1346*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "float":
1347*da0073e9SAndroid Build Coastguard Worker            return "doublelist"
1348*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "SymInt":
1349*da0073e9SAndroid Build Coastguard Worker            # accept definite size
1350*da0073e9SAndroid Build Coastguard Worker            return "symintlist" if symint else "intlist"
1351*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Scalar":
1352*da0073e9SAndroid Build Coastguard Worker            return "scalarlist"
1353*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
1354*da0073e9SAndroid Build Coastguard Worker
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker# Return RHS expression for python argument using PythonArgParser output.
1357*da0073e9SAndroid Build Coastguard Worker# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
1358*da0073e9SAndroid Build Coastguard Workerdef arg_parser_output_expr(
1359*da0073e9SAndroid Build Coastguard Worker    arg_index: int, a: PythonArgument, *, symint: bool = True
1360*da0073e9SAndroid Build Coastguard Worker) -> PythonArgParserOutputExpr:
1361*da0073e9SAndroid Build Coastguard Worker    has_default = a.default_init is not None
1362*da0073e9SAndroid Build Coastguard Worker    unpack_method = arg_parser_unpack_method(
1363*da0073e9SAndroid Build Coastguard Worker        t=a.type, default=a.default, default_init=a.default_init, symint=symint
1364*da0073e9SAndroid Build Coastguard Worker    )
1365*da0073e9SAndroid Build Coastguard Worker    default = f", {a.default_init}" if has_default else ""
1366*da0073e9SAndroid Build Coastguard Worker    expr = f"_r.{unpack_method}({arg_index}{default})"
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker    return PythonArgParserOutputExpr(
1369*da0073e9SAndroid Build Coastguard Worker        name=a.name,
1370*da0073e9SAndroid Build Coastguard Worker        expr=expr,
1371*da0073e9SAndroid Build Coastguard Worker        index=arg_index,
1372*da0073e9SAndroid Build Coastguard Worker        argument=a,
1373*da0073e9SAndroid Build Coastguard Worker    )
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker
1376*da0073e9SAndroid Build Coastguard Worker# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
1377*da0073e9SAndroid Build Coastguard Workerdef arg_parser_output_exprs(
1378*da0073e9SAndroid Build Coastguard Worker    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1379*da0073e9SAndroid Build Coastguard Worker) -> dict[str, PythonArgParserOutputExpr]:
1380*da0073e9SAndroid Build Coastguard Worker    return {
1381*da0073e9SAndroid Build Coastguard Worker        e.name: e
1382*da0073e9SAndroid Build Coastguard Worker        for i, a in enumerate(ps.arguments())
1383*da0073e9SAndroid Build Coastguard Worker        for e in (arg_parser_output_expr(i, a, symint=symint),)
1384*da0073e9SAndroid Build Coastguard Worker    }
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker# argument name to type for scattered tensor options fields
1388*da0073e9SAndroid Build Coastguard WorkerTENSOR_OPTIONS_FIELDS = {
1389*da0073e9SAndroid Build Coastguard Worker    "dtype": "ScalarType?",
1390*da0073e9SAndroid Build Coastguard Worker    "device": "Device?",
1391*da0073e9SAndroid Build Coastguard Worker    "layout": "Layout?",
1392*da0073e9SAndroid Build Coastguard Worker    "pin_memory": "bool?",
1393*da0073e9SAndroid Build Coastguard Worker    "requires_grad": "bool?",
1394*da0073e9SAndroid Build Coastguard Worker}
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
1398*da0073e9SAndroid Build Coastguard Workerdef dispatch_lambda_exprs(
1399*da0073e9SAndroid Build Coastguard Worker    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1400*da0073e9SAndroid Build Coastguard Worker) -> DispatchLambdaArgumentExprs:
1401*da0073e9SAndroid Build Coastguard Worker    # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
1402*da0073e9SAndroid Build Coastguard Worker    # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
1403*da0073e9SAndroid Build Coastguard Worker    # outputs.
1404*da0073e9SAndroid Build Coastguard Worker    arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
1405*da0073e9SAndroid Build Coastguard Worker    lambda_args = dispatch_lambda_args(ps, f, symint=symint)
1406*da0073e9SAndroid Build Coastguard Worker    inits: list[str] = []
1407*da0073e9SAndroid Build Coastguard Worker    lambda_args_exprs: dict[str, str] = {}
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker    has_toptions = has_tensor_options(f)
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker    # 1. special inits/unpacking to provide binding exprs for lambda arguments.
1412*da0073e9SAndroid Build Coastguard Worker    for a in ps.arguments(skip_tensor_options=True):
1413*da0073e9SAndroid Build Coastguard Worker        name = a.name
1414*da0073e9SAndroid Build Coastguard Worker        arg_parser_expr = arg_parser_outputs[a.name].expr
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker        if has_toptions and name == "self":
1417*da0073e9SAndroid Build Coastguard Worker            # TODO: why this needs to be special case?
1418*da0073e9SAndroid Build Coastguard Worker            inits.extend(
1419*da0073e9SAndroid Build Coastguard Worker                [
1420*da0073e9SAndroid Build Coastguard Worker                    f"auto self = {arg_parser_expr};",
1421*da0073e9SAndroid Build Coastguard Worker                ]
1422*da0073e9SAndroid Build Coastguard Worker            )
1423*da0073e9SAndroid Build Coastguard Worker            lambda_args_exprs[name] = name
1424*da0073e9SAndroid Build Coastguard Worker        elif (
1425*da0073e9SAndroid Build Coastguard Worker            isinstance(a, PythonOutArgument)
1426*da0073e9SAndroid Build Coastguard Worker            and len(a.outputs) > 1
1427*da0073e9SAndroid Build Coastguard Worker            and f.func.is_out_fn()
1428*da0073e9SAndroid Build Coastguard Worker        ):
1429*da0073e9SAndroid Build Coastguard Worker            inits.extend(
1430*da0073e9SAndroid Build Coastguard Worker                [
1431*da0073e9SAndroid Build Coastguard Worker                    f"auto out = {arg_parser_expr};",
1432*da0073e9SAndroid Build Coastguard Worker                ]
1433*da0073e9SAndroid Build Coastguard Worker            )
1434*da0073e9SAndroid Build Coastguard Worker            for i, out_arg in enumerate(a.outputs):
1435*da0073e9SAndroid Build Coastguard Worker                lambda_args_exprs[out_arg.name] = f"out[{i}]"
1436*da0073e9SAndroid Build Coastguard Worker        elif str(a.type) == "Dimname[]?":
1437*da0073e9SAndroid Build Coastguard Worker            # [old codegen]
1438*da0073e9SAndroid Build Coastguard Worker            # TODO: make this part of something more general, or get rid of it.
1439*da0073e9SAndroid Build Coastguard Worker            # optional<ArrayRef<T>> are special. The PythonArgParser returns an
1440*da0073e9SAndroid Build Coastguard Worker            # optional<vector<T>>, which cannot be implicitly converted to
1441*da0073e9SAndroid Build Coastguard Worker            # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
1442*da0073e9SAndroid Build Coastguard Worker            inits.extend(
1443*da0073e9SAndroid Build Coastguard Worker                [
1444*da0073e9SAndroid Build Coastguard Worker                    f"auto __{name} = {arg_parser_expr};",
1445*da0073e9SAndroid Build Coastguard Worker                    f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;",  # noqa: B950
1446*da0073e9SAndroid Build Coastguard Worker                ]
1447*da0073e9SAndroid Build Coastguard Worker            )
1448*da0073e9SAndroid Build Coastguard Worker            lambda_args_exprs[name] = name
1449*da0073e9SAndroid Build Coastguard Worker        else:
1450*da0073e9SAndroid Build Coastguard Worker            # default case - directly using PythonArgParser output expr
1451*da0073e9SAndroid Build Coastguard Worker            lambda_args_exprs[name] = arg_parser_expr
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker    # method's self is passed directly to python binding, rather than parsed
1454*da0073e9SAndroid Build Coastguard Worker    if ps.method:
1455*da0073e9SAndroid Build Coastguard Worker        lambda_args_exprs["self"] = "self"
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    # 2. special packing/checking for TensorOptions.
1458*da0073e9SAndroid Build Coastguard Worker    tensor_options_args_names = [a.name for a in ps.tensor_options_args]
1459*da0073e9SAndroid Build Coastguard Worker    if has_toptions:
1460*da0073e9SAndroid Build Coastguard Worker        if f.func.is_out_fn():
1461*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"{f.func}: tensor options with output arg")
1462*da0073e9SAndroid Build Coastguard Worker        for a in ps.tensor_options_args:
1463*da0073e9SAndroid Build Coastguard Worker            if a.name not in TENSOR_OPTIONS_FIELDS:
1464*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1465*da0073e9SAndroid Build Coastguard Worker                    f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
1466*da0073e9SAndroid Build Coastguard Worker                )
1467*da0073e9SAndroid Build Coastguard Worker            if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
1468*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1469*da0073e9SAndroid Build Coastguard Worker                    f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
1470*da0073e9SAndroid Build Coastguard Worker                )
1471*da0073e9SAndroid Build Coastguard Worker        if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
1472*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1473*da0073e9SAndroid Build Coastguard Worker                f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
1474*da0073e9SAndroid Build Coastguard Worker            )
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker        inits.append(
1477*da0073e9SAndroid Build Coastguard Worker            f"""\
1478*da0073e9SAndroid Build Coastguard Workerconst auto options = TensorOptions()
1479*da0073e9SAndroid Build Coastguard Worker    .dtype({arg_parser_outputs['dtype'].expr})
1480*da0073e9SAndroid Build Coastguard Worker    .device({arg_parser_outputs['device'].expr})
1481*da0073e9SAndroid Build Coastguard Worker    .layout({arg_parser_outputs['layout'].expr})
1482*da0073e9SAndroid Build Coastguard Worker    .requires_grad({arg_parser_outputs['requires_grad'].expr})
1483*da0073e9SAndroid Build Coastguard Worker    .pinned_memory({arg_parser_outputs['pin_memory'].expr});
1484*da0073e9SAndroid Build Coastguard Workertorch::utils::maybe_initialize_device(options);
1485*da0073e9SAndroid Build Coastguard Worker"""
1486*da0073e9SAndroid Build Coastguard Worker        )
1487*da0073e9SAndroid Build Coastguard Worker        lambda_args_exprs["options"] = "options"
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker    # 3. special case - access scattered TensorOptions fields without packing
1490*da0073e9SAndroid Build Coastguard Worker    # TODO: maybe move to the generator side as it's not related to binding.
1491*da0073e9SAndroid Build Coastguard Worker    if not has_toptions and tensor_options_args_names:
1492*da0073e9SAndroid Build Coastguard Worker        if "dtype" in tensor_options_args_names:
1493*da0073e9SAndroid Build Coastguard Worker            # we're an output-arg variant, check these args against output tensor
1494*da0073e9SAndroid Build Coastguard Worker            if not f.func.is_out_fn():
1495*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1496*da0073e9SAndroid Build Coastguard Worker                    f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
1497*da0073e9SAndroid Build Coastguard Worker                )
1498*da0073e9SAndroid Build Coastguard Worker            if not all(a in tensor_options_args_names for a in ("layout", "device")):
1499*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1500*da0073e9SAndroid Build Coastguard Worker                    f"{f.func}: incomplete tensor options for output check"
1501*da0073e9SAndroid Build Coastguard Worker                )
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker            inits.append(
1504*da0073e9SAndroid Build Coastguard Worker                f"""\
1505*da0073e9SAndroid Build Coastguard Workercheck_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
1506*da0073e9SAndroid Build Coastguard Worker                       {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
1507*da0073e9SAndroid Build Coastguard Worker                       {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
1508*da0073e9SAndroid Build Coastguard Worker"""
1509*da0073e9SAndroid Build Coastguard Worker            )
1510*da0073e9SAndroid Build Coastguard Worker        # we'll set requires_grad on outgoing tensor
1511*da0073e9SAndroid Build Coastguard Worker        if "requires_grad" not in tensor_options_args_names:
1512*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1513*da0073e9SAndroid Build Coastguard Worker                f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
1514*da0073e9SAndroid Build Coastguard Worker            )
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker    return DispatchLambdaArgumentExprs(
1517*da0073e9SAndroid Build Coastguard Worker        exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
1518*da0073e9SAndroid Build Coastguard Worker        inits=inits,
1519*da0073e9SAndroid Build Coastguard Worker    )
1520