xref: /aosp_15_r20/external/pytorch/tools/pyi/gen_pyi.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import collections
5import importlib
6import sys
7from pprint import pformat
8from typing import Sequence
9from unittest.mock import Mock, patch
10from warnings import warn
11
12from tools.autograd.gen_python_functions import (
13    group_overloads,
14    load_signatures,
15    should_generate_py_binding,
16)
17
18from torchgen.api.python import (
19    PythonSignatureGroup,
20    PythonSignatureNativeFunctionPair,
21    returns_structseq_pyi,
22)
23from torchgen.gen import parse_native_yaml, parse_tags_yaml
24from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
25from torchgen.utils import FileManager
26
27
28"""
29This module implements generation of type stubs for PyTorch,
30enabling use of autocomplete in IDEs like PyCharm, which otherwise
31don't understand C extension modules.
32
33At the moment, this module only handles type stubs for torch and
34torch.Tensor.  It should eventually be expanded to cover all functions
35which come are autogenerated.
36
37Here's our general strategy:
38
39- We start off with a hand-written __init__.pyi.in file.  This
40  file contains type definitions for everything we cannot automatically
41  generate, including pure Python definitions directly in __init__.py
42  (the latter case should be pretty rare).
43
44- We go through automatically bound functions based on the
45  type information recorded in native_functions.yaml and
46  generate type hints for them (generate_type_hints)
47
48There are a number of type hints which we've special-cased;
49read gen_pyi for the gory details.
50"""
51
52
53def get_py_torch_functions(
54    python_funcs: Sequence[PythonSignatureNativeFunctionPair],
55    method: bool = False,
56) -> Sequence[PythonSignatureGroup]:
57    """
58    Get declarations (grouped by name) which should be generated
59    as either functions in the "torch" module or methods on Tensor.
60    """
61
62    def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
63        return (
64            should_generate_py_binding(python_func.function)
65            and not python_func.function.python_module
66            and Variant.function in python_func.function.variants
67        )
68
69    def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
70        return (
71            should_generate_py_binding(python_func.function)
72            and not python_func.function.python_module
73            and Variant.method in python_func.function.variants
74        )
75
76    should_bind = should_bind_method if method else should_bind_function
77    return group_overloads([f for f in python_funcs if should_bind(f)])
78
79
80# TODO: Consider defining some aliases for our Union[...] types, to make
81# the stubs to read on the human eye.
82
83DEVICE_PARAM = "device: Optional[DeviceLikeType] = None"
84FACTORY_PARAMS = f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False, pin_memory: _bool = False"
85
86# NOTE: specifying indices for Tensor.__getitem__
87# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi:
88#
89# key: (
90#     None
91#     | slice
92#     | ellipsis
93#     | SupportsIndex
94#     | _ArrayLikeInt_co
95#     | tuple[None | slice | ellipsis | _ArrayLikeInt_co | SupportsIndex, ...]
96# )
97#
98# where:
99#
100# _ArrayLikeInt_co = _DualArrayLike[
101#     dtype[Union[bool_, integer[Any]]],
102#     Union[bool, int],
103# ]
104#
105# and
106#
107# _DualArrayLike = Union[
108#     _SupportsArray[_DType],
109#     _NestedSequence[_SupportsArray[_DType]],
110#     _T,
111#     _NestedSequence[_T],
112# ]
113#
114# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple.
115# We can substitute and simplify:
116# _SupportsArray -> Tensor
117# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]]
118# which leaves us with key: T | tuple[T, ...], where T is:
119# T = (
120#     None | bool | int | slice | ellipsis | SupportsIndex
121#     | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int]
122# )
123
124# NOTE: ellipsis is equal to type[Ellipsis] in stub files.
125_leaf_types = "Union[None, _bool, _int, slice, ellipsis, Tensor]"  # not SupportsIndex!
126_index = f"Union[SupportsIndex, {_leaf_types}, _NestedSequence[{_leaf_types}]]"
127INDICES = f"indices: Union[{_index}, tuple[{_index}, ...]]"
128
129blocklist = [
130    "__init_subclass__",
131    "__new__",
132    "__subclasshook__",
133    "cdist",
134    "device",
135    "grad",
136    "requires_grad",
137    "range",
138    # defined in functional
139    "einsum",
140    # Somehow, these are defined in both _C and in functional. Ick!
141    "broadcast_tensors",
142    # Manually define named tensor type stubs in __init__.pyi.in
143    "align_tensors",
144    "meshgrid",
145    "cartesian_prod",
146    "block_diag",
147    "norm",
148    "chain_matmul",
149    "stft",
150    "tensordot",
151    "split",
152    "unique_consecutive",
153    "atleast_1d",
154    "atleast_2d",
155    "atleast_3d",
156    # These are handled specially by python_arg_parser.cpp
157    "add",
158    "add_",
159    "add_out",
160    "sub",
161    "sub_",
162    "sub_out",
163    "mul",
164    "mul_",
165    "mul_out",
166    "div",
167    "div_",
168    "div_out",
169    "true_divide",
170    "true_divide_",
171    "true_divide_out",
172    "floor_divide",
173    "floor_divide_",
174    "floor_divide_out",
175    "to",
176    "_to_copy",
177    "copy_",
178]
179
180binary_ops = (
181    "add",
182    "sub",
183    "mul",
184    "div",
185    "pow",
186    "lshift",
187    "rshift",
188    "mod",
189    "truediv",
190    "matmul",
191    "floordiv",
192    "radd",
193    "rsub",
194    "rmul",
195    "rtruediv",
196    "rfloordiv",
197    "rpow",  # reverse arithmetic
198    "and",
199    "or",
200    "xor",
201    "rand",
202    "ror",
203    "rxor",  # logic
204    "iadd",
205    "iand",
206    "idiv",
207    "ilshift",
208    "imul",
209    "ior",
210    "irshift",
211    "isub",
212    "ixor",
213    "ifloordiv",
214    "imod",  # inplace ops
215)
216symmetric_comparison_ops = ("eq", "ne")
217asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
218comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
219
220unary_ops = ("neg", "abs", "invert")
221to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
222all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
223
224
225def sig_for_ops(opname: str) -> list[str]:
226    """sig_for_ops(opname : str) -> List[str]
227
228    Returns signatures for operator special functions (__add__ etc.)"""
229
230    # we have to do this by hand, because they are hand-bound in Python
231
232    assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
233
234    name = opname[2:-2]
235    if name in binary_ops:
236        return [f"def {opname}(self, other: Any) -> Tensor: ..."]
237    elif name in comparison_ops:
238        sig = f"def {opname}(self, other: Any) -> Tensor: ..."
239        if name in symmetric_comparison_ops:
240            # unsafe override https://github.com/python/mypy/issues/5704
241            sig += "  # type: ignore[override]"
242        return [sig]
243    elif name in unary_ops:
244        return [f"def {opname}(self) -> Tensor: ..."]
245    elif name in to_py_type_ops:
246        if name in {"bool", "float", "complex"}:
247            tname = name
248        elif name == "nonzero":
249            tname = "bool"
250        else:
251            tname = "int"
252        if tname in {"float", "int", "bool", "complex"}:
253            tname = "builtins." + tname
254        return [f"def {opname}(self) -> {tname}: ..."]
255    else:
256        raise Exception("unknown op", opname)  # noqa: TRY002
257
258
259def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
260    type_hints: list[str] = []
261
262    # Some deprecated ops that are on the blocklist are still included in pyi
263    if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
264        return type_hints
265
266    # deprecated signatures have separate entries for their functional and out variants
267    # (as opposed to the native ops, which fuse the two into a single signature).
268    # generate the functional variant here, if an out variant exists.
269    if sig_group.signature.deprecated and sig_group.outplace is not None:
270        type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
271        type_hints.append(type_hint)
272
273    # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
274    # Generates the out variant if one exists. Otherwise, generate the functional variant
275    type_hint = sig_group.signature.signature_str_pyi(
276        skip_outputs=sig_group.outplace is None
277    )
278    type_hints.append(type_hint)
279
280    # Some operators also additionally have a vararg variant of their signature
281    type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
282        skip_outputs=sig_group.outplace is None
283    )
284    if type_hint_vararg:
285        type_hints.append(type_hint_vararg)
286
287    return type_hints
288
289
290def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
291    flag_pos = arg_list.index("{return_indices}")
292    # If return_indices is positional arg, everything before should have no default
293    arg_list_positional = (
294        [
295            ", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", "))
296            for arg in arg_list[: flag_pos + 1]
297        ]
298        + ["/"]
299        + arg_list[flag_pos + 1 :]
300    )
301    # Otherwise force return_indices to be kwarg
302    arg_list_keyword = arg_list.copy()
303    arg_list_keyword.insert(flag_pos, "*")
304    tmpl = "def {name}({args}) -> {{return_type}}: ..."
305    return {
306        name: [
307            tmpl.format(name=name, args=", ".join(arg_list)).format(
308                return_indices="return_indices: Literal[False] = False",
309                return_type="Tensor",
310            ),
311            tmpl.format(name=name, args=", ".join(arg_list_positional)).format(
312                return_indices="return_indices: Literal[True]",
313                return_type="Tuple[Tensor, Tensor]",
314            ),
315            tmpl.format(name=name, args=", ".join(arg_list_keyword)).format(
316                return_indices="return_indices: Literal[True]",
317                return_type="Tuple[Tensor, Tensor]",
318            ),
319        ]
320    }
321
322
323def gen_nn_functional(fm: FileManager) -> None:
324    INPUT = "input: Tensor"
325    KERNEL_SIZE = "kernel_size: Union[_int, _size]"
326    STRIDE_PADDING = ", ".join(
327        [
328            "stride: Optional[Union[_int, _size]] = None",
329            "padding: Union[_int, _size] = 0",
330        ]
331    )
332
333    # TODO the list for `torch._C._nn` is nonexhaustive
334    unsorted_c_nn_function_hints: dict[str, list[str]] = {}
335
336    for d in (2, 3):
337        unsorted_c_nn_function_hints.update(
338            {
339                f"avg_pool{d}d": [
340                    f"def avg_pool{d}d({{}}) -> Tensor: ...".format(
341                        ", ".join(
342                            [
343                                f"{INPUT}",
344                                f"{KERNEL_SIZE}",
345                                f"{STRIDE_PADDING}",
346                                "ceil_mode: bool = False",
347                                "count_include_pad: bool = True",
348                                "divisor_override: Optional[int] = None",
349                            ]
350                        )
351                    )
352                ],
353                f"fractional_max_pool{d}d": [
354                    f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format(
355                        ", ".join(
356                            [
357                                f"{INPUT}",
358                                f"{KERNEL_SIZE}",
359                                "output_size: Union[_int, _size]",
360                                "_random_samples: Tensor",
361                            ]
362                        ),
363                        "Tuple[Tensor, Tensor]",
364                    )
365                ],
366                f"adaptive_max_pool{d}d": [
367                    f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format(
368                        ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]),
369                        "Tuple[Tensor, Tensor]",
370                    )
371                ],
372            }
373        )
374
375    unsorted_c_nn_function_hints.update(
376        {
377            "hardtanh": [
378                "def hardtanh({}) -> Tensor: ...".format(
379                    ", ".join(
380                        [
381                            "input: Tensor",
382                            "min_val: float = ...",
383                            "max_val: float = ...",
384                            "*",
385                            "out: Optional[Tensor] = None",
386                        ]
387                    )
388                )
389            ],
390            "hardtanh_": [
391                "def hardtanh_({}) -> Tensor: ...".format(
392                    ", ".join(
393                        [
394                            "input: Tensor",
395                            "min_val: float = ...",
396                            "max_val: float = ...",
397                        ]
398                    )
399                )
400            ],
401            "elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."],
402            "leaky_relu": [
403                "def leaky_relu({}) -> Tensor: ...".format(
404                    ", ".join(
405                        [
406                            "input: Tensor",
407                            "negative_slope: float = ...",
408                            "*",
409                            "out: Optional[Tensor] = None",
410                        ]
411                    )
412                )
413            ],
414            "leaky_relu_": [
415                f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..."
416            ],
417            "log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."],
418            "gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."],
419            "softplus": [
420                "def softplus({}) -> Tensor: ...".format(
421                    ", ".join(
422                        ["input: Tensor", "beta: float = ...", "threshold: float = ..."]
423                    )
424                )
425            ],
426            "softshrink": [
427                "def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..."
428            ],
429            "hardsigmoid": [
430                f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..."
431            ],
432            "linear": [
433                "def linear({}) -> Tensor: ...".format(
434                    ", ".join(
435                        [
436                            "input: Tensor",
437                            "weight: Tensor",
438                            "bias: Optional[Tensor] = None",
439                        ]
440                    )
441                )
442            ],
443            "pad": [
444                "def pad({}) -> Tensor: ...".format(
445                    ", ".join(
446                        [
447                            "input: Tensor",
448                            "pad: Sequence[int]",
449                            "mode: str = ...",
450                            "value: Optional[float] = None",
451                        ]
452                    )
453                )
454            ],
455            "one_hot": [
456                "def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..."
457            ],
458            "scaled_dot_product_attention": [
459                "def scaled_dot_product_attention({}) -> Tensor: ...".format(
460                    ", ".join(
461                        [
462                            "query: Tensor",
463                            "key: Tensor",
464                            "value: Tensor",
465                            "attn_mask: Optional[Tensor] = None",
466                            "dropout_p: float = 0.0",
467                            "is_causal: bool = False",
468                            "scale: Optional[float] = None",
469                            "enable_gqa: bool = False",
470                        ]
471                    )
472                )
473            ],
474        }
475    )
476
477    c_nn_function_hints: list[str] = []
478    for _, hints in sorted(unsorted_c_nn_function_hints.items()):
479        if len(hints) > 1:
480            hints = ["@overload\n" + h for h in hints]
481        c_nn_function_hints += hints
482
483    # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
484    # through an `_add_docstr` call
485    torch_imports = [
486        "conv1d",
487        "conv2d",
488        "conv3d",
489        "conv_transpose1d",
490        "conv_transpose2d",
491        "conv_transpose3d",
492        "conv_tbc",
493        "avg_pool1d",
494        "adaptive_avg_pool1d",
495        "relu_",
496        "selu_",
497        "celu_",
498        "prelu",
499        "rrelu_",
500        "hardshrink",
501        "bilinear",
502        "pixel_shuffle",
503        "pixel_unshuffle",
504        "channel_shuffle",
505        "native_channel_shuffle",
506        "pairwise_distance",
507        "pdist",
508        "cosine_similarity",
509    ]
510    imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports]
511
512    # Functions imported into `torch.nn.functional` from `torch._C._nn`
513    c_nn_imports = [
514        "avg_pool2d",
515        "avg_pool3d",
516        "hardtanh_",
517        "elu_",
518        "leaky_relu_",
519        "gelu",
520        "softplus",
521        "softshrink",
522        "linear",
523        "pad",
524        "one_hot",
525        "scaled_dot_product_attention",
526    ]
527    imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports]
528    # This is from `torch._C._nn` but renamed
529    imported_hints.append(
530        "from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid"
531    )
532
533    # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
534    unsorted_dispatched_hints: dict[str, list[str]] = {}
535
536    for d in (1, 2, 3):
537        unsorted_dispatched_hints.update(
538            **get_max_pool_dispatch(
539                f"max_pool{d}d",
540                [
541                    f"{INPUT}",
542                    f"{KERNEL_SIZE}",
543                    f"{STRIDE_PADDING}",
544                    "dilation: Union[_int, _size] = 1",
545                    "ceil_mode: bool = False",
546                    "{return_indices}",
547                ],
548            ),
549            **get_max_pool_dispatch(
550                f"fractional_max_pool{d}d",
551                [
552                    f"{INPUT}",
553                    f"{KERNEL_SIZE}",
554                    "output_size: Optional[Union[_int, _size]] = None",
555                    "output_ratio: Optional[_ratio_any_t] = None",
556                    "{return_indices}",
557                    "_random_samples: Optional[Tensor] = None",
558                ],
559            ),
560            **get_max_pool_dispatch(
561                f"adaptive_max_pool{d}d",
562                [f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"],
563            ),
564        )
565
566    # There's no fractional_max_pool1d
567    del unsorted_dispatched_hints["fractional_max_pool1d"]
568
569    dispatched_hints: list[str] = []
570    for _, hints in sorted(unsorted_dispatched_hints.items()):
571        if len(hints) > 1:
572            hints = ["@overload\n" + h for h in hints]
573        dispatched_hints += hints
574
575    fm.write_with_template(
576        "torch/nn/functional.pyi",
577        "torch/nn/functional.pyi.in",
578        lambda: {
579            "imported_hints": imported_hints,
580            "dispatched_hints": dispatched_hints,
581        },
582    )
583    fm.write_with_template(
584        "torch/_C/_nn.pyi",
585        "torch/_C/_nn.pyi.in",
586        lambda: {
587            "c_nn_function_hints": c_nn_function_hints,
588        },
589    )
590
591
592"""
593We gather the docstrings for torch with the following steps:
5941. Mock torch and torch._C, which are the only dependencies of the docs files
5952. Mock the _add_docstr function to save the docstrings
5963. Import the docs files to trigger mocked _add_docstr and collect docstrings
597"""
598
599
600def gather_docstrs() -> dict[str, str]:
601    docstrs = {}
602
603    def mock_add_docstr(func: Mock, docstr: str) -> None:
604        docstrs[func._extract_mock_name()] = docstr.strip()
605
606    # sys.modules and sys.path are restored after the context manager exits
607    with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]):
608        # mock the torch module and torch._C._add_docstr
609        sys.modules["torch"] = Mock(name="torch")
610        sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr)
611
612        try:
613            # manually import torch._torch_docs and torch._tensor_docs to trigger
614            # the mocked _add_docstr and collect docstrings
615            sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs")
616            sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs")
617        except ModuleNotFoundError:
618            # Gracefully fail if these modules are not importable
619            warn(
620                "Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files."
621            )
622
623    return docstrs
624
625
626def add_docstr_to_hint(docstr: str, hint: str) -> str:
627    if "..." in hint:  # function or method
628        assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
629        hint = hint[:-3]  # remove "..."
630        return "\n    ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
631    else:  # attribute or property
632        return f'{hint}\nr"""{docstr}"""\n'
633
634
635def gen_pyi(
636    native_yaml_path: str,
637    tags_yaml_path: str,
638    deprecated_yaml_path: str,
639    fm: FileManager,
640) -> None:
641    """gen_pyi()
642
643    This function generates a pyi file for torch.
644    """
645
646    # Some of this logic overlaps with generate_python_signature in
647    # tools/autograd/gen_python_functions.py; however, this
648    # function is all about generating mypy type signatures, whereas
649    # the other function generates are custom format for argument
650    # checking.  If you are update this, consider if your change
651    # also needs to update the other file.
652
653    # Dictionary for NamedTuple definitions
654    structseqs: dict[str, str] = {}
655
656    # Generate type signatures for top-level functions
657    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658
659    unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
660
661    for n, n1, n2 in [
662        ("csr", "crow", "col"),
663        ("csc", "ccol", "row"),
664        ("bsr", "crow", "col"),
665        ("bsc", "ccol", "row"),
666    ]:
667        unsorted_function_hints.update(
668            {
669                f"sparse_{n}_tensor": [
670                    f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format(
671                        ", ".join(
672                            [
673                                f"{n1}_indices: Union[Tensor, List]",
674                                f"{n2}_indices: Union[Tensor, List]",
675                                "values: Union[Tensor, List]",
676                                "size: Optional[_size] = None",
677                                "*",
678                                "dtype: Optional[_dtype] = None",
679                                "device: Optional[DeviceLikeType] = None",
680                                "requires_grad: _bool = False",
681                                "check_invariants: Optional[_bool] = None",
682                            ]
683                        ),
684                    )
685                ],
686            }
687        )
688
689    unsorted_function_hints.update(
690        {
691            "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."],
692            "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."],
693            "asarray": [
694                "def asarray({}) -> Tensor: ...".format(
695                    ", ".join(
696                        [
697                            "obj: Any",
698                            "*",
699                            "dtype: Optional[_dtype] = None",
700                            "device: Optional[DeviceLikeType] = None",
701                            "copy: Optional[_bool] = None",
702                            "requires_grad: _bool = False",
703                        ]
704                    )
705                )
706            ],
707            "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."],
708            "frombuffer": [
709                "def frombuffer({}) -> Tensor: ...".format(
710                    ", ".join(
711                        [
712                            "buffer: Any",
713                            "*",
714                            "dtype: _dtype",
715                            "count: int = -1",
716                            "offset: int = 0",
717                            "requires_grad: _bool = False",
718                        ]
719                    )
720                )
721            ],
722            "numel": ["def numel(self: Tensor) -> _int: ..."],
723            "as_tensor": [
724                "def as_tensor({}) -> Tensor: ...".format(
725                    ", ".join(
726                        [
727                            "data: Any",
728                            "dtype: Optional[_dtype] = None",
729                            DEVICE_PARAM,
730                        ]
731                    )
732                )
733            ],
734            "get_num_threads": ["def get_num_threads() -> _int: ..."],
735            "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."],
736            "init_num_threads": ["def init_num_threads() -> None: ..."],
737            "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."],
738            "set_num_interop_threads": [
739                "def set_num_interop_threads(num: _int) -> None: ..."
740            ],
741            # These functions are explicitly disabled by
742            # SKIP_PYTHON_BINDINGS because they are hand bound.
743            # Correspondingly, we must hand-write their signatures.
744            "tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."],
745            "sparse_coo_tensor": [
746                "def sparse_coo_tensor({}) -> Tensor: ...".format(
747                    ", ".join(
748                        [
749                            "indices: Tensor",
750                            "values: Union[Tensor, List]",
751                            "size: Optional[_size] = None",
752                            "*",
753                            "dtype: Optional[_dtype] = None",
754                            "device: Optional[DeviceLikeType] = None",
755                            "requires_grad: _bool = False",
756                            "check_invariants: Optional[_bool] = None",
757                            "is_coalesced: Optional[_bool] = None",
758                        ]
759                    )
760                )
761            ],
762            "sparse_compressed_tensor": [
763                "def sparse_compressed_tensor({}) -> Tensor: ...".format(
764                    ", ".join(
765                        [
766                            "compressed_indices: Union[Tensor, List]",
767                            "plain_indices: Union[Tensor, List]",
768                            "values: Union[Tensor, List]",
769                            "size: Optional[_size] = None",
770                            "*",
771                            "dtype: Optional[_dtype] = None",
772                            "layout: Optional[_layout] = None",
773                            "device: Optional[DeviceLikeType] = None",
774                            "requires_grad: _bool = False",
775                            "check_invariants: Optional[_bool] = None",
776                        ]
777                    )
778                )
779            ],
780            "_sync": ["def _sync(t: Tensor) -> None: ..."],
781            "_is_functional_tensor": [
782                "def _is_functional_tensor(t: Tensor) -> _bool: ..."
783            ],
784            "_is_functional_tensor_base": [
785                "def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
786            ],
787            "_from_functional_tensor": [
788                "def _from_functional_tensor(t: Tensor) -> Tensor: ..."
789            ],
790            "_to_functional_tensor": [
791                "def _to_functional_tensor(t: Tensor) -> Tensor: ..."
792            ],
793            "_functionalize_replace": [
794                "def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..."
795            ],
796            "_functionalize_commit_update": [
797                "def _functionalize_commit_update(t: Tensor) -> None: ..."
798            ],
799            "_functionalize_unsafe_set": [
800                "def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
801            ],
802            "_functionalize_mark_mutation_hidden_from_autograd": [
803                "def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..."
804            ],
805            "_functionalize_are_all_mutations_hidden_from_autograd": [
806                "def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ..."
807            ],
808            "_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [
809                "def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ..."
810            ],
811            "_functionalize_was_inductor_storage_resized": [
812                "def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ..."
813            ],
814            "_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."],
815            "_functionalize_was_storage_changed": [
816                "def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..."
817            ],
818            "_functionalize_set_storage_changed": [
819                "def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..."
820            ],
821            "_functionalize_has_metadata_mutation": [
822                "def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..."
823            ],
824            "_functionalize_apply_view_metas": [
825                "def _functionalize_apply_view_metas(tensor: Tensor,  base: Tensor) -> Tensor: ..."
826            ],
827            "_functionalize_is_symbolic": [
828                "def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..."
829            ],
830            "_enable_functionalization": [
831                "def _enable_functionalization(*, reapply_views: _bool = False): ..."
832            ],
833            "_disable_functionalization": ["def _disable_functionalization(): ..."],
834            "range": [
835                "def range({}) -> Tensor: ...".format(
836                    ", ".join(
837                        [
838                            "start: Number",
839                            "end: Number",
840                            "step: Number = 1",
841                            "*",
842                            "out: Optional[Tensor] = None",
843                            FACTORY_PARAMS,
844                        ]
845                    )
846                )
847            ],
848            "arange": [
849                "def arange({}) -> Tensor: ...".format(
850                    ", ".join(
851                        [
852                            "start: Number",
853                            "end: Number",
854                            "step: Number",
855                            "*",
856                            "out: Optional[Tensor] = None",
857                            FACTORY_PARAMS,
858                        ]
859                    )
860                ),
861                "def arange({}) -> Tensor: ...".format(
862                    ", ".join(
863                        [
864                            "start: Number",
865                            "end: Number",
866                            "*",
867                            "out: Optional[Tensor] = None",
868                            FACTORY_PARAMS,
869                        ]
870                    )
871                ),
872                "def arange({}) -> Tensor: ...".format(
873                    ", ".join(
874                        [
875                            "end: Number",
876                            "*",
877                            "out: Optional[Tensor] = None",
878                            FACTORY_PARAMS,
879                        ]
880                    )
881                ),
882            ],
883            "linspace": [
884                "def linspace({}) -> Tensor: ...".format(
885                    ", ".join(
886                        [
887                            "start: Number",
888                            "end: Number",
889                            "steps: Optional[_int] = None",
890                            "*",
891                            "out: Optional[Tensor] = None",
892                            FACTORY_PARAMS,
893                        ]
894                    )
895                )
896            ],
897            "logspace": [
898                "def logspace({}) -> Tensor: ...".format(
899                    ", ".join(
900                        [
901                            "start: Number",
902                            "end: Number",
903                            "steps: Optional[_int] = None",
904                            "base: _float = 10.0",
905                            "*",
906                            "out: Optional[Tensor] = None",
907                            FACTORY_PARAMS,
908                        ]
909                    )
910                )
911            ],
912            "randint": [
913                "def randint({}) -> Tensor: ...".format(
914                    ", ".join(
915                        [
916                            "low: _int",
917                            "high: _int",
918                            "size: _size",
919                            "*",
920                            "generator: Optional[Generator] = None",
921                            FACTORY_PARAMS,
922                        ]
923                    )
924                ),
925                "def randint({}) -> Tensor: ...".format(
926                    ", ".join(
927                        [
928                            "high: _int",
929                            "size: _size",
930                            "*",
931                            "generator: Optional[Generator] = None",
932                            FACTORY_PARAMS,
933                        ]
934                    )
935                ),
936            ],
937            "full": [
938                "def full({}) -> Tensor: ...".format(
939                    ", ".join(
940                        [
941                            "size: _size",
942                            "fill_value: Union[Number, _complex]",
943                            "*",
944                            "out: Optional[Tensor] = None",
945                            "layout: _layout = strided",
946                            FACTORY_PARAMS,
947                        ]
948                    )
949                ),
950                "def full({}) -> Tensor: ...".format(
951                    ", ".join(
952                        [
953                            "size: _size",
954                            "fill_value: Union[Number, _complex]",
955                            "*",
956                            "names: List[Union[str, None]]",
957                            "layout: _layout = strided",
958                            FACTORY_PARAMS,
959                        ]
960                    )
961                ),
962            ],
963            "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."],
964            "is_inference_mode_enabled": [
965                "def is_inference_mode_enabled() -> _bool: ..."
966            ],
967            "nonzero": [
968                "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...",
969                "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
970            ],
971            "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
972            "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
973            "saddmm": [
974                "def saddmm({}) -> Tensor: ...".format(
975                    ", ".join(
976                        [
977                            "input: Tensor",
978                            "mat1: Tensor",
979                            "mat2: Tensor",
980                            "*",
981                            "beta: Number = 1",
982                            "alpha: Number = 1",
983                            "out: Optional[Tensor] = None",
984                        ]
985                    )
986                )
987            ],
988            "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
989            "div": [
990                "def div({}) -> Tensor: ...".format(
991                    ", ".join(
992                        [
993                            "input: Union[Tensor, Number]",
994                            "other: Union[Tensor, Number]",
995                            "*",
996                            "rounding_mode: Optional[str] = None",
997                            "out: Optional[Tensor] = None",
998                        ]
999                    )
1000                )
1001            ],
1002        }
1003    )
1004    for binop in ["true_divide", "floor_divide"]:
1005        unsorted_function_hints[binop].append(
1006            f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], "
1007            "*, out: Optional[Tensor] = None) -> Tensor: ..."
1008        )
1009    for binop in ["mul"]:
1010        unsorted_function_hints[binop].append(
1011            f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], "
1012            "*, out: Optional[Tensor] = None) -> Tensor: ..."
1013        )
1014    for binop in ["add", "sub"]:
1015        unsorted_function_hints[binop].append(
1016            f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], "
1017            "*, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: ..."
1018        )
1019
1020    native_functions = parse_native_yaml(
1021        native_yaml_path, tags_yaml_path
1022    ).native_functions
1023    native_functions = list(filter(should_generate_py_binding, native_functions))
1024
1025    function_signatures = load_signatures(
1026        native_functions, deprecated_yaml_path, method=False, pyi=True
1027    )
1028    sig_groups = get_py_torch_functions(function_signatures)
1029    for group in sorted(sig_groups, key=lambda g: g.signature.name):
1030        name = group.signature.name
1031        unsorted_function_hints[name] += generate_type_hints(group)
1032
1033        structseq = returns_structseq_pyi(group.signature)
1034        if structseq is not None and not group.signature.deprecated:
1035            # deprecated structseqs are currently not included for torch functions
1036            tuple_name, tuple_def = structseq
1037            if tuple_name in structseqs:
1038                assert structseqs[tuple_name] == tuple_def
1039            else:
1040                structseqs[tuple_name] = tuple_def
1041
1042    def replace_special_case(hint: str) -> str:
1043        # NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h
1044        hint = hint.replace("at::Reduction::Mean", "1")
1045        hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None")
1046        # Match both:
1047        # ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None"
1048        # ": Union[Tuple[Tensor, ...], List[Tensor]] = None"
1049        hint = hint.replace(
1050            "Tuple[Tensor, ...], List[Tensor]] = None",
1051            "Tuple[Tensor, ...], List[Tensor], None] = None",
1052        )
1053        return hint
1054
1055    docstrs = gather_docstrs()
1056    function_hints = []
1057    for name, hints in sorted(unsorted_function_hints.items()):
1058        hints = [replace_special_case(h) for h in hints]
1059        if len(hints) > 1:
1060            hints = ["@overload\n" + h for h in hints]
1061        docstr = docstrs.get(f"torch.{name}")
1062        if docstr is not None:
1063            hints = [add_docstr_to_hint(docstr, h) for h in hints]
1064        function_hints += hints
1065
1066    # Generate type signatures for Tensor methods
1067    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1068
1069    unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
1070    unsorted_tensor_method_hints.update(
1071        {
1072            "size": [
1073                "def size(self, dim: None = None) -> Size: ...",
1074                "def size(self, dim: _int) -> _int: ...",
1075            ],
1076            "stride": [
1077                "def stride(self, dim: None = None) -> Tuple[_int, ...]: ...",
1078                "def stride(self, dim: _int) -> _int: ...",
1079            ],
1080            "new_ones": [
1081                f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..."
1082            ],
1083            "new_tensor": [
1084                f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..."
1085            ],
1086            "__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."],
1087            # new and __init__ have the same signatures differ only in return type
1088            # Adapted from legacy_tensor_ctor and legacy_tensor_new
1089            "new": [
1090                f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...",
1091                "def new(cls, storage: Storage) -> Self: ...",
1092                "def new(cls, other: Tensor) -> Self: ...",
1093                f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...",
1094            ],
1095            "__init__": [
1096                f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...",
1097                "def __init__(self, storage: Storage) -> None: ...",
1098                "def __init__(self, other: Tensor) -> None: ...",
1099                f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...",
1100            ],
1101            "as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."],
1102            "_make_subclass": [
1103                "@staticmethod    \ndef _make_subclass({}) -> S: ...".format(
1104                    ", ".join(
1105                        [
1106                            "cls: _Type[S]",
1107                            "data: Tensor",
1108                            "require_grad: _bool = False",
1109                            "dispatch_strides: _bool = False",
1110                            "dispatch_device: _bool = False",
1111                            "device_for_backend_keys: Optional[_device] = None",
1112                        ]
1113                    )
1114                )
1115            ],
1116            "__contains__": ["def __contains__(self, other: Any, /) -> _bool: ..."],
1117            "__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."],
1118            "__setitem__": [
1119                f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..."
1120            ],
1121            "tolist": ["def tolist(self) -> List: ..."],
1122            "requires_grad_": [
1123                "def requires_grad_(self, mode: _bool = True) -> Tensor: ..."
1124            ],
1125            "element_size": ["def element_size(self) -> _int: ..."],
1126            "data_ptr": ["def data_ptr(self) -> _int: ..."],
1127            "dim": ["def dim(self) -> _int: ..."],
1128            "nonzero": [
1129                "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...",
1130                "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
1131            ],
1132            "numel": ["def numel(self) -> _int: ..."],
1133            "ndimension": ["def ndimension(self) -> _int: ..."],
1134            "nelement": ["def nelement(self) -> _int: ..."],
1135            "cuda": [
1136                "def cuda({}) -> Tensor: ...".format(
1137                    ", ".join(
1138                        [
1139                            "self",
1140                            "device: Optional[Union[_device, _int, str]] = None",
1141                            "non_blocking: _bool = False",
1142                            "memory_format: torch.memory_format = torch.preserve_format",
1143                        ]
1144                    )
1145                )
1146            ],
1147            "xpu": [
1148                "def xpu({}) -> Tensor: ...".format(
1149                    ", ".join(
1150                        [
1151                            "self",
1152                            "device: Optional[Union[_device, _int, str]] = None",
1153                            "non_blocking: _bool = False",
1154                            "memory_format: torch.memory_format = torch.preserve_format",
1155                        ]
1156                    )
1157                )
1158            ],
1159            "cpu": [
1160                "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..."
1161            ],
1162            "numpy": ["def numpy(self, *, force: _bool = False) -> numpy.ndarray: ..."],
1163            "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
1164            "map_": [
1165                "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."
1166            ],
1167            "map2_": [
1168                "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
1169            ],
1170            "storage": ["def untyped_storage(self) -> UntypedStorage: ..."],
1171            "storage_type": ["def storage_type(self) -> Storage: ..."],
1172            "type": [
1173                "def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...",
1174                "def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...",
1175            ],
1176            "get_device": ["def get_device(self) -> _int: ..."],
1177            "contiguous": [
1178                "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..."
1179            ],
1180            "has_names": ["def has_names(self) -> _bool: ..."],
1181            "is_contiguous": [
1182                "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
1183            ],
1184            "_is_view": ["def _is_view(self) -> _bool: ..."],
1185            "is_cpu": ["is_cpu: _bool"],
1186            "is_cuda": ["is_cuda: _bool"],
1187            "is_leaf": ["is_leaf: _bool"],
1188            "is_nested": ["is_nested: _bool"],
1189            "is_sparse": ["is_sparse: _bool"],
1190            "is_sparse_csr": ["is_sparse_csr: _bool"],
1191            "is_quantized": ["is_quantized: _bool"],
1192            "is_meta": ["is_meta: _bool"],
1193            "is_mps": ["is_mps: _bool"],
1194            "is_mtia": ["is_mtia: _bool"],
1195            "is_maia": ["is_maia: _bool"],
1196            "is_mkldnn": ["is_mkldnn: _bool"],
1197            "is_vulkan": ["is_vulkan: _bool"],
1198            "is_ipu": ["is_ipu: _bool"],
1199            "storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."],
1200            "to": [
1201                (
1202                    f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, "
1203                    "memory_format: Optional[torch.memory_format] = None) -> Tensor: ..."
1204                )
1205                for args in [
1206                    "dtype: _dtype",
1207                    "device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None",
1208                    "other: Tensor",
1209                ]
1210            ],
1211            "item": ["def item(self) -> Number: ..."],
1212            "copy_": [
1213                "def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..."
1214            ],
1215            "set_": [
1216                "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
1217                "offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...",
1218                "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
1219            ],
1220            "split": [
1221                "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...",
1222                "def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...",
1223            ],
1224            "div": [
1225                "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
1226            ],
1227            "div_": [
1228                "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
1229            ],
1230        }
1231    )
1232    for binop in ["true_divide", "floor_divide"]:
1233        for inplace in [False, True]:
1234            out_suffix = ", *, out: Optional[Tensor] = None"
1235            if inplace:
1236                binop += "_"
1237                out_suffix = ""
1238            unsorted_tensor_method_hints[binop].append(
1239                f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})"
1240                " -> Tensor: ..."
1241            )
1242    for binop in ["mul"]:
1243        for inplace in [False, True]:
1244            out_suffix = ", *, out: Optional[Tensor] = None"
1245            if inplace:
1246                binop += "_"
1247                out_suffix = ""
1248            unsorted_tensor_method_hints[binop].append(
1249                f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat]{out_suffix})"
1250                " -> Tensor: ..."
1251            )
1252    for binop in ["add", "sub"]:
1253        for inplace in [False, True]:
1254            out_suffix = ", out: Optional[Tensor] = None"
1255            if inplace:
1256                binop += "_"
1257                out_suffix = ""
1258            unsorted_tensor_method_hints[binop].append(
1259                f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], "
1260                f"*, alpha: Optional[Union[Number, _complex]] = 1{out_suffix})"
1261                " -> Tensor: ..."
1262            )
1263    simple_conversions = [
1264        "byte",
1265        "char",
1266        "double",
1267        "float",
1268        "half",
1269        "int",
1270        "long",
1271        "short",
1272        "bool",
1273        "bfloat16",
1274    ]
1275    for name in simple_conversions:
1276        unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
1277
1278    # pyi tensor methods don't currently include deprecated signatures for some reason
1279    # TODO: we should probably add them in
1280    tensor_method_signatures = load_signatures(
1281        native_functions,
1282        deprecated_yaml_path,
1283        method=True,
1284        skip_deprecated=True,
1285        pyi=True,
1286    )
1287    tensor_method_sig_groups = get_py_torch_functions(
1288        tensor_method_signatures, method=True
1289    )
1290
1291    for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
1292        name = group.signature.name
1293        unsorted_tensor_method_hints[name] += generate_type_hints(group)
1294
1295        structseq = returns_structseq_pyi(group.signature)
1296        if structseq is not None and not group.signature.deprecated:
1297            # deprecated structseqs are currently not included for torch functions
1298            tuple_name, tuple_def = structseq
1299            if tuple_name in structseqs:
1300                assert structseqs[tuple_name] == tuple_def
1301            else:
1302                structseqs[tuple_name] = tuple_def
1303
1304    for op in all_ops:
1305        name = f"__{op}__"
1306        unsorted_tensor_method_hints[name] += sig_for_ops(name)
1307
1308    tensor_method_hints = []
1309    for name, hints in sorted(unsorted_tensor_method_hints.items()):
1310        if len(hints) > 1:
1311            hints = ["@overload\n" + h for h in hints]
1312        docstr = docstrs.get(f"torch._C.TensorBase.{name}")
1313        if docstr is not None:
1314            hints = [add_docstr_to_hint(docstr, h) for h in hints]
1315        tensor_method_hints += hints
1316
1317    # TODO: Missing type hints for nn
1318
1319    # Generate structseq definitions
1320    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1321
1322    structseq_defs = [f"{defn}\n" for defn in structseqs.values()]
1323
1324    # Generate type signatures for legacy classes
1325    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1326
1327    legacy_storage_base_hints = ["class StorageBase(object): ..."]
1328
1329    legacy_class_hints = []
1330    for c in (
1331        "DoubleTensor",
1332        "FloatTensor",
1333        "BFloat16Tensor",
1334        "LongTensor",
1335        "IntTensor",
1336        "ShortTensor",
1337        "HalfTensor",
1338        "CharTensor",
1339        "ByteTensor",
1340        "BoolTensor",
1341    ):
1342        legacy_class_hints.append(f"class {c}(Tensor): ...")
1343
1344    # Generate type signatures for dtype classes
1345    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1346
1347    # TODO: don't explicitly list dtypes here; get it from canonical
1348    # source
1349    dtype_class_hints = [
1350        f"{n}: dtype = ..."
1351        for n in [
1352            "float32",
1353            "float",
1354            "float64",
1355            "double",
1356            "float16",
1357            "bfloat16",
1358            "float8_e4m3fn",
1359            "float8_e4m3fnuz",
1360            "float8_e5m2",
1361            "float8_e5m2fnuz",
1362            "half",
1363            "uint8",
1364            "uint16",
1365            "uint32",
1366            "uint64",
1367            "int8",
1368            "int16",
1369            "short",
1370            "int32",
1371            "int",
1372            "int64",
1373            "long",
1374            "complex32",
1375            "complex64",
1376            "chalf",
1377            "cfloat",
1378            "complex128",
1379            "cdouble",
1380            "quint8",
1381            "qint8",
1382            "qint32",
1383            "bool",
1384            "quint4x2",
1385            "quint2x4",
1386            "bits1x8",
1387            "bits2x4",
1388            "bits4x2",
1389            "bits8",
1390            "bits16",
1391        ]
1392    ]
1393
1394    # Generate __all__ directive
1395    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1396
1397    # Include only the functions that contain hints, to prevent undefined
1398    # symbols to be included in the `__all__` directive.
1399    hinted_function_names = [
1400        name for name, hint in unsorted_function_hints.items() if hint
1401    ]
1402    all_symbols = sorted(list(structseqs.keys()) + hinted_function_names)
1403    all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
1404    all_directive[0] = f"__all__ = {all_directive[0]}"
1405
1406    # Dispatch key hints
1407    # ~~~~~~~~~~~~~~~~~~
1408    dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey]
1409    torch_dispatch_mode_key_hints = [
1410        f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey
1411    ]
1412
1413    # Tags Enum type hints
1414    # ~~~~~~~~~~~~~~~~~~~~
1415
1416    tag_names = sorted(parse_tags_yaml(tags_yaml_path))
1417    tag_attributes = "\n".join(
1418        f"{name}: _int = {index}" for index, name in enumerate(tag_names)
1419    )
1420
1421    # Write out the stub
1422    # ~~~~~~~~~~~~~~~~~~
1423
1424    env = {
1425        "structseq_defs": structseq_defs,
1426        "function_hints": function_hints,
1427        "tensor_method_hints": tensor_method_hints,
1428        "legacy_class_hints": legacy_class_hints,
1429        "legacy_storage_base_hints": legacy_storage_base_hints,
1430        "dtype_class_hints": dtype_class_hints,
1431        "dispatch_key_hints": dispatch_key_hints,
1432        "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints,
1433        "all_directive": all_directive,
1434        "tag_attributes": tag_attributes,
1435    }
1436    fm.write_with_template(
1437        "torch/_C/__init__.pyi",
1438        "torch/_C/__init__.pyi.in",
1439        lambda: env,
1440    )
1441    fm.write_with_template(
1442        "torch/_C/_VariableFunctions.pyi",
1443        "torch/_C/_VariableFunctions.pyi.in",
1444        lambda: env,
1445    )
1446    fm.write_with_template(
1447        "torch/_VF.pyi",
1448        "torch/_C/_VariableFunctions.pyi.in",
1449        lambda: env,
1450    )
1451    fm.write_with_template(
1452        "torch/return_types.pyi",
1453        "torch/_C/return_types.pyi.in",
1454        lambda: env,
1455    )
1456    gen_nn_functional(fm)
1457
1458
1459def main() -> None:
1460    parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
1461    parser.add_argument(
1462        "--native-functions-path",
1463        metavar="NATIVE",
1464        default="aten/src/ATen/native/native_functions.yaml",
1465        help="path to native_functions.yaml",
1466    )
1467    parser.add_argument(
1468        "--tags-path",
1469        metavar="TAGS",
1470        default="aten/src/ATen/native/tags.yaml",
1471        help="path to tags.yaml",
1472    )
1473    parser.add_argument(
1474        "--deprecated-functions-path",
1475        metavar="DEPRECATED",
1476        default="tools/autograd/deprecated.yaml",
1477        help="path to deprecated.yaml",
1478    )
1479    parser.add_argument(
1480        "--out", metavar="OUT", default=".", help="path to output directory"
1481    )
1482    args = parser.parse_args()
1483    fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
1484    gen_pyi(
1485        args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm
1486    )
1487
1488
1489if __name__ == "__main__":
1490    main()
1491