xref: /aosp_15_r20/external/pytorch/tools/pyi/gen_pyi.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport argparse
4*da0073e9SAndroid Build Coastguard Workerimport collections
5*da0073e9SAndroid Build Coastguard Workerimport importlib
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerfrom pprint import pformat
8*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence
9*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import Mock, patch
10*da0073e9SAndroid Build Coastguard Workerfrom warnings import warn
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerfrom tools.autograd.gen_python_functions import (
13*da0073e9SAndroid Build Coastguard Worker    group_overloads,
14*da0073e9SAndroid Build Coastguard Worker    load_signatures,
15*da0073e9SAndroid Build Coastguard Worker    should_generate_py_binding,
16*da0073e9SAndroid Build Coastguard Worker)
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.python import (
19*da0073e9SAndroid Build Coastguard Worker    PythonSignatureGroup,
20*da0073e9SAndroid Build Coastguard Worker    PythonSignatureNativeFunctionPair,
21*da0073e9SAndroid Build Coastguard Worker    returns_structseq_pyi,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import parse_native_yaml, parse_tags_yaml
24*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
25*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import FileManager
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker"""
29*da0073e9SAndroid Build Coastguard WorkerThis module implements generation of type stubs for PyTorch,
30*da0073e9SAndroid Build Coastguard Workerenabling use of autocomplete in IDEs like PyCharm, which otherwise
31*da0073e9SAndroid Build Coastguard Workerdon't understand C extension modules.
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard WorkerAt the moment, this module only handles type stubs for torch and
34*da0073e9SAndroid Build Coastguard Workertorch.Tensor.  It should eventually be expanded to cover all functions
35*da0073e9SAndroid Build Coastguard Workerwhich come are autogenerated.
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard WorkerHere's our general strategy:
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker- We start off with a hand-written __init__.pyi.in file.  This
40*da0073e9SAndroid Build Coastguard Worker  file contains type definitions for everything we cannot automatically
41*da0073e9SAndroid Build Coastguard Worker  generate, including pure Python definitions directly in __init__.py
42*da0073e9SAndroid Build Coastguard Worker  (the latter case should be pretty rare).
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker- We go through automatically bound functions based on the
45*da0073e9SAndroid Build Coastguard Worker  type information recorded in native_functions.yaml and
46*da0073e9SAndroid Build Coastguard Worker  generate type hints for them (generate_type_hints)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard WorkerThere are a number of type hints which we've special-cased;
49*da0073e9SAndroid Build Coastguard Workerread gen_pyi for the gory details.
50*da0073e9SAndroid Build Coastguard Worker"""
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef get_py_torch_functions(
54*da0073e9SAndroid Build Coastguard Worker    python_funcs: Sequence[PythonSignatureNativeFunctionPair],
55*da0073e9SAndroid Build Coastguard Worker    method: bool = False,
56*da0073e9SAndroid Build Coastguard Worker) -> Sequence[PythonSignatureGroup]:
57*da0073e9SAndroid Build Coastguard Worker    """
58*da0073e9SAndroid Build Coastguard Worker    Get declarations (grouped by name) which should be generated
59*da0073e9SAndroid Build Coastguard Worker    as either functions in the "torch" module or methods on Tensor.
60*da0073e9SAndroid Build Coastguard Worker    """
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
63*da0073e9SAndroid Build Coastguard Worker        return (
64*da0073e9SAndroid Build Coastguard Worker            should_generate_py_binding(python_func.function)
65*da0073e9SAndroid Build Coastguard Worker            and not python_func.function.python_module
66*da0073e9SAndroid Build Coastguard Worker            and Variant.function in python_func.function.variants
67*da0073e9SAndroid Build Coastguard Worker        )
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
70*da0073e9SAndroid Build Coastguard Worker        return (
71*da0073e9SAndroid Build Coastguard Worker            should_generate_py_binding(python_func.function)
72*da0073e9SAndroid Build Coastguard Worker            and not python_func.function.python_module
73*da0073e9SAndroid Build Coastguard Worker            and Variant.method in python_func.function.variants
74*da0073e9SAndroid Build Coastguard Worker        )
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    should_bind = should_bind_method if method else should_bind_function
77*da0073e9SAndroid Build Coastguard Worker    return group_overloads([f for f in python_funcs if should_bind(f)])
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker# TODO: Consider defining some aliases for our Union[...] types, to make
81*da0073e9SAndroid Build Coastguard Worker# the stubs to read on the human eye.
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard WorkerDEVICE_PARAM = "device: Optional[DeviceLikeType] = None"
84*da0073e9SAndroid Build Coastguard WorkerFACTORY_PARAMS = f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False, pin_memory: _bool = False"
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker# NOTE: specifying indices for Tensor.__getitem__
87*da0073e9SAndroid Build Coastguard Worker# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi:
88*da0073e9SAndroid Build Coastguard Worker#
89*da0073e9SAndroid Build Coastguard Worker# key: (
90*da0073e9SAndroid Build Coastguard Worker#     None
91*da0073e9SAndroid Build Coastguard Worker#     | slice
92*da0073e9SAndroid Build Coastguard Worker#     | ellipsis
93*da0073e9SAndroid Build Coastguard Worker#     | SupportsIndex
94*da0073e9SAndroid Build Coastguard Worker#     | _ArrayLikeInt_co
95*da0073e9SAndroid Build Coastguard Worker#     | tuple[None | slice | ellipsis | _ArrayLikeInt_co | SupportsIndex, ...]
96*da0073e9SAndroid Build Coastguard Worker# )
97*da0073e9SAndroid Build Coastguard Worker#
98*da0073e9SAndroid Build Coastguard Worker# where:
99*da0073e9SAndroid Build Coastguard Worker#
100*da0073e9SAndroid Build Coastguard Worker# _ArrayLikeInt_co = _DualArrayLike[
101*da0073e9SAndroid Build Coastguard Worker#     dtype[Union[bool_, integer[Any]]],
102*da0073e9SAndroid Build Coastguard Worker#     Union[bool, int],
103*da0073e9SAndroid Build Coastguard Worker# ]
104*da0073e9SAndroid Build Coastguard Worker#
105*da0073e9SAndroid Build Coastguard Worker# and
106*da0073e9SAndroid Build Coastguard Worker#
107*da0073e9SAndroid Build Coastguard Worker# _DualArrayLike = Union[
108*da0073e9SAndroid Build Coastguard Worker#     _SupportsArray[_DType],
109*da0073e9SAndroid Build Coastguard Worker#     _NestedSequence[_SupportsArray[_DType]],
110*da0073e9SAndroid Build Coastguard Worker#     _T,
111*da0073e9SAndroid Build Coastguard Worker#     _NestedSequence[_T],
112*da0073e9SAndroid Build Coastguard Worker# ]
113*da0073e9SAndroid Build Coastguard Worker#
114*da0073e9SAndroid Build Coastguard Worker# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple.
115*da0073e9SAndroid Build Coastguard Worker# We can substitute and simplify:
116*da0073e9SAndroid Build Coastguard Worker# _SupportsArray -> Tensor
117*da0073e9SAndroid Build Coastguard Worker# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]]
118*da0073e9SAndroid Build Coastguard Worker# which leaves us with key: T | tuple[T, ...], where T is:
119*da0073e9SAndroid Build Coastguard Worker# T = (
120*da0073e9SAndroid Build Coastguard Worker#     None | bool | int | slice | ellipsis | SupportsIndex
121*da0073e9SAndroid Build Coastguard Worker#     | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int]
122*da0073e9SAndroid Build Coastguard Worker# )
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker# NOTE: ellipsis is equal to type[Ellipsis] in stub files.
125*da0073e9SAndroid Build Coastguard Worker_leaf_types = "Union[None, _bool, _int, slice, ellipsis, Tensor]"  # not SupportsIndex!
126*da0073e9SAndroid Build Coastguard Worker_index = f"Union[SupportsIndex, {_leaf_types}, _NestedSequence[{_leaf_types}]]"
127*da0073e9SAndroid Build Coastguard WorkerINDICES = f"indices: Union[{_index}, tuple[{_index}, ...]]"
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Workerblocklist = [
130*da0073e9SAndroid Build Coastguard Worker    "__init_subclass__",
131*da0073e9SAndroid Build Coastguard Worker    "__new__",
132*da0073e9SAndroid Build Coastguard Worker    "__subclasshook__",
133*da0073e9SAndroid Build Coastguard Worker    "cdist",
134*da0073e9SAndroid Build Coastguard Worker    "device",
135*da0073e9SAndroid Build Coastguard Worker    "grad",
136*da0073e9SAndroid Build Coastguard Worker    "requires_grad",
137*da0073e9SAndroid Build Coastguard Worker    "range",
138*da0073e9SAndroid Build Coastguard Worker    # defined in functional
139*da0073e9SAndroid Build Coastguard Worker    "einsum",
140*da0073e9SAndroid Build Coastguard Worker    # Somehow, these are defined in both _C and in functional. Ick!
141*da0073e9SAndroid Build Coastguard Worker    "broadcast_tensors",
142*da0073e9SAndroid Build Coastguard Worker    # Manually define named tensor type stubs in __init__.pyi.in
143*da0073e9SAndroid Build Coastguard Worker    "align_tensors",
144*da0073e9SAndroid Build Coastguard Worker    "meshgrid",
145*da0073e9SAndroid Build Coastguard Worker    "cartesian_prod",
146*da0073e9SAndroid Build Coastguard Worker    "block_diag",
147*da0073e9SAndroid Build Coastguard Worker    "norm",
148*da0073e9SAndroid Build Coastguard Worker    "chain_matmul",
149*da0073e9SAndroid Build Coastguard Worker    "stft",
150*da0073e9SAndroid Build Coastguard Worker    "tensordot",
151*da0073e9SAndroid Build Coastguard Worker    "split",
152*da0073e9SAndroid Build Coastguard Worker    "unique_consecutive",
153*da0073e9SAndroid Build Coastguard Worker    "atleast_1d",
154*da0073e9SAndroid Build Coastguard Worker    "atleast_2d",
155*da0073e9SAndroid Build Coastguard Worker    "atleast_3d",
156*da0073e9SAndroid Build Coastguard Worker    # These are handled specially by python_arg_parser.cpp
157*da0073e9SAndroid Build Coastguard Worker    "add",
158*da0073e9SAndroid Build Coastguard Worker    "add_",
159*da0073e9SAndroid Build Coastguard Worker    "add_out",
160*da0073e9SAndroid Build Coastguard Worker    "sub",
161*da0073e9SAndroid Build Coastguard Worker    "sub_",
162*da0073e9SAndroid Build Coastguard Worker    "sub_out",
163*da0073e9SAndroid Build Coastguard Worker    "mul",
164*da0073e9SAndroid Build Coastguard Worker    "mul_",
165*da0073e9SAndroid Build Coastguard Worker    "mul_out",
166*da0073e9SAndroid Build Coastguard Worker    "div",
167*da0073e9SAndroid Build Coastguard Worker    "div_",
168*da0073e9SAndroid Build Coastguard Worker    "div_out",
169*da0073e9SAndroid Build Coastguard Worker    "true_divide",
170*da0073e9SAndroid Build Coastguard Worker    "true_divide_",
171*da0073e9SAndroid Build Coastguard Worker    "true_divide_out",
172*da0073e9SAndroid Build Coastguard Worker    "floor_divide",
173*da0073e9SAndroid Build Coastguard Worker    "floor_divide_",
174*da0073e9SAndroid Build Coastguard Worker    "floor_divide_out",
175*da0073e9SAndroid Build Coastguard Worker    "to",
176*da0073e9SAndroid Build Coastguard Worker    "_to_copy",
177*da0073e9SAndroid Build Coastguard Worker    "copy_",
178*da0073e9SAndroid Build Coastguard Worker]
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Workerbinary_ops = (
181*da0073e9SAndroid Build Coastguard Worker    "add",
182*da0073e9SAndroid Build Coastguard Worker    "sub",
183*da0073e9SAndroid Build Coastguard Worker    "mul",
184*da0073e9SAndroid Build Coastguard Worker    "div",
185*da0073e9SAndroid Build Coastguard Worker    "pow",
186*da0073e9SAndroid Build Coastguard Worker    "lshift",
187*da0073e9SAndroid Build Coastguard Worker    "rshift",
188*da0073e9SAndroid Build Coastguard Worker    "mod",
189*da0073e9SAndroid Build Coastguard Worker    "truediv",
190*da0073e9SAndroid Build Coastguard Worker    "matmul",
191*da0073e9SAndroid Build Coastguard Worker    "floordiv",
192*da0073e9SAndroid Build Coastguard Worker    "radd",
193*da0073e9SAndroid Build Coastguard Worker    "rsub",
194*da0073e9SAndroid Build Coastguard Worker    "rmul",
195*da0073e9SAndroid Build Coastguard Worker    "rtruediv",
196*da0073e9SAndroid Build Coastguard Worker    "rfloordiv",
197*da0073e9SAndroid Build Coastguard Worker    "rpow",  # reverse arithmetic
198*da0073e9SAndroid Build Coastguard Worker    "and",
199*da0073e9SAndroid Build Coastguard Worker    "or",
200*da0073e9SAndroid Build Coastguard Worker    "xor",
201*da0073e9SAndroid Build Coastguard Worker    "rand",
202*da0073e9SAndroid Build Coastguard Worker    "ror",
203*da0073e9SAndroid Build Coastguard Worker    "rxor",  # logic
204*da0073e9SAndroid Build Coastguard Worker    "iadd",
205*da0073e9SAndroid Build Coastguard Worker    "iand",
206*da0073e9SAndroid Build Coastguard Worker    "idiv",
207*da0073e9SAndroid Build Coastguard Worker    "ilshift",
208*da0073e9SAndroid Build Coastguard Worker    "imul",
209*da0073e9SAndroid Build Coastguard Worker    "ior",
210*da0073e9SAndroid Build Coastguard Worker    "irshift",
211*da0073e9SAndroid Build Coastguard Worker    "isub",
212*da0073e9SAndroid Build Coastguard Worker    "ixor",
213*da0073e9SAndroid Build Coastguard Worker    "ifloordiv",
214*da0073e9SAndroid Build Coastguard Worker    "imod",  # inplace ops
215*da0073e9SAndroid Build Coastguard Worker)
216*da0073e9SAndroid Build Coastguard Workersymmetric_comparison_ops = ("eq", "ne")
217*da0073e9SAndroid Build Coastguard Workerasymmetric_comparison_ops = ("ge", "gt", "lt", "le")
218*da0073e9SAndroid Build Coastguard Workercomparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Workerunary_ops = ("neg", "abs", "invert")
221*da0073e9SAndroid Build Coastguard Workerto_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
222*da0073e9SAndroid Build Coastguard Workerall_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Workerdef sig_for_ops(opname: str) -> list[str]:
226*da0073e9SAndroid Build Coastguard Worker    """sig_for_ops(opname : str) -> List[str]
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    Returns signatures for operator special functions (__add__ etc.)"""
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    # we have to do this by hand, because they are hand-bound in Python
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    name = opname[2:-2]
235*da0073e9SAndroid Build Coastguard Worker    if name in binary_ops:
236*da0073e9SAndroid Build Coastguard Worker        return [f"def {opname}(self, other: Any) -> Tensor: ..."]
237*da0073e9SAndroid Build Coastguard Worker    elif name in comparison_ops:
238*da0073e9SAndroid Build Coastguard Worker        sig = f"def {opname}(self, other: Any) -> Tensor: ..."
239*da0073e9SAndroid Build Coastguard Worker        if name in symmetric_comparison_ops:
240*da0073e9SAndroid Build Coastguard Worker            # unsafe override https://github.com/python/mypy/issues/5704
241*da0073e9SAndroid Build Coastguard Worker            sig += "  # type: ignore[override]"
242*da0073e9SAndroid Build Coastguard Worker        return [sig]
243*da0073e9SAndroid Build Coastguard Worker    elif name in unary_ops:
244*da0073e9SAndroid Build Coastguard Worker        return [f"def {opname}(self) -> Tensor: ..."]
245*da0073e9SAndroid Build Coastguard Worker    elif name in to_py_type_ops:
246*da0073e9SAndroid Build Coastguard Worker        if name in {"bool", "float", "complex"}:
247*da0073e9SAndroid Build Coastguard Worker            tname = name
248*da0073e9SAndroid Build Coastguard Worker        elif name == "nonzero":
249*da0073e9SAndroid Build Coastguard Worker            tname = "bool"
250*da0073e9SAndroid Build Coastguard Worker        else:
251*da0073e9SAndroid Build Coastguard Worker            tname = "int"
252*da0073e9SAndroid Build Coastguard Worker        if tname in {"float", "int", "bool", "complex"}:
253*da0073e9SAndroid Build Coastguard Worker            tname = "builtins." + tname
254*da0073e9SAndroid Build Coastguard Worker        return [f"def {opname}(self) -> {tname}: ..."]
255*da0073e9SAndroid Build Coastguard Worker    else:
256*da0073e9SAndroid Build Coastguard Worker        raise Exception("unknown op", opname)  # noqa: TRY002
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Workerdef generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
260*da0073e9SAndroid Build Coastguard Worker    type_hints: list[str] = []
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker    # Some deprecated ops that are on the blocklist are still included in pyi
263*da0073e9SAndroid Build Coastguard Worker    if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
264*da0073e9SAndroid Build Coastguard Worker        return type_hints
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    # deprecated signatures have separate entries for their functional and out variants
267*da0073e9SAndroid Build Coastguard Worker    # (as opposed to the native ops, which fuse the two into a single signature).
268*da0073e9SAndroid Build Coastguard Worker    # generate the functional variant here, if an out variant exists.
269*da0073e9SAndroid Build Coastguard Worker    if sig_group.signature.deprecated and sig_group.outplace is not None:
270*da0073e9SAndroid Build Coastguard Worker        type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
271*da0073e9SAndroid Build Coastguard Worker        type_hints.append(type_hint)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
274*da0073e9SAndroid Build Coastguard Worker    # Generates the out variant if one exists. Otherwise, generate the functional variant
275*da0073e9SAndroid Build Coastguard Worker    type_hint = sig_group.signature.signature_str_pyi(
276*da0073e9SAndroid Build Coastguard Worker        skip_outputs=sig_group.outplace is None
277*da0073e9SAndroid Build Coastguard Worker    )
278*da0073e9SAndroid Build Coastguard Worker    type_hints.append(type_hint)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    # Some operators also additionally have a vararg variant of their signature
281*da0073e9SAndroid Build Coastguard Worker    type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
282*da0073e9SAndroid Build Coastguard Worker        skip_outputs=sig_group.outplace is None
283*da0073e9SAndroid Build Coastguard Worker    )
284*da0073e9SAndroid Build Coastguard Worker    if type_hint_vararg:
285*da0073e9SAndroid Build Coastguard Worker        type_hints.append(type_hint_vararg)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    return type_hints
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Workerdef get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
291*da0073e9SAndroid Build Coastguard Worker    flag_pos = arg_list.index("{return_indices}")
292*da0073e9SAndroid Build Coastguard Worker    # If return_indices is positional arg, everything before should have no default
293*da0073e9SAndroid Build Coastguard Worker    arg_list_positional = (
294*da0073e9SAndroid Build Coastguard Worker        [
295*da0073e9SAndroid Build Coastguard Worker            ", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", "))
296*da0073e9SAndroid Build Coastguard Worker            for arg in arg_list[: flag_pos + 1]
297*da0073e9SAndroid Build Coastguard Worker        ]
298*da0073e9SAndroid Build Coastguard Worker        + ["/"]
299*da0073e9SAndroid Build Coastguard Worker        + arg_list[flag_pos + 1 :]
300*da0073e9SAndroid Build Coastguard Worker    )
301*da0073e9SAndroid Build Coastguard Worker    # Otherwise force return_indices to be kwarg
302*da0073e9SAndroid Build Coastguard Worker    arg_list_keyword = arg_list.copy()
303*da0073e9SAndroid Build Coastguard Worker    arg_list_keyword.insert(flag_pos, "*")
304*da0073e9SAndroid Build Coastguard Worker    tmpl = "def {name}({args}) -> {{return_type}}: ..."
305*da0073e9SAndroid Build Coastguard Worker    return {
306*da0073e9SAndroid Build Coastguard Worker        name: [
307*da0073e9SAndroid Build Coastguard Worker            tmpl.format(name=name, args=", ".join(arg_list)).format(
308*da0073e9SAndroid Build Coastguard Worker                return_indices="return_indices: Literal[False] = False",
309*da0073e9SAndroid Build Coastguard Worker                return_type="Tensor",
310*da0073e9SAndroid Build Coastguard Worker            ),
311*da0073e9SAndroid Build Coastguard Worker            tmpl.format(name=name, args=", ".join(arg_list_positional)).format(
312*da0073e9SAndroid Build Coastguard Worker                return_indices="return_indices: Literal[True]",
313*da0073e9SAndroid Build Coastguard Worker                return_type="Tuple[Tensor, Tensor]",
314*da0073e9SAndroid Build Coastguard Worker            ),
315*da0073e9SAndroid Build Coastguard Worker            tmpl.format(name=name, args=", ".join(arg_list_keyword)).format(
316*da0073e9SAndroid Build Coastguard Worker                return_indices="return_indices: Literal[True]",
317*da0073e9SAndroid Build Coastguard Worker                return_type="Tuple[Tensor, Tensor]",
318*da0073e9SAndroid Build Coastguard Worker            ),
319*da0073e9SAndroid Build Coastguard Worker        ]
320*da0073e9SAndroid Build Coastguard Worker    }
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Workerdef gen_nn_functional(fm: FileManager) -> None:
324*da0073e9SAndroid Build Coastguard Worker    INPUT = "input: Tensor"
325*da0073e9SAndroid Build Coastguard Worker    KERNEL_SIZE = "kernel_size: Union[_int, _size]"
326*da0073e9SAndroid Build Coastguard Worker    STRIDE_PADDING = ", ".join(
327*da0073e9SAndroid Build Coastguard Worker        [
328*da0073e9SAndroid Build Coastguard Worker            "stride: Optional[Union[_int, _size]] = None",
329*da0073e9SAndroid Build Coastguard Worker            "padding: Union[_int, _size] = 0",
330*da0073e9SAndroid Build Coastguard Worker        ]
331*da0073e9SAndroid Build Coastguard Worker    )
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    # TODO the list for `torch._C._nn` is nonexhaustive
334*da0073e9SAndroid Build Coastguard Worker    unsorted_c_nn_function_hints: dict[str, list[str]] = {}
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    for d in (2, 3):
337*da0073e9SAndroid Build Coastguard Worker        unsorted_c_nn_function_hints.update(
338*da0073e9SAndroid Build Coastguard Worker            {
339*da0073e9SAndroid Build Coastguard Worker                f"avg_pool{d}d": [
340*da0073e9SAndroid Build Coastguard Worker                    f"def avg_pool{d}d({{}}) -> Tensor: ...".format(
341*da0073e9SAndroid Build Coastguard Worker                        ", ".join(
342*da0073e9SAndroid Build Coastguard Worker                            [
343*da0073e9SAndroid Build Coastguard Worker                                f"{INPUT}",
344*da0073e9SAndroid Build Coastguard Worker                                f"{KERNEL_SIZE}",
345*da0073e9SAndroid Build Coastguard Worker                                f"{STRIDE_PADDING}",
346*da0073e9SAndroid Build Coastguard Worker                                "ceil_mode: bool = False",
347*da0073e9SAndroid Build Coastguard Worker                                "count_include_pad: bool = True",
348*da0073e9SAndroid Build Coastguard Worker                                "divisor_override: Optional[int] = None",
349*da0073e9SAndroid Build Coastguard Worker                            ]
350*da0073e9SAndroid Build Coastguard Worker                        )
351*da0073e9SAndroid Build Coastguard Worker                    )
352*da0073e9SAndroid Build Coastguard Worker                ],
353*da0073e9SAndroid Build Coastguard Worker                f"fractional_max_pool{d}d": [
354*da0073e9SAndroid Build Coastguard Worker                    f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format(
355*da0073e9SAndroid Build Coastguard Worker                        ", ".join(
356*da0073e9SAndroid Build Coastguard Worker                            [
357*da0073e9SAndroid Build Coastguard Worker                                f"{INPUT}",
358*da0073e9SAndroid Build Coastguard Worker                                f"{KERNEL_SIZE}",
359*da0073e9SAndroid Build Coastguard Worker                                "output_size: Union[_int, _size]",
360*da0073e9SAndroid Build Coastguard Worker                                "_random_samples: Tensor",
361*da0073e9SAndroid Build Coastguard Worker                            ]
362*da0073e9SAndroid Build Coastguard Worker                        ),
363*da0073e9SAndroid Build Coastguard Worker                        "Tuple[Tensor, Tensor]",
364*da0073e9SAndroid Build Coastguard Worker                    )
365*da0073e9SAndroid Build Coastguard Worker                ],
366*da0073e9SAndroid Build Coastguard Worker                f"adaptive_max_pool{d}d": [
367*da0073e9SAndroid Build Coastguard Worker                    f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format(
368*da0073e9SAndroid Build Coastguard Worker                        ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]),
369*da0073e9SAndroid Build Coastguard Worker                        "Tuple[Tensor, Tensor]",
370*da0073e9SAndroid Build Coastguard Worker                    )
371*da0073e9SAndroid Build Coastguard Worker                ],
372*da0073e9SAndroid Build Coastguard Worker            }
373*da0073e9SAndroid Build Coastguard Worker        )
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    unsorted_c_nn_function_hints.update(
376*da0073e9SAndroid Build Coastguard Worker        {
377*da0073e9SAndroid Build Coastguard Worker            "hardtanh": [
378*da0073e9SAndroid Build Coastguard Worker                "def hardtanh({}) -> Tensor: ...".format(
379*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
380*da0073e9SAndroid Build Coastguard Worker                        [
381*da0073e9SAndroid Build Coastguard Worker                            "input: Tensor",
382*da0073e9SAndroid Build Coastguard Worker                            "min_val: float = ...",
383*da0073e9SAndroid Build Coastguard Worker                            "max_val: float = ...",
384*da0073e9SAndroid Build Coastguard Worker                            "*",
385*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
386*da0073e9SAndroid Build Coastguard Worker                        ]
387*da0073e9SAndroid Build Coastguard Worker                    )
388*da0073e9SAndroid Build Coastguard Worker                )
389*da0073e9SAndroid Build Coastguard Worker            ],
390*da0073e9SAndroid Build Coastguard Worker            "hardtanh_": [
391*da0073e9SAndroid Build Coastguard Worker                "def hardtanh_({}) -> Tensor: ...".format(
392*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
393*da0073e9SAndroid Build Coastguard Worker                        [
394*da0073e9SAndroid Build Coastguard Worker                            "input: Tensor",
395*da0073e9SAndroid Build Coastguard Worker                            "min_val: float = ...",
396*da0073e9SAndroid Build Coastguard Worker                            "max_val: float = ...",
397*da0073e9SAndroid Build Coastguard Worker                        ]
398*da0073e9SAndroid Build Coastguard Worker                    )
399*da0073e9SAndroid Build Coastguard Worker                )
400*da0073e9SAndroid Build Coastguard Worker            ],
401*da0073e9SAndroid Build Coastguard Worker            "elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."],
402*da0073e9SAndroid Build Coastguard Worker            "leaky_relu": [
403*da0073e9SAndroid Build Coastguard Worker                "def leaky_relu({}) -> Tensor: ...".format(
404*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
405*da0073e9SAndroid Build Coastguard Worker                        [
406*da0073e9SAndroid Build Coastguard Worker                            "input: Tensor",
407*da0073e9SAndroid Build Coastguard Worker                            "negative_slope: float = ...",
408*da0073e9SAndroid Build Coastguard Worker                            "*",
409*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
410*da0073e9SAndroid Build Coastguard Worker                        ]
411*da0073e9SAndroid Build Coastguard Worker                    )
412*da0073e9SAndroid Build Coastguard Worker                )
413*da0073e9SAndroid Build Coastguard Worker            ],
414*da0073e9SAndroid Build Coastguard Worker            "leaky_relu_": [
415*da0073e9SAndroid Build Coastguard Worker                f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..."
416*da0073e9SAndroid Build Coastguard Worker            ],
417*da0073e9SAndroid Build Coastguard Worker            "log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."],
418*da0073e9SAndroid Build Coastguard Worker            "gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."],
419*da0073e9SAndroid Build Coastguard Worker            "softplus": [
420*da0073e9SAndroid Build Coastguard Worker                "def softplus({}) -> Tensor: ...".format(
421*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
422*da0073e9SAndroid Build Coastguard Worker                        ["input: Tensor", "beta: float = ...", "threshold: float = ..."]
423*da0073e9SAndroid Build Coastguard Worker                    )
424*da0073e9SAndroid Build Coastguard Worker                )
425*da0073e9SAndroid Build Coastguard Worker            ],
426*da0073e9SAndroid Build Coastguard Worker            "softshrink": [
427*da0073e9SAndroid Build Coastguard Worker                "def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..."
428*da0073e9SAndroid Build Coastguard Worker            ],
429*da0073e9SAndroid Build Coastguard Worker            "hardsigmoid": [
430*da0073e9SAndroid Build Coastguard Worker                f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..."
431*da0073e9SAndroid Build Coastguard Worker            ],
432*da0073e9SAndroid Build Coastguard Worker            "linear": [
433*da0073e9SAndroid Build Coastguard Worker                "def linear({}) -> Tensor: ...".format(
434*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
435*da0073e9SAndroid Build Coastguard Worker                        [
436*da0073e9SAndroid Build Coastguard Worker                            "input: Tensor",
437*da0073e9SAndroid Build Coastguard Worker                            "weight: Tensor",
438*da0073e9SAndroid Build Coastguard Worker                            "bias: Optional[Tensor] = None",
439*da0073e9SAndroid Build Coastguard Worker                        ]
440*da0073e9SAndroid Build Coastguard Worker                    )
441*da0073e9SAndroid Build Coastguard Worker                )
442*da0073e9SAndroid Build Coastguard Worker            ],
443*da0073e9SAndroid Build Coastguard Worker            "pad": [
444*da0073e9SAndroid Build Coastguard Worker                "def pad({}) -> Tensor: ...".format(
445*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
446*da0073e9SAndroid Build Coastguard Worker                        [
447*da0073e9SAndroid Build Coastguard Worker                            "input: Tensor",
448*da0073e9SAndroid Build Coastguard Worker                            "pad: Sequence[int]",
449*da0073e9SAndroid Build Coastguard Worker                            "mode: str = ...",
450*da0073e9SAndroid Build Coastguard Worker                            "value: Optional[float] = None",
451*da0073e9SAndroid Build Coastguard Worker                        ]
452*da0073e9SAndroid Build Coastguard Worker                    )
453*da0073e9SAndroid Build Coastguard Worker                )
454*da0073e9SAndroid Build Coastguard Worker            ],
455*da0073e9SAndroid Build Coastguard Worker            "one_hot": [
456*da0073e9SAndroid Build Coastguard Worker                "def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..."
457*da0073e9SAndroid Build Coastguard Worker            ],
458*da0073e9SAndroid Build Coastguard Worker            "scaled_dot_product_attention": [
459*da0073e9SAndroid Build Coastguard Worker                "def scaled_dot_product_attention({}) -> Tensor: ...".format(
460*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
461*da0073e9SAndroid Build Coastguard Worker                        [
462*da0073e9SAndroid Build Coastguard Worker                            "query: Tensor",
463*da0073e9SAndroid Build Coastguard Worker                            "key: Tensor",
464*da0073e9SAndroid Build Coastguard Worker                            "value: Tensor",
465*da0073e9SAndroid Build Coastguard Worker                            "attn_mask: Optional[Tensor] = None",
466*da0073e9SAndroid Build Coastguard Worker                            "dropout_p: float = 0.0",
467*da0073e9SAndroid Build Coastguard Worker                            "is_causal: bool = False",
468*da0073e9SAndroid Build Coastguard Worker                            "scale: Optional[float] = None",
469*da0073e9SAndroid Build Coastguard Worker                            "enable_gqa: bool = False",
470*da0073e9SAndroid Build Coastguard Worker                        ]
471*da0073e9SAndroid Build Coastguard Worker                    )
472*da0073e9SAndroid Build Coastguard Worker                )
473*da0073e9SAndroid Build Coastguard Worker            ],
474*da0073e9SAndroid Build Coastguard Worker        }
475*da0073e9SAndroid Build Coastguard Worker    )
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    c_nn_function_hints: list[str] = []
478*da0073e9SAndroid Build Coastguard Worker    for _, hints in sorted(unsorted_c_nn_function_hints.items()):
479*da0073e9SAndroid Build Coastguard Worker        if len(hints) > 1:
480*da0073e9SAndroid Build Coastguard Worker            hints = ["@overload\n" + h for h in hints]
481*da0073e9SAndroid Build Coastguard Worker        c_nn_function_hints += hints
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker    # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
484*da0073e9SAndroid Build Coastguard Worker    # through an `_add_docstr` call
485*da0073e9SAndroid Build Coastguard Worker    torch_imports = [
486*da0073e9SAndroid Build Coastguard Worker        "conv1d",
487*da0073e9SAndroid Build Coastguard Worker        "conv2d",
488*da0073e9SAndroid Build Coastguard Worker        "conv3d",
489*da0073e9SAndroid Build Coastguard Worker        "conv_transpose1d",
490*da0073e9SAndroid Build Coastguard Worker        "conv_transpose2d",
491*da0073e9SAndroid Build Coastguard Worker        "conv_transpose3d",
492*da0073e9SAndroid Build Coastguard Worker        "conv_tbc",
493*da0073e9SAndroid Build Coastguard Worker        "avg_pool1d",
494*da0073e9SAndroid Build Coastguard Worker        "adaptive_avg_pool1d",
495*da0073e9SAndroid Build Coastguard Worker        "relu_",
496*da0073e9SAndroid Build Coastguard Worker        "selu_",
497*da0073e9SAndroid Build Coastguard Worker        "celu_",
498*da0073e9SAndroid Build Coastguard Worker        "prelu",
499*da0073e9SAndroid Build Coastguard Worker        "rrelu_",
500*da0073e9SAndroid Build Coastguard Worker        "hardshrink",
501*da0073e9SAndroid Build Coastguard Worker        "bilinear",
502*da0073e9SAndroid Build Coastguard Worker        "pixel_shuffle",
503*da0073e9SAndroid Build Coastguard Worker        "pixel_unshuffle",
504*da0073e9SAndroid Build Coastguard Worker        "channel_shuffle",
505*da0073e9SAndroid Build Coastguard Worker        "native_channel_shuffle",
506*da0073e9SAndroid Build Coastguard Worker        "pairwise_distance",
507*da0073e9SAndroid Build Coastguard Worker        "pdist",
508*da0073e9SAndroid Build Coastguard Worker        "cosine_similarity",
509*da0073e9SAndroid Build Coastguard Worker    ]
510*da0073e9SAndroid Build Coastguard Worker    imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports]
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    # Functions imported into `torch.nn.functional` from `torch._C._nn`
513*da0073e9SAndroid Build Coastguard Worker    c_nn_imports = [
514*da0073e9SAndroid Build Coastguard Worker        "avg_pool2d",
515*da0073e9SAndroid Build Coastguard Worker        "avg_pool3d",
516*da0073e9SAndroid Build Coastguard Worker        "hardtanh_",
517*da0073e9SAndroid Build Coastguard Worker        "elu_",
518*da0073e9SAndroid Build Coastguard Worker        "leaky_relu_",
519*da0073e9SAndroid Build Coastguard Worker        "gelu",
520*da0073e9SAndroid Build Coastguard Worker        "softplus",
521*da0073e9SAndroid Build Coastguard Worker        "softshrink",
522*da0073e9SAndroid Build Coastguard Worker        "linear",
523*da0073e9SAndroid Build Coastguard Worker        "pad",
524*da0073e9SAndroid Build Coastguard Worker        "one_hot",
525*da0073e9SAndroid Build Coastguard Worker        "scaled_dot_product_attention",
526*da0073e9SAndroid Build Coastguard Worker    ]
527*da0073e9SAndroid Build Coastguard Worker    imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports]
528*da0073e9SAndroid Build Coastguard Worker    # This is from `torch._C._nn` but renamed
529*da0073e9SAndroid Build Coastguard Worker    imported_hints.append(
530*da0073e9SAndroid Build Coastguard Worker        "from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid"
531*da0073e9SAndroid Build Coastguard Worker    )
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
534*da0073e9SAndroid Build Coastguard Worker    unsorted_dispatched_hints: dict[str, list[str]] = {}
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    for d in (1, 2, 3):
537*da0073e9SAndroid Build Coastguard Worker        unsorted_dispatched_hints.update(
538*da0073e9SAndroid Build Coastguard Worker            **get_max_pool_dispatch(
539*da0073e9SAndroid Build Coastguard Worker                f"max_pool{d}d",
540*da0073e9SAndroid Build Coastguard Worker                [
541*da0073e9SAndroid Build Coastguard Worker                    f"{INPUT}",
542*da0073e9SAndroid Build Coastguard Worker                    f"{KERNEL_SIZE}",
543*da0073e9SAndroid Build Coastguard Worker                    f"{STRIDE_PADDING}",
544*da0073e9SAndroid Build Coastguard Worker                    "dilation: Union[_int, _size] = 1",
545*da0073e9SAndroid Build Coastguard Worker                    "ceil_mode: bool = False",
546*da0073e9SAndroid Build Coastguard Worker                    "{return_indices}",
547*da0073e9SAndroid Build Coastguard Worker                ],
548*da0073e9SAndroid Build Coastguard Worker            ),
549*da0073e9SAndroid Build Coastguard Worker            **get_max_pool_dispatch(
550*da0073e9SAndroid Build Coastguard Worker                f"fractional_max_pool{d}d",
551*da0073e9SAndroid Build Coastguard Worker                [
552*da0073e9SAndroid Build Coastguard Worker                    f"{INPUT}",
553*da0073e9SAndroid Build Coastguard Worker                    f"{KERNEL_SIZE}",
554*da0073e9SAndroid Build Coastguard Worker                    "output_size: Optional[Union[_int, _size]] = None",
555*da0073e9SAndroid Build Coastguard Worker                    "output_ratio: Optional[_ratio_any_t] = None",
556*da0073e9SAndroid Build Coastguard Worker                    "{return_indices}",
557*da0073e9SAndroid Build Coastguard Worker                    "_random_samples: Optional[Tensor] = None",
558*da0073e9SAndroid Build Coastguard Worker                ],
559*da0073e9SAndroid Build Coastguard Worker            ),
560*da0073e9SAndroid Build Coastguard Worker            **get_max_pool_dispatch(
561*da0073e9SAndroid Build Coastguard Worker                f"adaptive_max_pool{d}d",
562*da0073e9SAndroid Build Coastguard Worker                [f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"],
563*da0073e9SAndroid Build Coastguard Worker            ),
564*da0073e9SAndroid Build Coastguard Worker        )
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    # There's no fractional_max_pool1d
567*da0073e9SAndroid Build Coastguard Worker    del unsorted_dispatched_hints["fractional_max_pool1d"]
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker    dispatched_hints: list[str] = []
570*da0073e9SAndroid Build Coastguard Worker    for _, hints in sorted(unsorted_dispatched_hints.items()):
571*da0073e9SAndroid Build Coastguard Worker        if len(hints) > 1:
572*da0073e9SAndroid Build Coastguard Worker            hints = ["@overload\n" + h for h in hints]
573*da0073e9SAndroid Build Coastguard Worker        dispatched_hints += hints
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
576*da0073e9SAndroid Build Coastguard Worker        "torch/nn/functional.pyi",
577*da0073e9SAndroid Build Coastguard Worker        "torch/nn/functional.pyi.in",
578*da0073e9SAndroid Build Coastguard Worker        lambda: {
579*da0073e9SAndroid Build Coastguard Worker            "imported_hints": imported_hints,
580*da0073e9SAndroid Build Coastguard Worker            "dispatched_hints": dispatched_hints,
581*da0073e9SAndroid Build Coastguard Worker        },
582*da0073e9SAndroid Build Coastguard Worker    )
583*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
584*da0073e9SAndroid Build Coastguard Worker        "torch/_C/_nn.pyi",
585*da0073e9SAndroid Build Coastguard Worker        "torch/_C/_nn.pyi.in",
586*da0073e9SAndroid Build Coastguard Worker        lambda: {
587*da0073e9SAndroid Build Coastguard Worker            "c_nn_function_hints": c_nn_function_hints,
588*da0073e9SAndroid Build Coastguard Worker        },
589*da0073e9SAndroid Build Coastguard Worker    )
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker"""
593*da0073e9SAndroid Build Coastguard WorkerWe gather the docstrings for torch with the following steps:
594*da0073e9SAndroid Build Coastguard Worker1. Mock torch and torch._C, which are the only dependencies of the docs files
595*da0073e9SAndroid Build Coastguard Worker2. Mock the _add_docstr function to save the docstrings
596*da0073e9SAndroid Build Coastguard Worker3. Import the docs files to trigger mocked _add_docstr and collect docstrings
597*da0073e9SAndroid Build Coastguard Worker"""
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Workerdef gather_docstrs() -> dict[str, str]:
601*da0073e9SAndroid Build Coastguard Worker    docstrs = {}
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker    def mock_add_docstr(func: Mock, docstr: str) -> None:
604*da0073e9SAndroid Build Coastguard Worker        docstrs[func._extract_mock_name()] = docstr.strip()
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    # sys.modules and sys.path are restored after the context manager exits
607*da0073e9SAndroid Build Coastguard Worker    with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
608*da0073e9SAndroid Build Coastguard Worker        # mock the torch module and torch._C._add_docstr
609*da0073e9SAndroid Build Coastguard Worker        sys.modules["torch"] = Mock(name="torch")
610*da0073e9SAndroid Build Coastguard Worker        sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        try:
613*da0073e9SAndroid Build Coastguard Worker            # manually import torch._torch_docs and torch._tensor_docs to trigger
614*da0073e9SAndroid Build Coastguard Worker            # the mocked _add_docstr and collect docstrings
615*da0073e9SAndroid Build Coastguard Worker            sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
616*da0073e9SAndroid Build Coastguard Worker            sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
617*da0073e9SAndroid Build Coastguard Worker        except ModuleNotFoundError:
618*da0073e9SAndroid Build Coastguard Worker            # Gracefully fail if these modules are not importable
619*da0073e9SAndroid Build Coastguard Worker            warn(
620*da0073e9SAndroid Build Coastguard Worker                "Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files."
621*da0073e9SAndroid Build Coastguard Worker            )
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    return docstrs
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Workerdef add_docstr_to_hint(docstr: str, hint: str) -> str:
627*da0073e9SAndroid Build Coastguard Worker    if "..." in hint:  # function or method
628*da0073e9SAndroid Build Coastguard Worker        assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
629*da0073e9SAndroid Build Coastguard Worker        hint = hint[:-3]  # remove "..."
630*da0073e9SAndroid Build Coastguard Worker        return "\n    ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
631*da0073e9SAndroid Build Coastguard Worker    else:  # attribute or property
632*da0073e9SAndroid Build Coastguard Worker        return f'{hint}\nr"""{docstr}"""\n'
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Workerdef gen_pyi(
636*da0073e9SAndroid Build Coastguard Worker    native_yaml_path: str,
637*da0073e9SAndroid Build Coastguard Worker    tags_yaml_path: str,
638*da0073e9SAndroid Build Coastguard Worker    deprecated_yaml_path: str,
639*da0073e9SAndroid Build Coastguard Worker    fm: FileManager,
640*da0073e9SAndroid Build Coastguard Worker) -> None:
641*da0073e9SAndroid Build Coastguard Worker    """gen_pyi()
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker    This function generates a pyi file for torch.
644*da0073e9SAndroid Build Coastguard Worker    """
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker    # Some of this logic overlaps with generate_python_signature in
647*da0073e9SAndroid Build Coastguard Worker    # tools/autograd/gen_python_functions.py; however, this
648*da0073e9SAndroid Build Coastguard Worker    # function is all about generating mypy type signatures, whereas
649*da0073e9SAndroid Build Coastguard Worker    # the other function generates are custom format for argument
650*da0073e9SAndroid Build Coastguard Worker    # checking.  If you are update this, consider if your change
651*da0073e9SAndroid Build Coastguard Worker    # also needs to update the other file.
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker    # Dictionary for NamedTuple definitions
654*da0073e9SAndroid Build Coastguard Worker    structseqs: dict[str, str] = {}
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker    # Generate type signatures for top-level functions
657*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker    unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    for n, n1, n2 in [
662*da0073e9SAndroid Build Coastguard Worker        ("csr", "crow", "col"),
663*da0073e9SAndroid Build Coastguard Worker        ("csc", "ccol", "row"),
664*da0073e9SAndroid Build Coastguard Worker        ("bsr", "crow", "col"),
665*da0073e9SAndroid Build Coastguard Worker        ("bsc", "ccol", "row"),
666*da0073e9SAndroid Build Coastguard Worker    ]:
667*da0073e9SAndroid Build Coastguard Worker        unsorted_function_hints.update(
668*da0073e9SAndroid Build Coastguard Worker            {
669*da0073e9SAndroid Build Coastguard Worker                f"sparse_{n}_tensor": [
670*da0073e9SAndroid Build Coastguard Worker                    f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format(
671*da0073e9SAndroid Build Coastguard Worker                        ", ".join(
672*da0073e9SAndroid Build Coastguard Worker                            [
673*da0073e9SAndroid Build Coastguard Worker                                f"{n1}_indices: Union[Tensor, List]",
674*da0073e9SAndroid Build Coastguard Worker                                f"{n2}_indices: Union[Tensor, List]",
675*da0073e9SAndroid Build Coastguard Worker                                "values: Union[Tensor, List]",
676*da0073e9SAndroid Build Coastguard Worker                                "size: Optional[_size] = None",
677*da0073e9SAndroid Build Coastguard Worker                                "*",
678*da0073e9SAndroid Build Coastguard Worker                                "dtype: Optional[_dtype] = None",
679*da0073e9SAndroid Build Coastguard Worker                                "device: Optional[DeviceLikeType] = None",
680*da0073e9SAndroid Build Coastguard Worker                                "requires_grad: _bool = False",
681*da0073e9SAndroid Build Coastguard Worker                                "check_invariants: Optional[_bool] = None",
682*da0073e9SAndroid Build Coastguard Worker                            ]
683*da0073e9SAndroid Build Coastguard Worker                        ),
684*da0073e9SAndroid Build Coastguard Worker                    )
685*da0073e9SAndroid Build Coastguard Worker                ],
686*da0073e9SAndroid Build Coastguard Worker            }
687*da0073e9SAndroid Build Coastguard Worker        )
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    unsorted_function_hints.update(
690*da0073e9SAndroid Build Coastguard Worker        {
691*da0073e9SAndroid Build Coastguard Worker            "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."],
692*da0073e9SAndroid Build Coastguard Worker            "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."],
693*da0073e9SAndroid Build Coastguard Worker            "asarray": [
694*da0073e9SAndroid Build Coastguard Worker                "def asarray({}) -> Tensor: ...".format(
695*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
696*da0073e9SAndroid Build Coastguard Worker                        [
697*da0073e9SAndroid Build Coastguard Worker                            "obj: Any",
698*da0073e9SAndroid Build Coastguard Worker                            "*",
699*da0073e9SAndroid Build Coastguard Worker                            "dtype: Optional[_dtype] = None",
700*da0073e9SAndroid Build Coastguard Worker                            "device: Optional[DeviceLikeType] = None",
701*da0073e9SAndroid Build Coastguard Worker                            "copy: Optional[_bool] = None",
702*da0073e9SAndroid Build Coastguard Worker                            "requires_grad: _bool = False",
703*da0073e9SAndroid Build Coastguard Worker                        ]
704*da0073e9SAndroid Build Coastguard Worker                    )
705*da0073e9SAndroid Build Coastguard Worker                )
706*da0073e9SAndroid Build Coastguard Worker            ],
707*da0073e9SAndroid Build Coastguard Worker            "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."],
708*da0073e9SAndroid Build Coastguard Worker            "frombuffer": [
709*da0073e9SAndroid Build Coastguard Worker                "def frombuffer({}) -> Tensor: ...".format(
710*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
711*da0073e9SAndroid Build Coastguard Worker                        [
712*da0073e9SAndroid Build Coastguard Worker                            "buffer: Any",
713*da0073e9SAndroid Build Coastguard Worker                            "*",
714*da0073e9SAndroid Build Coastguard Worker                            "dtype: _dtype",
715*da0073e9SAndroid Build Coastguard Worker                            "count: int = -1",
716*da0073e9SAndroid Build Coastguard Worker                            "offset: int = 0",
717*da0073e9SAndroid Build Coastguard Worker                            "requires_grad: _bool = False",
718*da0073e9SAndroid Build Coastguard Worker                        ]
719*da0073e9SAndroid Build Coastguard Worker                    )
720*da0073e9SAndroid Build Coastguard Worker                )
721*da0073e9SAndroid Build Coastguard Worker            ],
722*da0073e9SAndroid Build Coastguard Worker            "numel": ["def numel(self: Tensor) -> _int: ..."],
723*da0073e9SAndroid Build Coastguard Worker            "as_tensor": [
724*da0073e9SAndroid Build Coastguard Worker                "def as_tensor({}) -> Tensor: ...".format(
725*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
726*da0073e9SAndroid Build Coastguard Worker                        [
727*da0073e9SAndroid Build Coastguard Worker                            "data: Any",
728*da0073e9SAndroid Build Coastguard Worker                            "dtype: Optional[_dtype] = None",
729*da0073e9SAndroid Build Coastguard Worker                            DEVICE_PARAM,
730*da0073e9SAndroid Build Coastguard Worker                        ]
731*da0073e9SAndroid Build Coastguard Worker                    )
732*da0073e9SAndroid Build Coastguard Worker                )
733*da0073e9SAndroid Build Coastguard Worker            ],
734*da0073e9SAndroid Build Coastguard Worker            "get_num_threads": ["def get_num_threads() -> _int: ..."],
735*da0073e9SAndroid Build Coastguard Worker            "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."],
736*da0073e9SAndroid Build Coastguard Worker            "init_num_threads": ["def init_num_threads() -> None: ..."],
737*da0073e9SAndroid Build Coastguard Worker            "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."],
738*da0073e9SAndroid Build Coastguard Worker            "set_num_interop_threads": [
739*da0073e9SAndroid Build Coastguard Worker                "def set_num_interop_threads(num: _int) -> None: ..."
740*da0073e9SAndroid Build Coastguard Worker            ],
741*da0073e9SAndroid Build Coastguard Worker            # These functions are explicitly disabled by
742*da0073e9SAndroid Build Coastguard Worker            # SKIP_PYTHON_BINDINGS because they are hand bound.
743*da0073e9SAndroid Build Coastguard Worker            # Correspondingly, we must hand-write their signatures.
744*da0073e9SAndroid Build Coastguard Worker            "tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."],
745*da0073e9SAndroid Build Coastguard Worker            "sparse_coo_tensor": [
746*da0073e9SAndroid Build Coastguard Worker                "def sparse_coo_tensor({}) -> Tensor: ...".format(
747*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
748*da0073e9SAndroid Build Coastguard Worker                        [
749*da0073e9SAndroid Build Coastguard Worker                            "indices: Tensor",
750*da0073e9SAndroid Build Coastguard Worker                            "values: Union[Tensor, List]",
751*da0073e9SAndroid Build Coastguard Worker                            "size: Optional[_size] = None",
752*da0073e9SAndroid Build Coastguard Worker                            "*",
753*da0073e9SAndroid Build Coastguard Worker                            "dtype: Optional[_dtype] = None",
754*da0073e9SAndroid Build Coastguard Worker                            "device: Optional[DeviceLikeType] = None",
755*da0073e9SAndroid Build Coastguard Worker                            "requires_grad: _bool = False",
756*da0073e9SAndroid Build Coastguard Worker                            "check_invariants: Optional[_bool] = None",
757*da0073e9SAndroid Build Coastguard Worker                            "is_coalesced: Optional[_bool] = None",
758*da0073e9SAndroid Build Coastguard Worker                        ]
759*da0073e9SAndroid Build Coastguard Worker                    )
760*da0073e9SAndroid Build Coastguard Worker                )
761*da0073e9SAndroid Build Coastguard Worker            ],
762*da0073e9SAndroid Build Coastguard Worker            "sparse_compressed_tensor": [
763*da0073e9SAndroid Build Coastguard Worker                "def sparse_compressed_tensor({}) -> Tensor: ...".format(
764*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
765*da0073e9SAndroid Build Coastguard Worker                        [
766*da0073e9SAndroid Build Coastguard Worker                            "compressed_indices: Union[Tensor, List]",
767*da0073e9SAndroid Build Coastguard Worker                            "plain_indices: Union[Tensor, List]",
768*da0073e9SAndroid Build Coastguard Worker                            "values: Union[Tensor, List]",
769*da0073e9SAndroid Build Coastguard Worker                            "size: Optional[_size] = None",
770*da0073e9SAndroid Build Coastguard Worker                            "*",
771*da0073e9SAndroid Build Coastguard Worker                            "dtype: Optional[_dtype] = None",
772*da0073e9SAndroid Build Coastguard Worker                            "layout: Optional[_layout] = None",
773*da0073e9SAndroid Build Coastguard Worker                            "device: Optional[DeviceLikeType] = None",
774*da0073e9SAndroid Build Coastguard Worker                            "requires_grad: _bool = False",
775*da0073e9SAndroid Build Coastguard Worker                            "check_invariants: Optional[_bool] = None",
776*da0073e9SAndroid Build Coastguard Worker                        ]
777*da0073e9SAndroid Build Coastguard Worker                    )
778*da0073e9SAndroid Build Coastguard Worker                )
779*da0073e9SAndroid Build Coastguard Worker            ],
780*da0073e9SAndroid Build Coastguard Worker            "_sync": ["def _sync(t: Tensor) -> None: ..."],
781*da0073e9SAndroid Build Coastguard Worker            "_is_functional_tensor": [
782*da0073e9SAndroid Build Coastguard Worker                "def _is_functional_tensor(t: Tensor) -> _bool: ..."
783*da0073e9SAndroid Build Coastguard Worker            ],
784*da0073e9SAndroid Build Coastguard Worker            "_is_functional_tensor_base": [
785*da0073e9SAndroid Build Coastguard Worker                "def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
786*da0073e9SAndroid Build Coastguard Worker            ],
787*da0073e9SAndroid Build Coastguard Worker            "_from_functional_tensor": [
788*da0073e9SAndroid Build Coastguard Worker                "def _from_functional_tensor(t: Tensor) -> Tensor: ..."
789*da0073e9SAndroid Build Coastguard Worker            ],
790*da0073e9SAndroid Build Coastguard Worker            "_to_functional_tensor": [
791*da0073e9SAndroid Build Coastguard Worker                "def _to_functional_tensor(t: Tensor) -> Tensor: ..."
792*da0073e9SAndroid Build Coastguard Worker            ],
793*da0073e9SAndroid Build Coastguard Worker            "_functionalize_replace": [
794*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..."
795*da0073e9SAndroid Build Coastguard Worker            ],
796*da0073e9SAndroid Build Coastguard Worker            "_functionalize_commit_update": [
797*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_commit_update(t: Tensor) -> None: ..."
798*da0073e9SAndroid Build Coastguard Worker            ],
799*da0073e9SAndroid Build Coastguard Worker            "_functionalize_unsafe_set": [
800*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
801*da0073e9SAndroid Build Coastguard Worker            ],
802*da0073e9SAndroid Build Coastguard Worker            "_functionalize_mark_mutation_hidden_from_autograd": [
803*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..."
804*da0073e9SAndroid Build Coastguard Worker            ],
805*da0073e9SAndroid Build Coastguard Worker            "_functionalize_are_all_mutations_hidden_from_autograd": [
806*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ..."
807*da0073e9SAndroid Build Coastguard Worker            ],
808*da0073e9SAndroid Build Coastguard Worker            "_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [
809*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ..."
810*da0073e9SAndroid Build Coastguard Worker            ],
811*da0073e9SAndroid Build Coastguard Worker            "_functionalize_was_inductor_storage_resized": [
812*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ..."
813*da0073e9SAndroid Build Coastguard Worker            ],
814*da0073e9SAndroid Build Coastguard Worker            "_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."],
815*da0073e9SAndroid Build Coastguard Worker            "_functionalize_was_storage_changed": [
816*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..."
817*da0073e9SAndroid Build Coastguard Worker            ],
818*da0073e9SAndroid Build Coastguard Worker            "_functionalize_set_storage_changed": [
819*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..."
820*da0073e9SAndroid Build Coastguard Worker            ],
821*da0073e9SAndroid Build Coastguard Worker            "_functionalize_has_metadata_mutation": [
822*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..."
823*da0073e9SAndroid Build Coastguard Worker            ],
824*da0073e9SAndroid Build Coastguard Worker            "_functionalize_apply_view_metas": [
825*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_apply_view_metas(tensor: Tensor,  base: Tensor) -> Tensor: ..."
826*da0073e9SAndroid Build Coastguard Worker            ],
827*da0073e9SAndroid Build Coastguard Worker            "_functionalize_is_symbolic": [
828*da0073e9SAndroid Build Coastguard Worker                "def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..."
829*da0073e9SAndroid Build Coastguard Worker            ],
830*da0073e9SAndroid Build Coastguard Worker            "_enable_functionalization": [
831*da0073e9SAndroid Build Coastguard Worker                "def _enable_functionalization(*, reapply_views: _bool = False): ..."
832*da0073e9SAndroid Build Coastguard Worker            ],
833*da0073e9SAndroid Build Coastguard Worker            "_disable_functionalization": ["def _disable_functionalization(): ..."],
834*da0073e9SAndroid Build Coastguard Worker            "range": [
835*da0073e9SAndroid Build Coastguard Worker                "def range({}) -> Tensor: ...".format(
836*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
837*da0073e9SAndroid Build Coastguard Worker                        [
838*da0073e9SAndroid Build Coastguard Worker                            "start: Number",
839*da0073e9SAndroid Build Coastguard Worker                            "end: Number",
840*da0073e9SAndroid Build Coastguard Worker                            "step: Number = 1",
841*da0073e9SAndroid Build Coastguard Worker                            "*",
842*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
843*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
844*da0073e9SAndroid Build Coastguard Worker                        ]
845*da0073e9SAndroid Build Coastguard Worker                    )
846*da0073e9SAndroid Build Coastguard Worker                )
847*da0073e9SAndroid Build Coastguard Worker            ],
848*da0073e9SAndroid Build Coastguard Worker            "arange": [
849*da0073e9SAndroid Build Coastguard Worker                "def arange({}) -> Tensor: ...".format(
850*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
851*da0073e9SAndroid Build Coastguard Worker                        [
852*da0073e9SAndroid Build Coastguard Worker                            "start: Number",
853*da0073e9SAndroid Build Coastguard Worker                            "end: Number",
854*da0073e9SAndroid Build Coastguard Worker                            "step: Number",
855*da0073e9SAndroid Build Coastguard Worker                            "*",
856*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
857*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
858*da0073e9SAndroid Build Coastguard Worker                        ]
859*da0073e9SAndroid Build Coastguard Worker                    )
860*da0073e9SAndroid Build Coastguard Worker                ),
861*da0073e9SAndroid Build Coastguard Worker                "def arange({}) -> Tensor: ...".format(
862*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
863*da0073e9SAndroid Build Coastguard Worker                        [
864*da0073e9SAndroid Build Coastguard Worker                            "start: Number",
865*da0073e9SAndroid Build Coastguard Worker                            "end: Number",
866*da0073e9SAndroid Build Coastguard Worker                            "*",
867*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
868*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
869*da0073e9SAndroid Build Coastguard Worker                        ]
870*da0073e9SAndroid Build Coastguard Worker                    )
871*da0073e9SAndroid Build Coastguard Worker                ),
872*da0073e9SAndroid Build Coastguard Worker                "def arange({}) -> Tensor: ...".format(
873*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
874*da0073e9SAndroid Build Coastguard Worker                        [
875*da0073e9SAndroid Build Coastguard Worker                            "end: Number",
876*da0073e9SAndroid Build Coastguard Worker                            "*",
877*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
878*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
879*da0073e9SAndroid Build Coastguard Worker                        ]
880*da0073e9SAndroid Build Coastguard Worker                    )
881*da0073e9SAndroid Build Coastguard Worker                ),
882*da0073e9SAndroid Build Coastguard Worker            ],
883*da0073e9SAndroid Build Coastguard Worker            "linspace": [
884*da0073e9SAndroid Build Coastguard Worker                "def linspace({}) -> Tensor: ...".format(
885*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
886*da0073e9SAndroid Build Coastguard Worker                        [
887*da0073e9SAndroid Build Coastguard Worker                            "start: Number",
888*da0073e9SAndroid Build Coastguard Worker                            "end: Number",
889*da0073e9SAndroid Build Coastguard Worker                            "steps: Optional[_int] = None",
890*da0073e9SAndroid Build Coastguard Worker                            "*",
891*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
892*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
893*da0073e9SAndroid Build Coastguard Worker                        ]
894*da0073e9SAndroid Build Coastguard Worker                    )
895*da0073e9SAndroid Build Coastguard Worker                )
896*da0073e9SAndroid Build Coastguard Worker            ],
897*da0073e9SAndroid Build Coastguard Worker            "logspace": [
898*da0073e9SAndroid Build Coastguard Worker                "def logspace({}) -> Tensor: ...".format(
899*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
900*da0073e9SAndroid Build Coastguard Worker                        [
901*da0073e9SAndroid Build Coastguard Worker                            "start: Number",
902*da0073e9SAndroid Build Coastguard Worker                            "end: Number",
903*da0073e9SAndroid Build Coastguard Worker                            "steps: Optional[_int] = None",
904*da0073e9SAndroid Build Coastguard Worker                            "base: _float = 10.0",
905*da0073e9SAndroid Build Coastguard Worker                            "*",
906*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
907*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
908*da0073e9SAndroid Build Coastguard Worker                        ]
909*da0073e9SAndroid Build Coastguard Worker                    )
910*da0073e9SAndroid Build Coastguard Worker                )
911*da0073e9SAndroid Build Coastguard Worker            ],
912*da0073e9SAndroid Build Coastguard Worker            "randint": [
913*da0073e9SAndroid Build Coastguard Worker                "def randint({}) -> Tensor: ...".format(
914*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
915*da0073e9SAndroid Build Coastguard Worker                        [
916*da0073e9SAndroid Build Coastguard Worker                            "low: _int",
917*da0073e9SAndroid Build Coastguard Worker                            "high: _int",
918*da0073e9SAndroid Build Coastguard Worker                            "size: _size",
919*da0073e9SAndroid Build Coastguard Worker                            "*",
920*da0073e9SAndroid Build Coastguard Worker                            "generator: Optional[Generator] = None",
921*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
922*da0073e9SAndroid Build Coastguard Worker                        ]
923*da0073e9SAndroid Build Coastguard Worker                    )
924*da0073e9SAndroid Build Coastguard Worker                ),
925*da0073e9SAndroid Build Coastguard Worker                "def randint({}) -> Tensor: ...".format(
926*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
927*da0073e9SAndroid Build Coastguard Worker                        [
928*da0073e9SAndroid Build Coastguard Worker                            "high: _int",
929*da0073e9SAndroid Build Coastguard Worker                            "size: _size",
930*da0073e9SAndroid Build Coastguard Worker                            "*",
931*da0073e9SAndroid Build Coastguard Worker                            "generator: Optional[Generator] = None",
932*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
933*da0073e9SAndroid Build Coastguard Worker                        ]
934*da0073e9SAndroid Build Coastguard Worker                    )
935*da0073e9SAndroid Build Coastguard Worker                ),
936*da0073e9SAndroid Build Coastguard Worker            ],
937*da0073e9SAndroid Build Coastguard Worker            "full": [
938*da0073e9SAndroid Build Coastguard Worker                "def full({}) -> Tensor: ...".format(
939*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
940*da0073e9SAndroid Build Coastguard Worker                        [
941*da0073e9SAndroid Build Coastguard Worker                            "size: _size",
942*da0073e9SAndroid Build Coastguard Worker                            "fill_value: Union[Number, _complex]",
943*da0073e9SAndroid Build Coastguard Worker                            "*",
944*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
945*da0073e9SAndroid Build Coastguard Worker                            "layout: _layout = strided",
946*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
947*da0073e9SAndroid Build Coastguard Worker                        ]
948*da0073e9SAndroid Build Coastguard Worker                    )
949*da0073e9SAndroid Build Coastguard Worker                ),
950*da0073e9SAndroid Build Coastguard Worker                "def full({}) -> Tensor: ...".format(
951*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
952*da0073e9SAndroid Build Coastguard Worker                        [
953*da0073e9SAndroid Build Coastguard Worker                            "size: _size",
954*da0073e9SAndroid Build Coastguard Worker                            "fill_value: Union[Number, _complex]",
955*da0073e9SAndroid Build Coastguard Worker                            "*",
956*da0073e9SAndroid Build Coastguard Worker                            "names: List[Union[str, None]]",
957*da0073e9SAndroid Build Coastguard Worker                            "layout: _layout = strided",
958*da0073e9SAndroid Build Coastguard Worker                            FACTORY_PARAMS,
959*da0073e9SAndroid Build Coastguard Worker                        ]
960*da0073e9SAndroid Build Coastguard Worker                    )
961*da0073e9SAndroid Build Coastguard Worker                ),
962*da0073e9SAndroid Build Coastguard Worker            ],
963*da0073e9SAndroid Build Coastguard Worker            "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."],
964*da0073e9SAndroid Build Coastguard Worker            "is_inference_mode_enabled": [
965*da0073e9SAndroid Build Coastguard Worker                "def is_inference_mode_enabled() -> _bool: ..."
966*da0073e9SAndroid Build Coastguard Worker            ],
967*da0073e9SAndroid Build Coastguard Worker            "nonzero": [
968*da0073e9SAndroid Build Coastguard Worker                "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...",
969*da0073e9SAndroid Build Coastguard Worker                "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
970*da0073e9SAndroid Build Coastguard Worker            ],
971*da0073e9SAndroid Build Coastguard Worker            "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
972*da0073e9SAndroid Build Coastguard Worker            "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
973*da0073e9SAndroid Build Coastguard Worker            "saddmm": [
974*da0073e9SAndroid Build Coastguard Worker                "def saddmm({}) -> Tensor: ...".format(
975*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
976*da0073e9SAndroid Build Coastguard Worker                        [
977*da0073e9SAndroid Build Coastguard Worker                            "input: Tensor",
978*da0073e9SAndroid Build Coastguard Worker                            "mat1: Tensor",
979*da0073e9SAndroid Build Coastguard Worker                            "mat2: Tensor",
980*da0073e9SAndroid Build Coastguard Worker                            "*",
981*da0073e9SAndroid Build Coastguard Worker                            "beta: Number = 1",
982*da0073e9SAndroid Build Coastguard Worker                            "alpha: Number = 1",
983*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
984*da0073e9SAndroid Build Coastguard Worker                        ]
985*da0073e9SAndroid Build Coastguard Worker                    )
986*da0073e9SAndroid Build Coastguard Worker                )
987*da0073e9SAndroid Build Coastguard Worker            ],
988*da0073e9SAndroid Build Coastguard Worker            "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
989*da0073e9SAndroid Build Coastguard Worker            "div": [
990*da0073e9SAndroid Build Coastguard Worker                "def div({}) -> Tensor: ...".format(
991*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
992*da0073e9SAndroid Build Coastguard Worker                        [
993*da0073e9SAndroid Build Coastguard Worker                            "input: Union[Tensor, Number]",
994*da0073e9SAndroid Build Coastguard Worker                            "other: Union[Tensor, Number]",
995*da0073e9SAndroid Build Coastguard Worker                            "*",
996*da0073e9SAndroid Build Coastguard Worker                            "rounding_mode: Optional[str] = None",
997*da0073e9SAndroid Build Coastguard Worker                            "out: Optional[Tensor] = None",
998*da0073e9SAndroid Build Coastguard Worker                        ]
999*da0073e9SAndroid Build Coastguard Worker                    )
1000*da0073e9SAndroid Build Coastguard Worker                )
1001*da0073e9SAndroid Build Coastguard Worker            ],
1002*da0073e9SAndroid Build Coastguard Worker        }
1003*da0073e9SAndroid Build Coastguard Worker    )
1004*da0073e9SAndroid Build Coastguard Worker    for binop in ["true_divide", "floor_divide"]:
1005*da0073e9SAndroid Build Coastguard Worker        unsorted_function_hints[binop].append(
1006*da0073e9SAndroid Build Coastguard Worker            f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], "
1007*da0073e9SAndroid Build Coastguard Worker            "*, out: Optional[Tensor] = None) -> Tensor: ..."
1008*da0073e9SAndroid Build Coastguard Worker        )
1009*da0073e9SAndroid Build Coastguard Worker    for binop in ["mul"]:
1010*da0073e9SAndroid Build Coastguard Worker        unsorted_function_hints[binop].append(
1011*da0073e9SAndroid Build Coastguard Worker            f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], "
1012*da0073e9SAndroid Build Coastguard Worker            "*, out: Optional[Tensor] = None) -> Tensor: ..."
1013*da0073e9SAndroid Build Coastguard Worker        )
1014*da0073e9SAndroid Build Coastguard Worker    for binop in ["add", "sub"]:
1015*da0073e9SAndroid Build Coastguard Worker        unsorted_function_hints[binop].append(
1016*da0073e9SAndroid Build Coastguard Worker            f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], "
1017*da0073e9SAndroid Build Coastguard Worker            "*, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: ..."
1018*da0073e9SAndroid Build Coastguard Worker        )
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker    native_functions = parse_native_yaml(
1021*da0073e9SAndroid Build Coastguard Worker        native_yaml_path, tags_yaml_path
1022*da0073e9SAndroid Build Coastguard Worker    ).native_functions
1023*da0073e9SAndroid Build Coastguard Worker    native_functions = list(filter(should_generate_py_binding, native_functions))
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker    function_signatures = load_signatures(
1026*da0073e9SAndroid Build Coastguard Worker        native_functions, deprecated_yaml_path, method=False, pyi=True
1027*da0073e9SAndroid Build Coastguard Worker    )
1028*da0073e9SAndroid Build Coastguard Worker    sig_groups = get_py_torch_functions(function_signatures)
1029*da0073e9SAndroid Build Coastguard Worker    for group in sorted(sig_groups, key=lambda g: g.signature.name):
1030*da0073e9SAndroid Build Coastguard Worker        name = group.signature.name
1031*da0073e9SAndroid Build Coastguard Worker        unsorted_function_hints[name] += generate_type_hints(group)
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Worker        structseq = returns_structseq_pyi(group.signature)
1034*da0073e9SAndroid Build Coastguard Worker        if structseq is not None and not group.signature.deprecated:
1035*da0073e9SAndroid Build Coastguard Worker            # deprecated structseqs are currently not included for torch functions
1036*da0073e9SAndroid Build Coastguard Worker            tuple_name, tuple_def = structseq
1037*da0073e9SAndroid Build Coastguard Worker            if tuple_name in structseqs:
1038*da0073e9SAndroid Build Coastguard Worker                assert structseqs[tuple_name] == tuple_def
1039*da0073e9SAndroid Build Coastguard Worker            else:
1040*da0073e9SAndroid Build Coastguard Worker                structseqs[tuple_name] = tuple_def
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    def replace_special_case(hint: str) -> str:
1043*da0073e9SAndroid Build Coastguard Worker        # NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h
1044*da0073e9SAndroid Build Coastguard Worker        hint = hint.replace("at::Reduction::Mean", "1")
1045*da0073e9SAndroid Build Coastguard Worker        hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None")
1046*da0073e9SAndroid Build Coastguard Worker        # Match both:
1047*da0073e9SAndroid Build Coastguard Worker        # ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None"
1048*da0073e9SAndroid Build Coastguard Worker        # ": Union[Tuple[Tensor, ...], List[Tensor]] = None"
1049*da0073e9SAndroid Build Coastguard Worker        hint = hint.replace(
1050*da0073e9SAndroid Build Coastguard Worker            "Tuple[Tensor, ...], List[Tensor]] = None",
1051*da0073e9SAndroid Build Coastguard Worker            "Tuple[Tensor, ...], List[Tensor], None] = None",
1052*da0073e9SAndroid Build Coastguard Worker        )
1053*da0073e9SAndroid Build Coastguard Worker        return hint
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker    docstrs = gather_docstrs()
1056*da0073e9SAndroid Build Coastguard Worker    function_hints = []
1057*da0073e9SAndroid Build Coastguard Worker    for name, hints in sorted(unsorted_function_hints.items()):
1058*da0073e9SAndroid Build Coastguard Worker        hints = [replace_special_case(h) for h in hints]
1059*da0073e9SAndroid Build Coastguard Worker        if len(hints) > 1:
1060*da0073e9SAndroid Build Coastguard Worker            hints = ["@overload\n" + h for h in hints]
1061*da0073e9SAndroid Build Coastguard Worker        docstr = docstrs.get(f"torch.{name}")
1062*da0073e9SAndroid Build Coastguard Worker        if docstr is not None:
1063*da0073e9SAndroid Build Coastguard Worker            hints = [add_docstr_to_hint(docstr, h) for h in hints]
1064*da0073e9SAndroid Build Coastguard Worker        function_hints += hints
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    # Generate type signatures for Tensor methods
1067*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker    unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
1070*da0073e9SAndroid Build Coastguard Worker    unsorted_tensor_method_hints.update(
1071*da0073e9SAndroid Build Coastguard Worker        {
1072*da0073e9SAndroid Build Coastguard Worker            "size": [
1073*da0073e9SAndroid Build Coastguard Worker                "def size(self, dim: None = None) -> Size: ...",
1074*da0073e9SAndroid Build Coastguard Worker                "def size(self, dim: _int) -> _int: ...",
1075*da0073e9SAndroid Build Coastguard Worker            ],
1076*da0073e9SAndroid Build Coastguard Worker            "stride": [
1077*da0073e9SAndroid Build Coastguard Worker                "def stride(self, dim: None = None) -> Tuple[_int, ...]: ...",
1078*da0073e9SAndroid Build Coastguard Worker                "def stride(self, dim: _int) -> _int: ...",
1079*da0073e9SAndroid Build Coastguard Worker            ],
1080*da0073e9SAndroid Build Coastguard Worker            "new_ones": [
1081*da0073e9SAndroid Build Coastguard Worker                f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..."
1082*da0073e9SAndroid Build Coastguard Worker            ],
1083*da0073e9SAndroid Build Coastguard Worker            "new_tensor": [
1084*da0073e9SAndroid Build Coastguard Worker                f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..."
1085*da0073e9SAndroid Build Coastguard Worker            ],
1086*da0073e9SAndroid Build Coastguard Worker            "__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."],
1087*da0073e9SAndroid Build Coastguard Worker            # new and __init__ have the same signatures differ only in return type
1088*da0073e9SAndroid Build Coastguard Worker            # Adapted from legacy_tensor_ctor and legacy_tensor_new
1089*da0073e9SAndroid Build Coastguard Worker            "new": [
1090*da0073e9SAndroid Build Coastguard Worker                f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...",
1091*da0073e9SAndroid Build Coastguard Worker                "def new(cls, storage: Storage) -> Self: ...",
1092*da0073e9SAndroid Build Coastguard Worker                "def new(cls, other: Tensor) -> Self: ...",
1093*da0073e9SAndroid Build Coastguard Worker                f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...",
1094*da0073e9SAndroid Build Coastguard Worker            ],
1095*da0073e9SAndroid Build Coastguard Worker            "__init__": [
1096*da0073e9SAndroid Build Coastguard Worker                f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...",
1097*da0073e9SAndroid Build Coastguard Worker                "def __init__(self, storage: Storage) -> None: ...",
1098*da0073e9SAndroid Build Coastguard Worker                "def __init__(self, other: Tensor) -> None: ...",
1099*da0073e9SAndroid Build Coastguard Worker                f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...",
1100*da0073e9SAndroid Build Coastguard Worker            ],
1101*da0073e9SAndroid Build Coastguard Worker            "as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."],
1102*da0073e9SAndroid Build Coastguard Worker            "_make_subclass": [
1103*da0073e9SAndroid Build Coastguard Worker                "@staticmethod    \ndef _make_subclass({}) -> S: ...".format(
1104*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
1105*da0073e9SAndroid Build Coastguard Worker                        [
1106*da0073e9SAndroid Build Coastguard Worker                            "cls: _Type[S]",
1107*da0073e9SAndroid Build Coastguard Worker                            "data: Tensor",
1108*da0073e9SAndroid Build Coastguard Worker                            "require_grad: _bool = False",
1109*da0073e9SAndroid Build Coastguard Worker                            "dispatch_strides: _bool = False",
1110*da0073e9SAndroid Build Coastguard Worker                            "dispatch_device: _bool = False",
1111*da0073e9SAndroid Build Coastguard Worker                            "device_for_backend_keys: Optional[_device] = None",
1112*da0073e9SAndroid Build Coastguard Worker                        ]
1113*da0073e9SAndroid Build Coastguard Worker                    )
1114*da0073e9SAndroid Build Coastguard Worker                )
1115*da0073e9SAndroid Build Coastguard Worker            ],
1116*da0073e9SAndroid Build Coastguard Worker            "__contains__": ["def __contains__(self, other: Any, /) -> _bool: ..."],
1117*da0073e9SAndroid Build Coastguard Worker            "__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."],
1118*da0073e9SAndroid Build Coastguard Worker            "__setitem__": [
1119*da0073e9SAndroid Build Coastguard Worker                f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..."
1120*da0073e9SAndroid Build Coastguard Worker            ],
1121*da0073e9SAndroid Build Coastguard Worker            "tolist": ["def tolist(self) -> List: ..."],
1122*da0073e9SAndroid Build Coastguard Worker            "requires_grad_": [
1123*da0073e9SAndroid Build Coastguard Worker                "def requires_grad_(self, mode: _bool = True) -> Tensor: ..."
1124*da0073e9SAndroid Build Coastguard Worker            ],
1125*da0073e9SAndroid Build Coastguard Worker            "element_size": ["def element_size(self) -> _int: ..."],
1126*da0073e9SAndroid Build Coastguard Worker            "data_ptr": ["def data_ptr(self) -> _int: ..."],
1127*da0073e9SAndroid Build Coastguard Worker            "dim": ["def dim(self) -> _int: ..."],
1128*da0073e9SAndroid Build Coastguard Worker            "nonzero": [
1129*da0073e9SAndroid Build Coastguard Worker                "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...",
1130*da0073e9SAndroid Build Coastguard Worker                "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
1131*da0073e9SAndroid Build Coastguard Worker            ],
1132*da0073e9SAndroid Build Coastguard Worker            "numel": ["def numel(self) -> _int: ..."],
1133*da0073e9SAndroid Build Coastguard Worker            "ndimension": ["def ndimension(self) -> _int: ..."],
1134*da0073e9SAndroid Build Coastguard Worker            "nelement": ["def nelement(self) -> _int: ..."],
1135*da0073e9SAndroid Build Coastguard Worker            "cuda": [
1136*da0073e9SAndroid Build Coastguard Worker                "def cuda({}) -> Tensor: ...".format(
1137*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
1138*da0073e9SAndroid Build Coastguard Worker                        [
1139*da0073e9SAndroid Build Coastguard Worker                            "self",
1140*da0073e9SAndroid Build Coastguard Worker                            "device: Optional[Union[_device, _int, str]] = None",
1141*da0073e9SAndroid Build Coastguard Worker                            "non_blocking: _bool = False",
1142*da0073e9SAndroid Build Coastguard Worker                            "memory_format: torch.memory_format = torch.preserve_format",
1143*da0073e9SAndroid Build Coastguard Worker                        ]
1144*da0073e9SAndroid Build Coastguard Worker                    )
1145*da0073e9SAndroid Build Coastguard Worker                )
1146*da0073e9SAndroid Build Coastguard Worker            ],
1147*da0073e9SAndroid Build Coastguard Worker            "xpu": [
1148*da0073e9SAndroid Build Coastguard Worker                "def xpu({}) -> Tensor: ...".format(
1149*da0073e9SAndroid Build Coastguard Worker                    ", ".join(
1150*da0073e9SAndroid Build Coastguard Worker                        [
1151*da0073e9SAndroid Build Coastguard Worker                            "self",
1152*da0073e9SAndroid Build Coastguard Worker                            "device: Optional[Union[_device, _int, str]] = None",
1153*da0073e9SAndroid Build Coastguard Worker                            "non_blocking: _bool = False",
1154*da0073e9SAndroid Build Coastguard Worker                            "memory_format: torch.memory_format = torch.preserve_format",
1155*da0073e9SAndroid Build Coastguard Worker                        ]
1156*da0073e9SAndroid Build Coastguard Worker                    )
1157*da0073e9SAndroid Build Coastguard Worker                )
1158*da0073e9SAndroid Build Coastguard Worker            ],
1159*da0073e9SAndroid Build Coastguard Worker            "cpu": [
1160*da0073e9SAndroid Build Coastguard Worker                "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..."
1161*da0073e9SAndroid Build Coastguard Worker            ],
1162*da0073e9SAndroid Build Coastguard Worker            "numpy": ["def numpy(self, *, force: _bool = False) -> numpy.ndarray: ..."],
1163*da0073e9SAndroid Build Coastguard Worker            "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
1164*da0073e9SAndroid Build Coastguard Worker            "map_": [
1165*da0073e9SAndroid Build Coastguard Worker                "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."
1166*da0073e9SAndroid Build Coastguard Worker            ],
1167*da0073e9SAndroid Build Coastguard Worker            "map2_": [
1168*da0073e9SAndroid Build Coastguard Worker                "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
1169*da0073e9SAndroid Build Coastguard Worker            ],
1170*da0073e9SAndroid Build Coastguard Worker            "storage": ["def untyped_storage(self) -> UntypedStorage: ..."],
1171*da0073e9SAndroid Build Coastguard Worker            "storage_type": ["def storage_type(self) -> Storage: ..."],
1172*da0073e9SAndroid Build Coastguard Worker            "type": [
1173*da0073e9SAndroid Build Coastguard Worker                "def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...",
1174*da0073e9SAndroid Build Coastguard Worker                "def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...",
1175*da0073e9SAndroid Build Coastguard Worker            ],
1176*da0073e9SAndroid Build Coastguard Worker            "get_device": ["def get_device(self) -> _int: ..."],
1177*da0073e9SAndroid Build Coastguard Worker            "contiguous": [
1178*da0073e9SAndroid Build Coastguard Worker                "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..."
1179*da0073e9SAndroid Build Coastguard Worker            ],
1180*da0073e9SAndroid Build Coastguard Worker            "has_names": ["def has_names(self) -> _bool: ..."],
1181*da0073e9SAndroid Build Coastguard Worker            "is_contiguous": [
1182*da0073e9SAndroid Build Coastguard Worker                "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
1183*da0073e9SAndroid Build Coastguard Worker            ],
1184*da0073e9SAndroid Build Coastguard Worker            "_is_view": ["def _is_view(self) -> _bool: ..."],
1185*da0073e9SAndroid Build Coastguard Worker            "is_cpu": ["is_cpu: _bool"],
1186*da0073e9SAndroid Build Coastguard Worker            "is_cuda": ["is_cuda: _bool"],
1187*da0073e9SAndroid Build Coastguard Worker            "is_leaf": ["is_leaf: _bool"],
1188*da0073e9SAndroid Build Coastguard Worker            "is_nested": ["is_nested: _bool"],
1189*da0073e9SAndroid Build Coastguard Worker            "is_sparse": ["is_sparse: _bool"],
1190*da0073e9SAndroid Build Coastguard Worker            "is_sparse_csr": ["is_sparse_csr: _bool"],
1191*da0073e9SAndroid Build Coastguard Worker            "is_quantized": ["is_quantized: _bool"],
1192*da0073e9SAndroid Build Coastguard Worker            "is_meta": ["is_meta: _bool"],
1193*da0073e9SAndroid Build Coastguard Worker            "is_mps": ["is_mps: _bool"],
1194*da0073e9SAndroid Build Coastguard Worker            "is_mtia": ["is_mtia: _bool"],
1195*da0073e9SAndroid Build Coastguard Worker            "is_maia": ["is_maia: _bool"],
1196*da0073e9SAndroid Build Coastguard Worker            "is_mkldnn": ["is_mkldnn: _bool"],
1197*da0073e9SAndroid Build Coastguard Worker            "is_vulkan": ["is_vulkan: _bool"],
1198*da0073e9SAndroid Build Coastguard Worker            "is_ipu": ["is_ipu: _bool"],
1199*da0073e9SAndroid Build Coastguard Worker            "storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."],
1200*da0073e9SAndroid Build Coastguard Worker            "to": [
1201*da0073e9SAndroid Build Coastguard Worker                (
1202*da0073e9SAndroid Build Coastguard Worker                    f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, "
1203*da0073e9SAndroid Build Coastguard Worker                    "memory_format: Optional[torch.memory_format] = None) -> Tensor: ..."
1204*da0073e9SAndroid Build Coastguard Worker                )
1205*da0073e9SAndroid Build Coastguard Worker                for args in [
1206*da0073e9SAndroid Build Coastguard Worker                    "dtype: _dtype",
1207*da0073e9SAndroid Build Coastguard Worker                    "device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None",
1208*da0073e9SAndroid Build Coastguard Worker                    "other: Tensor",
1209*da0073e9SAndroid Build Coastguard Worker                ]
1210*da0073e9SAndroid Build Coastguard Worker            ],
1211*da0073e9SAndroid Build Coastguard Worker            "item": ["def item(self) -> Number: ..."],
1212*da0073e9SAndroid Build Coastguard Worker            "copy_": [
1213*da0073e9SAndroid Build Coastguard Worker                "def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..."
1214*da0073e9SAndroid Build Coastguard Worker            ],
1215*da0073e9SAndroid Build Coastguard Worker            "set_": [
1216*da0073e9SAndroid Build Coastguard Worker                "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
1217*da0073e9SAndroid Build Coastguard Worker                "offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...",
1218*da0073e9SAndroid Build Coastguard Worker                "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
1219*da0073e9SAndroid Build Coastguard Worker            ],
1220*da0073e9SAndroid Build Coastguard Worker            "split": [
1221*da0073e9SAndroid Build Coastguard Worker                "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...",
1222*da0073e9SAndroid Build Coastguard Worker                "def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...",
1223*da0073e9SAndroid Build Coastguard Worker            ],
1224*da0073e9SAndroid Build Coastguard Worker            "div": [
1225*da0073e9SAndroid Build Coastguard Worker                "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
1226*da0073e9SAndroid Build Coastguard Worker            ],
1227*da0073e9SAndroid Build Coastguard Worker            "div_": [
1228*da0073e9SAndroid Build Coastguard Worker                "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
1229*da0073e9SAndroid Build Coastguard Worker            ],
1230*da0073e9SAndroid Build Coastguard Worker        }
1231*da0073e9SAndroid Build Coastguard Worker    )
1232*da0073e9SAndroid Build Coastguard Worker    for binop in ["true_divide", "floor_divide"]:
1233*da0073e9SAndroid Build Coastguard Worker        for inplace in [False, True]:
1234*da0073e9SAndroid Build Coastguard Worker            out_suffix = ", *, out: Optional[Tensor] = None"
1235*da0073e9SAndroid Build Coastguard Worker            if inplace:
1236*da0073e9SAndroid Build Coastguard Worker                binop += "_"
1237*da0073e9SAndroid Build Coastguard Worker                out_suffix = ""
1238*da0073e9SAndroid Build Coastguard Worker            unsorted_tensor_method_hints[binop].append(
1239*da0073e9SAndroid Build Coastguard Worker                f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})"
1240*da0073e9SAndroid Build Coastguard Worker                " -> Tensor: ..."
1241*da0073e9SAndroid Build Coastguard Worker            )
1242*da0073e9SAndroid Build Coastguard Worker    for binop in ["mul"]:
1243*da0073e9SAndroid Build Coastguard Worker        for inplace in [False, True]:
1244*da0073e9SAndroid Build Coastguard Worker            out_suffix = ", *, out: Optional[Tensor] = None"
1245*da0073e9SAndroid Build Coastguard Worker            if inplace:
1246*da0073e9SAndroid Build Coastguard Worker                binop += "_"
1247*da0073e9SAndroid Build Coastguard Worker                out_suffix = ""
1248*da0073e9SAndroid Build Coastguard Worker            unsorted_tensor_method_hints[binop].append(
1249*da0073e9SAndroid Build Coastguard Worker                f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat]{out_suffix})"
1250*da0073e9SAndroid Build Coastguard Worker                " -> Tensor: ..."
1251*da0073e9SAndroid Build Coastguard Worker            )
1252*da0073e9SAndroid Build Coastguard Worker    for binop in ["add", "sub"]:
1253*da0073e9SAndroid Build Coastguard Worker        for inplace in [False, True]:
1254*da0073e9SAndroid Build Coastguard Worker            out_suffix = ", out: Optional[Tensor] = None"
1255*da0073e9SAndroid Build Coastguard Worker            if inplace:
1256*da0073e9SAndroid Build Coastguard Worker                binop += "_"
1257*da0073e9SAndroid Build Coastguard Worker                out_suffix = ""
1258*da0073e9SAndroid Build Coastguard Worker            unsorted_tensor_method_hints[binop].append(
1259*da0073e9SAndroid Build Coastguard Worker                f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], "
1260*da0073e9SAndroid Build Coastguard Worker                f"*, alpha: Optional[Union[Number, _complex]] = 1{out_suffix})"
1261*da0073e9SAndroid Build Coastguard Worker                " -> Tensor: ..."
1262*da0073e9SAndroid Build Coastguard Worker            )
1263*da0073e9SAndroid Build Coastguard Worker    simple_conversions = [
1264*da0073e9SAndroid Build Coastguard Worker        "byte",
1265*da0073e9SAndroid Build Coastguard Worker        "char",
1266*da0073e9SAndroid Build Coastguard Worker        "double",
1267*da0073e9SAndroid Build Coastguard Worker        "float",
1268*da0073e9SAndroid Build Coastguard Worker        "half",
1269*da0073e9SAndroid Build Coastguard Worker        "int",
1270*da0073e9SAndroid Build Coastguard Worker        "long",
1271*da0073e9SAndroid Build Coastguard Worker        "short",
1272*da0073e9SAndroid Build Coastguard Worker        "bool",
1273*da0073e9SAndroid Build Coastguard Worker        "bfloat16",
1274*da0073e9SAndroid Build Coastguard Worker    ]
1275*da0073e9SAndroid Build Coastguard Worker    for name in simple_conversions:
1276*da0073e9SAndroid Build Coastguard Worker        unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker    # pyi tensor methods don't currently include deprecated signatures for some reason
1279*da0073e9SAndroid Build Coastguard Worker    # TODO: we should probably add them in
1280*da0073e9SAndroid Build Coastguard Worker    tensor_method_signatures = load_signatures(
1281*da0073e9SAndroid Build Coastguard Worker        native_functions,
1282*da0073e9SAndroid Build Coastguard Worker        deprecated_yaml_path,
1283*da0073e9SAndroid Build Coastguard Worker        method=True,
1284*da0073e9SAndroid Build Coastguard Worker        skip_deprecated=True,
1285*da0073e9SAndroid Build Coastguard Worker        pyi=True,
1286*da0073e9SAndroid Build Coastguard Worker    )
1287*da0073e9SAndroid Build Coastguard Worker    tensor_method_sig_groups = get_py_torch_functions(
1288*da0073e9SAndroid Build Coastguard Worker        tensor_method_signatures, method=True
1289*da0073e9SAndroid Build Coastguard Worker    )
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker    for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
1292*da0073e9SAndroid Build Coastguard Worker        name = group.signature.name
1293*da0073e9SAndroid Build Coastguard Worker        unsorted_tensor_method_hints[name] += generate_type_hints(group)
1294*da0073e9SAndroid Build Coastguard Worker
1295*da0073e9SAndroid Build Coastguard Worker        structseq = returns_structseq_pyi(group.signature)
1296*da0073e9SAndroid Build Coastguard Worker        if structseq is not None and not group.signature.deprecated:
1297*da0073e9SAndroid Build Coastguard Worker            # deprecated structseqs are currently not included for torch functions
1298*da0073e9SAndroid Build Coastguard Worker            tuple_name, tuple_def = structseq
1299*da0073e9SAndroid Build Coastguard Worker            if tuple_name in structseqs:
1300*da0073e9SAndroid Build Coastguard Worker                assert structseqs[tuple_name] == tuple_def
1301*da0073e9SAndroid Build Coastguard Worker            else:
1302*da0073e9SAndroid Build Coastguard Worker                structseqs[tuple_name] = tuple_def
1303*da0073e9SAndroid Build Coastguard Worker
1304*da0073e9SAndroid Build Coastguard Worker    for op in all_ops:
1305*da0073e9SAndroid Build Coastguard Worker        name = f"__{op}__"
1306*da0073e9SAndroid Build Coastguard Worker        unsorted_tensor_method_hints[name] += sig_for_ops(name)
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker    tensor_method_hints = []
1309*da0073e9SAndroid Build Coastguard Worker    for name, hints in sorted(unsorted_tensor_method_hints.items()):
1310*da0073e9SAndroid Build Coastguard Worker        if len(hints) > 1:
1311*da0073e9SAndroid Build Coastguard Worker            hints = ["@overload\n" + h for h in hints]
1312*da0073e9SAndroid Build Coastguard Worker        docstr = docstrs.get(f"torch._C.TensorBase.{name}")
1313*da0073e9SAndroid Build Coastguard Worker        if docstr is not None:
1314*da0073e9SAndroid Build Coastguard Worker            hints = [add_docstr_to_hint(docstr, h) for h in hints]
1315*da0073e9SAndroid Build Coastguard Worker        tensor_method_hints += hints
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker    # TODO: Missing type hints for nn
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker    # Generate structseq definitions
1320*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker    structseq_defs = [f"{defn}\n" for defn in structseqs.values()]
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker    # Generate type signatures for legacy classes
1325*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker    legacy_storage_base_hints = ["class StorageBase(object): ..."]
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker    legacy_class_hints = []
1330*da0073e9SAndroid Build Coastguard Worker    for c in (
1331*da0073e9SAndroid Build Coastguard Worker        "DoubleTensor",
1332*da0073e9SAndroid Build Coastguard Worker        "FloatTensor",
1333*da0073e9SAndroid Build Coastguard Worker        "BFloat16Tensor",
1334*da0073e9SAndroid Build Coastguard Worker        "LongTensor",
1335*da0073e9SAndroid Build Coastguard Worker        "IntTensor",
1336*da0073e9SAndroid Build Coastguard Worker        "ShortTensor",
1337*da0073e9SAndroid Build Coastguard Worker        "HalfTensor",
1338*da0073e9SAndroid Build Coastguard Worker        "CharTensor",
1339*da0073e9SAndroid Build Coastguard Worker        "ByteTensor",
1340*da0073e9SAndroid Build Coastguard Worker        "BoolTensor",
1341*da0073e9SAndroid Build Coastguard Worker    ):
1342*da0073e9SAndroid Build Coastguard Worker        legacy_class_hints.append(f"class {c}(Tensor): ...")
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    # Generate type signatures for dtype classes
1345*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker    # TODO: don't explicitly list dtypes here; get it from canonical
1348*da0073e9SAndroid Build Coastguard Worker    # source
1349*da0073e9SAndroid Build Coastguard Worker    dtype_class_hints = [
1350*da0073e9SAndroid Build Coastguard Worker        f"{n}: dtype = ..."
1351*da0073e9SAndroid Build Coastguard Worker        for n in [
1352*da0073e9SAndroid Build Coastguard Worker            "float32",
1353*da0073e9SAndroid Build Coastguard Worker            "float",
1354*da0073e9SAndroid Build Coastguard Worker            "float64",
1355*da0073e9SAndroid Build Coastguard Worker            "double",
1356*da0073e9SAndroid Build Coastguard Worker            "float16",
1357*da0073e9SAndroid Build Coastguard Worker            "bfloat16",
1358*da0073e9SAndroid Build Coastguard Worker            "float8_e4m3fn",
1359*da0073e9SAndroid Build Coastguard Worker            "float8_e4m3fnuz",
1360*da0073e9SAndroid Build Coastguard Worker            "float8_e5m2",
1361*da0073e9SAndroid Build Coastguard Worker            "float8_e5m2fnuz",
1362*da0073e9SAndroid Build Coastguard Worker            "half",
1363*da0073e9SAndroid Build Coastguard Worker            "uint8",
1364*da0073e9SAndroid Build Coastguard Worker            "uint16",
1365*da0073e9SAndroid Build Coastguard Worker            "uint32",
1366*da0073e9SAndroid Build Coastguard Worker            "uint64",
1367*da0073e9SAndroid Build Coastguard Worker            "int8",
1368*da0073e9SAndroid Build Coastguard Worker            "int16",
1369*da0073e9SAndroid Build Coastguard Worker            "short",
1370*da0073e9SAndroid Build Coastguard Worker            "int32",
1371*da0073e9SAndroid Build Coastguard Worker            "int",
1372*da0073e9SAndroid Build Coastguard Worker            "int64",
1373*da0073e9SAndroid Build Coastguard Worker            "long",
1374*da0073e9SAndroid Build Coastguard Worker            "complex32",
1375*da0073e9SAndroid Build Coastguard Worker            "complex64",
1376*da0073e9SAndroid Build Coastguard Worker            "chalf",
1377*da0073e9SAndroid Build Coastguard Worker            "cfloat",
1378*da0073e9SAndroid Build Coastguard Worker            "complex128",
1379*da0073e9SAndroid Build Coastguard Worker            "cdouble",
1380*da0073e9SAndroid Build Coastguard Worker            "quint8",
1381*da0073e9SAndroid Build Coastguard Worker            "qint8",
1382*da0073e9SAndroid Build Coastguard Worker            "qint32",
1383*da0073e9SAndroid Build Coastguard Worker            "bool",
1384*da0073e9SAndroid Build Coastguard Worker            "quint4x2",
1385*da0073e9SAndroid Build Coastguard Worker            "quint2x4",
1386*da0073e9SAndroid Build Coastguard Worker            "bits1x8",
1387*da0073e9SAndroid Build Coastguard Worker            "bits2x4",
1388*da0073e9SAndroid Build Coastguard Worker            "bits4x2",
1389*da0073e9SAndroid Build Coastguard Worker            "bits8",
1390*da0073e9SAndroid Build Coastguard Worker            "bits16",
1391*da0073e9SAndroid Build Coastguard Worker        ]
1392*da0073e9SAndroid Build Coastguard Worker    ]
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker    # Generate __all__ directive
1395*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker    # Include only the functions that contain hints, to prevent undefined
1398*da0073e9SAndroid Build Coastguard Worker    # symbols to be included in the `__all__` directive.
1399*da0073e9SAndroid Build Coastguard Worker    hinted_function_names = [
1400*da0073e9SAndroid Build Coastguard Worker        name for name, hint in unsorted_function_hints.items() if hint
1401*da0073e9SAndroid Build Coastguard Worker    ]
1402*da0073e9SAndroid Build Coastguard Worker    all_symbols = sorted(list(structseqs.keys()) + hinted_function_names)
1403*da0073e9SAndroid Build Coastguard Worker    all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
1404*da0073e9SAndroid Build Coastguard Worker    all_directive[0] = f"__all__ = {all_directive[0]}"
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker    # Dispatch key hints
1407*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~
1408*da0073e9SAndroid Build Coastguard Worker    dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey]
1409*da0073e9SAndroid Build Coastguard Worker    torch_dispatch_mode_key_hints = [
1410*da0073e9SAndroid Build Coastguard Worker        f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey
1411*da0073e9SAndroid Build Coastguard Worker    ]
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    # Tags Enum type hints
1414*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker    tag_names = sorted(parse_tags_yaml(tags_yaml_path))
1417*da0073e9SAndroid Build Coastguard Worker    tag_attributes = "\n".join(
1418*da0073e9SAndroid Build Coastguard Worker        f"{name}: _int = {index}" for index, name in enumerate(tag_names)
1419*da0073e9SAndroid Build Coastguard Worker    )
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker    # Write out the stub
1422*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker    env = {
1425*da0073e9SAndroid Build Coastguard Worker        "structseq_defs": structseq_defs,
1426*da0073e9SAndroid Build Coastguard Worker        "function_hints": function_hints,
1427*da0073e9SAndroid Build Coastguard Worker        "tensor_method_hints": tensor_method_hints,
1428*da0073e9SAndroid Build Coastguard Worker        "legacy_class_hints": legacy_class_hints,
1429*da0073e9SAndroid Build Coastguard Worker        "legacy_storage_base_hints": legacy_storage_base_hints,
1430*da0073e9SAndroid Build Coastguard Worker        "dtype_class_hints": dtype_class_hints,
1431*da0073e9SAndroid Build Coastguard Worker        "dispatch_key_hints": dispatch_key_hints,
1432*da0073e9SAndroid Build Coastguard Worker        "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints,
1433*da0073e9SAndroid Build Coastguard Worker        "all_directive": all_directive,
1434*da0073e9SAndroid Build Coastguard Worker        "tag_attributes": tag_attributes,
1435*da0073e9SAndroid Build Coastguard Worker    }
1436*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
1437*da0073e9SAndroid Build Coastguard Worker        "torch/_C/__init__.pyi",
1438*da0073e9SAndroid Build Coastguard Worker        "torch/_C/__init__.pyi.in",
1439*da0073e9SAndroid Build Coastguard Worker        lambda: env,
1440*da0073e9SAndroid Build Coastguard Worker    )
1441*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
1442*da0073e9SAndroid Build Coastguard Worker        "torch/_C/_VariableFunctions.pyi",
1443*da0073e9SAndroid Build Coastguard Worker        "torch/_C/_VariableFunctions.pyi.in",
1444*da0073e9SAndroid Build Coastguard Worker        lambda: env,
1445*da0073e9SAndroid Build Coastguard Worker    )
1446*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
1447*da0073e9SAndroid Build Coastguard Worker        "torch/_VF.pyi",
1448*da0073e9SAndroid Build Coastguard Worker        "torch/_C/_VariableFunctions.pyi.in",
1449*da0073e9SAndroid Build Coastguard Worker        lambda: env,
1450*da0073e9SAndroid Build Coastguard Worker    )
1451*da0073e9SAndroid Build Coastguard Worker    fm.write_with_template(
1452*da0073e9SAndroid Build Coastguard Worker        "torch/return_types.pyi",
1453*da0073e9SAndroid Build Coastguard Worker        "torch/_C/return_types.pyi.in",
1454*da0073e9SAndroid Build Coastguard Worker        lambda: env,
1455*da0073e9SAndroid Build Coastguard Worker    )
1456*da0073e9SAndroid Build Coastguard Worker    gen_nn_functional(fm)
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
1460*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
1461*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1462*da0073e9SAndroid Build Coastguard Worker        "--native-functions-path",
1463*da0073e9SAndroid Build Coastguard Worker        metavar="NATIVE",
1464*da0073e9SAndroid Build Coastguard Worker        default="aten/src/ATen/native/native_functions.yaml",
1465*da0073e9SAndroid Build Coastguard Worker        help="path to native_functions.yaml",
1466*da0073e9SAndroid Build Coastguard Worker    )
1467*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1468*da0073e9SAndroid Build Coastguard Worker        "--tags-path",
1469*da0073e9SAndroid Build Coastguard Worker        metavar="TAGS",
1470*da0073e9SAndroid Build Coastguard Worker        default="aten/src/ATen/native/tags.yaml",
1471*da0073e9SAndroid Build Coastguard Worker        help="path to tags.yaml",
1472*da0073e9SAndroid Build Coastguard Worker    )
1473*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1474*da0073e9SAndroid Build Coastguard Worker        "--deprecated-functions-path",
1475*da0073e9SAndroid Build Coastguard Worker        metavar="DEPRECATED",
1476*da0073e9SAndroid Build Coastguard Worker        default="tools/autograd/deprecated.yaml",
1477*da0073e9SAndroid Build Coastguard Worker        help="path to deprecated.yaml",
1478*da0073e9SAndroid Build Coastguard Worker    )
1479*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
1480*da0073e9SAndroid Build Coastguard Worker        "--out", metavar="OUT", default=".", help="path to output directory"
1481*da0073e9SAndroid Build Coastguard Worker    )
1482*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
1483*da0073e9SAndroid Build Coastguard Worker    fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
1484*da0073e9SAndroid Build Coastguard Worker    gen_pyi(
1485*da0073e9SAndroid Build Coastguard Worker        args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm
1486*da0073e9SAndroid Build Coastguard Worker    )
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1490*da0073e9SAndroid Build Coastguard Worker    main()
1491