1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport argparse 4*da0073e9SAndroid Build Coastguard Workerimport collections 5*da0073e9SAndroid Build Coastguard Workerimport importlib 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerfrom pprint import pformat 8*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence 9*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import Mock, patch 10*da0073e9SAndroid Build Coastguard Workerfrom warnings import warn 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerfrom tools.autograd.gen_python_functions import ( 13*da0073e9SAndroid Build Coastguard Worker group_overloads, 14*da0073e9SAndroid Build Coastguard Worker load_signatures, 15*da0073e9SAndroid Build Coastguard Worker should_generate_py_binding, 16*da0073e9SAndroid Build Coastguard Worker) 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.python import ( 19*da0073e9SAndroid Build Coastguard Worker PythonSignatureGroup, 20*da0073e9SAndroid Build Coastguard Worker PythonSignatureNativeFunctionPair, 21*da0073e9SAndroid Build Coastguard Worker returns_structseq_pyi, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import parse_native_yaml, parse_tags_yaml 24*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant 25*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import FileManager 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker""" 29*da0073e9SAndroid Build Coastguard WorkerThis module implements generation of type stubs for PyTorch, 30*da0073e9SAndroid Build Coastguard Workerenabling use of autocomplete in IDEs like PyCharm, which otherwise 31*da0073e9SAndroid Build Coastguard Workerdon't understand C extension modules. 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard WorkerAt the moment, this module only handles type stubs for torch and 34*da0073e9SAndroid Build Coastguard Workertorch.Tensor. It should eventually be expanded to cover all functions 35*da0073e9SAndroid Build Coastguard Workerwhich come are autogenerated. 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard WorkerHere's our general strategy: 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker- We start off with a hand-written __init__.pyi.in file. This 40*da0073e9SAndroid Build Coastguard Worker file contains type definitions for everything we cannot automatically 41*da0073e9SAndroid Build Coastguard Worker generate, including pure Python definitions directly in __init__.py 42*da0073e9SAndroid Build Coastguard Worker (the latter case should be pretty rare). 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker- We go through automatically bound functions based on the 45*da0073e9SAndroid Build Coastguard Worker type information recorded in native_functions.yaml and 46*da0073e9SAndroid Build Coastguard Worker generate type hints for them (generate_type_hints) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard WorkerThere are a number of type hints which we've special-cased; 49*da0073e9SAndroid Build Coastguard Workerread gen_pyi for the gory details. 50*da0073e9SAndroid Build Coastguard Worker""" 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef get_py_torch_functions( 54*da0073e9SAndroid Build Coastguard Worker python_funcs: Sequence[PythonSignatureNativeFunctionPair], 55*da0073e9SAndroid Build Coastguard Worker method: bool = False, 56*da0073e9SAndroid Build Coastguard Worker) -> Sequence[PythonSignatureGroup]: 57*da0073e9SAndroid Build Coastguard Worker """ 58*da0073e9SAndroid Build Coastguard Worker Get declarations (grouped by name) which should be generated 59*da0073e9SAndroid Build Coastguard Worker as either functions in the "torch" module or methods on Tensor. 60*da0073e9SAndroid Build Coastguard Worker """ 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: 63*da0073e9SAndroid Build Coastguard Worker return ( 64*da0073e9SAndroid Build Coastguard Worker should_generate_py_binding(python_func.function) 65*da0073e9SAndroid Build Coastguard Worker and not python_func.function.python_module 66*da0073e9SAndroid Build Coastguard Worker and Variant.function in python_func.function.variants 67*da0073e9SAndroid Build Coastguard Worker ) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: 70*da0073e9SAndroid Build Coastguard Worker return ( 71*da0073e9SAndroid Build Coastguard Worker should_generate_py_binding(python_func.function) 72*da0073e9SAndroid Build Coastguard Worker and not python_func.function.python_module 73*da0073e9SAndroid Build Coastguard Worker and Variant.method in python_func.function.variants 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker should_bind = should_bind_method if method else should_bind_function 77*da0073e9SAndroid Build Coastguard Worker return group_overloads([f for f in python_funcs if should_bind(f)]) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker# TODO: Consider defining some aliases for our Union[...] types, to make 81*da0073e9SAndroid Build Coastguard Worker# the stubs to read on the human eye. 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard WorkerDEVICE_PARAM = "device: Optional[DeviceLikeType] = None" 84*da0073e9SAndroid Build Coastguard WorkerFACTORY_PARAMS = f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False, pin_memory: _bool = False" 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker# NOTE: specifying indices for Tensor.__getitem__ 87*da0073e9SAndroid Build Coastguard Worker# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi: 88*da0073e9SAndroid Build Coastguard Worker# 89*da0073e9SAndroid Build Coastguard Worker# key: ( 90*da0073e9SAndroid Build Coastguard Worker# None 91*da0073e9SAndroid Build Coastguard Worker# | slice 92*da0073e9SAndroid Build Coastguard Worker# | ellipsis 93*da0073e9SAndroid Build Coastguard Worker# | SupportsIndex 94*da0073e9SAndroid Build Coastguard Worker# | _ArrayLikeInt_co 95*da0073e9SAndroid Build Coastguard Worker# | tuple[None | slice | ellipsis | _ArrayLikeInt_co | SupportsIndex, ...] 96*da0073e9SAndroid Build Coastguard Worker# ) 97*da0073e9SAndroid Build Coastguard Worker# 98*da0073e9SAndroid Build Coastguard Worker# where: 99*da0073e9SAndroid Build Coastguard Worker# 100*da0073e9SAndroid Build Coastguard Worker# _ArrayLikeInt_co = _DualArrayLike[ 101*da0073e9SAndroid Build Coastguard Worker# dtype[Union[bool_, integer[Any]]], 102*da0073e9SAndroid Build Coastguard Worker# Union[bool, int], 103*da0073e9SAndroid Build Coastguard Worker# ] 104*da0073e9SAndroid Build Coastguard Worker# 105*da0073e9SAndroid Build Coastguard Worker# and 106*da0073e9SAndroid Build Coastguard Worker# 107*da0073e9SAndroid Build Coastguard Worker# _DualArrayLike = Union[ 108*da0073e9SAndroid Build Coastguard Worker# _SupportsArray[_DType], 109*da0073e9SAndroid Build Coastguard Worker# _NestedSequence[_SupportsArray[_DType]], 110*da0073e9SAndroid Build Coastguard Worker# _T, 111*da0073e9SAndroid Build Coastguard Worker# _NestedSequence[_T], 112*da0073e9SAndroid Build Coastguard Worker# ] 113*da0073e9SAndroid Build Coastguard Worker# 114*da0073e9SAndroid Build Coastguard Worker# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple. 115*da0073e9SAndroid Build Coastguard Worker# We can substitute and simplify: 116*da0073e9SAndroid Build Coastguard Worker# _SupportsArray -> Tensor 117*da0073e9SAndroid Build Coastguard Worker# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]] 118*da0073e9SAndroid Build Coastguard Worker# which leaves us with key: T | tuple[T, ...], where T is: 119*da0073e9SAndroid Build Coastguard Worker# T = ( 120*da0073e9SAndroid Build Coastguard Worker# None | bool | int | slice | ellipsis | SupportsIndex 121*da0073e9SAndroid Build Coastguard Worker# | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int] 122*da0073e9SAndroid Build Coastguard Worker# ) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker# NOTE: ellipsis is equal to type[Ellipsis] in stub files. 125*da0073e9SAndroid Build Coastguard Worker_leaf_types = "Union[None, _bool, _int, slice, ellipsis, Tensor]" # not SupportsIndex! 126*da0073e9SAndroid Build Coastguard Worker_index = f"Union[SupportsIndex, {_leaf_types}, _NestedSequence[{_leaf_types}]]" 127*da0073e9SAndroid Build Coastguard WorkerINDICES = f"indices: Union[{_index}, tuple[{_index}, ...]]" 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Workerblocklist = [ 130*da0073e9SAndroid Build Coastguard Worker "__init_subclass__", 131*da0073e9SAndroid Build Coastguard Worker "__new__", 132*da0073e9SAndroid Build Coastguard Worker "__subclasshook__", 133*da0073e9SAndroid Build Coastguard Worker "cdist", 134*da0073e9SAndroid Build Coastguard Worker "device", 135*da0073e9SAndroid Build Coastguard Worker "grad", 136*da0073e9SAndroid Build Coastguard Worker "requires_grad", 137*da0073e9SAndroid Build Coastguard Worker "range", 138*da0073e9SAndroid Build Coastguard Worker # defined in functional 139*da0073e9SAndroid Build Coastguard Worker "einsum", 140*da0073e9SAndroid Build Coastguard Worker # Somehow, these are defined in both _C and in functional. Ick! 141*da0073e9SAndroid Build Coastguard Worker "broadcast_tensors", 142*da0073e9SAndroid Build Coastguard Worker # Manually define named tensor type stubs in __init__.pyi.in 143*da0073e9SAndroid Build Coastguard Worker "align_tensors", 144*da0073e9SAndroid Build Coastguard Worker "meshgrid", 145*da0073e9SAndroid Build Coastguard Worker "cartesian_prod", 146*da0073e9SAndroid Build Coastguard Worker "block_diag", 147*da0073e9SAndroid Build Coastguard Worker "norm", 148*da0073e9SAndroid Build Coastguard Worker "chain_matmul", 149*da0073e9SAndroid Build Coastguard Worker "stft", 150*da0073e9SAndroid Build Coastguard Worker "tensordot", 151*da0073e9SAndroid Build Coastguard Worker "split", 152*da0073e9SAndroid Build Coastguard Worker "unique_consecutive", 153*da0073e9SAndroid Build Coastguard Worker "atleast_1d", 154*da0073e9SAndroid Build Coastguard Worker "atleast_2d", 155*da0073e9SAndroid Build Coastguard Worker "atleast_3d", 156*da0073e9SAndroid Build Coastguard Worker # These are handled specially by python_arg_parser.cpp 157*da0073e9SAndroid Build Coastguard Worker "add", 158*da0073e9SAndroid Build Coastguard Worker "add_", 159*da0073e9SAndroid Build Coastguard Worker "add_out", 160*da0073e9SAndroid Build Coastguard Worker "sub", 161*da0073e9SAndroid Build Coastguard Worker "sub_", 162*da0073e9SAndroid Build Coastguard Worker "sub_out", 163*da0073e9SAndroid Build Coastguard Worker "mul", 164*da0073e9SAndroid Build Coastguard Worker "mul_", 165*da0073e9SAndroid Build Coastguard Worker "mul_out", 166*da0073e9SAndroid Build Coastguard Worker "div", 167*da0073e9SAndroid Build Coastguard Worker "div_", 168*da0073e9SAndroid Build Coastguard Worker "div_out", 169*da0073e9SAndroid Build Coastguard Worker "true_divide", 170*da0073e9SAndroid Build Coastguard Worker "true_divide_", 171*da0073e9SAndroid Build Coastguard Worker "true_divide_out", 172*da0073e9SAndroid Build Coastguard Worker "floor_divide", 173*da0073e9SAndroid Build Coastguard Worker "floor_divide_", 174*da0073e9SAndroid Build Coastguard Worker "floor_divide_out", 175*da0073e9SAndroid Build Coastguard Worker "to", 176*da0073e9SAndroid Build Coastguard Worker "_to_copy", 177*da0073e9SAndroid Build Coastguard Worker "copy_", 178*da0073e9SAndroid Build Coastguard Worker] 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Workerbinary_ops = ( 181*da0073e9SAndroid Build Coastguard Worker "add", 182*da0073e9SAndroid Build Coastguard Worker "sub", 183*da0073e9SAndroid Build Coastguard Worker "mul", 184*da0073e9SAndroid Build Coastguard Worker "div", 185*da0073e9SAndroid Build Coastguard Worker "pow", 186*da0073e9SAndroid Build Coastguard Worker "lshift", 187*da0073e9SAndroid Build Coastguard Worker "rshift", 188*da0073e9SAndroid Build Coastguard Worker "mod", 189*da0073e9SAndroid Build Coastguard Worker "truediv", 190*da0073e9SAndroid Build Coastguard Worker "matmul", 191*da0073e9SAndroid Build Coastguard Worker "floordiv", 192*da0073e9SAndroid Build Coastguard Worker "radd", 193*da0073e9SAndroid Build Coastguard Worker "rsub", 194*da0073e9SAndroid Build Coastguard Worker "rmul", 195*da0073e9SAndroid Build Coastguard Worker "rtruediv", 196*da0073e9SAndroid Build Coastguard Worker "rfloordiv", 197*da0073e9SAndroid Build Coastguard Worker "rpow", # reverse arithmetic 198*da0073e9SAndroid Build Coastguard Worker "and", 199*da0073e9SAndroid Build Coastguard Worker "or", 200*da0073e9SAndroid Build Coastguard Worker "xor", 201*da0073e9SAndroid Build Coastguard Worker "rand", 202*da0073e9SAndroid Build Coastguard Worker "ror", 203*da0073e9SAndroid Build Coastguard Worker "rxor", # logic 204*da0073e9SAndroid Build Coastguard Worker "iadd", 205*da0073e9SAndroid Build Coastguard Worker "iand", 206*da0073e9SAndroid Build Coastguard Worker "idiv", 207*da0073e9SAndroid Build Coastguard Worker "ilshift", 208*da0073e9SAndroid Build Coastguard Worker "imul", 209*da0073e9SAndroid Build Coastguard Worker "ior", 210*da0073e9SAndroid Build Coastguard Worker "irshift", 211*da0073e9SAndroid Build Coastguard Worker "isub", 212*da0073e9SAndroid Build Coastguard Worker "ixor", 213*da0073e9SAndroid Build Coastguard Worker "ifloordiv", 214*da0073e9SAndroid Build Coastguard Worker "imod", # inplace ops 215*da0073e9SAndroid Build Coastguard Worker) 216*da0073e9SAndroid Build Coastguard Workersymmetric_comparison_ops = ("eq", "ne") 217*da0073e9SAndroid Build Coastguard Workerasymmetric_comparison_ops = ("ge", "gt", "lt", "le") 218*da0073e9SAndroid Build Coastguard Workercomparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Workerunary_ops = ("neg", "abs", "invert") 221*da0073e9SAndroid Build Coastguard Workerto_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero") 222*da0073e9SAndroid Build Coastguard Workerall_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Workerdef sig_for_ops(opname: str) -> list[str]: 226*da0073e9SAndroid Build Coastguard Worker """sig_for_ops(opname : str) -> List[str] 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker Returns signatures for operator special functions (__add__ etc.)""" 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker # we have to do this by hand, because they are hand-bound in Python 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}" 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker name = opname[2:-2] 235*da0073e9SAndroid Build Coastguard Worker if name in binary_ops: 236*da0073e9SAndroid Build Coastguard Worker return [f"def {opname}(self, other: Any) -> Tensor: ..."] 237*da0073e9SAndroid Build Coastguard Worker elif name in comparison_ops: 238*da0073e9SAndroid Build Coastguard Worker sig = f"def {opname}(self, other: Any) -> Tensor: ..." 239*da0073e9SAndroid Build Coastguard Worker if name in symmetric_comparison_ops: 240*da0073e9SAndroid Build Coastguard Worker # unsafe override https://github.com/python/mypy/issues/5704 241*da0073e9SAndroid Build Coastguard Worker sig += " # type: ignore[override]" 242*da0073e9SAndroid Build Coastguard Worker return [sig] 243*da0073e9SAndroid Build Coastguard Worker elif name in unary_ops: 244*da0073e9SAndroid Build Coastguard Worker return [f"def {opname}(self) -> Tensor: ..."] 245*da0073e9SAndroid Build Coastguard Worker elif name in to_py_type_ops: 246*da0073e9SAndroid Build Coastguard Worker if name in {"bool", "float", "complex"}: 247*da0073e9SAndroid Build Coastguard Worker tname = name 248*da0073e9SAndroid Build Coastguard Worker elif name == "nonzero": 249*da0073e9SAndroid Build Coastguard Worker tname = "bool" 250*da0073e9SAndroid Build Coastguard Worker else: 251*da0073e9SAndroid Build Coastguard Worker tname = "int" 252*da0073e9SAndroid Build Coastguard Worker if tname in {"float", "int", "bool", "complex"}: 253*da0073e9SAndroid Build Coastguard Worker tname = "builtins." + tname 254*da0073e9SAndroid Build Coastguard Worker return [f"def {opname}(self) -> {tname}: ..."] 255*da0073e9SAndroid Build Coastguard Worker else: 256*da0073e9SAndroid Build Coastguard Worker raise Exception("unknown op", opname) # noqa: TRY002 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Workerdef generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]: 260*da0073e9SAndroid Build Coastguard Worker type_hints: list[str] = [] 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker # Some deprecated ops that are on the blocklist are still included in pyi 263*da0073e9SAndroid Build Coastguard Worker if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: 264*da0073e9SAndroid Build Coastguard Worker return type_hints 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker # deprecated signatures have separate entries for their functional and out variants 267*da0073e9SAndroid Build Coastguard Worker # (as opposed to the native ops, which fuse the two into a single signature). 268*da0073e9SAndroid Build Coastguard Worker # generate the functional variant here, if an out variant exists. 269*da0073e9SAndroid Build Coastguard Worker if sig_group.signature.deprecated and sig_group.outplace is not None: 270*da0073e9SAndroid Build Coastguard Worker type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) 271*da0073e9SAndroid Build Coastguard Worker type_hints.append(type_hint) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument 274*da0073e9SAndroid Build Coastguard Worker # Generates the out variant if one exists. Otherwise, generate the functional variant 275*da0073e9SAndroid Build Coastguard Worker type_hint = sig_group.signature.signature_str_pyi( 276*da0073e9SAndroid Build Coastguard Worker skip_outputs=sig_group.outplace is None 277*da0073e9SAndroid Build Coastguard Worker ) 278*da0073e9SAndroid Build Coastguard Worker type_hints.append(type_hint) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker # Some operators also additionally have a vararg variant of their signature 281*da0073e9SAndroid Build Coastguard Worker type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( 282*da0073e9SAndroid Build Coastguard Worker skip_outputs=sig_group.outplace is None 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker if type_hint_vararg: 285*da0073e9SAndroid Build Coastguard Worker type_hints.append(type_hint_vararg) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker return type_hints 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Workerdef get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]: 291*da0073e9SAndroid Build Coastguard Worker flag_pos = arg_list.index("{return_indices}") 292*da0073e9SAndroid Build Coastguard Worker # If return_indices is positional arg, everything before should have no default 293*da0073e9SAndroid Build Coastguard Worker arg_list_positional = ( 294*da0073e9SAndroid Build Coastguard Worker [ 295*da0073e9SAndroid Build Coastguard Worker ", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", ")) 296*da0073e9SAndroid Build Coastguard Worker for arg in arg_list[: flag_pos + 1] 297*da0073e9SAndroid Build Coastguard Worker ] 298*da0073e9SAndroid Build Coastguard Worker + ["/"] 299*da0073e9SAndroid Build Coastguard Worker + arg_list[flag_pos + 1 :] 300*da0073e9SAndroid Build Coastguard Worker ) 301*da0073e9SAndroid Build Coastguard Worker # Otherwise force return_indices to be kwarg 302*da0073e9SAndroid Build Coastguard Worker arg_list_keyword = arg_list.copy() 303*da0073e9SAndroid Build Coastguard Worker arg_list_keyword.insert(flag_pos, "*") 304*da0073e9SAndroid Build Coastguard Worker tmpl = "def {name}({args}) -> {{return_type}}: ..." 305*da0073e9SAndroid Build Coastguard Worker return { 306*da0073e9SAndroid Build Coastguard Worker name: [ 307*da0073e9SAndroid Build Coastguard Worker tmpl.format(name=name, args=", ".join(arg_list)).format( 308*da0073e9SAndroid Build Coastguard Worker return_indices="return_indices: Literal[False] = False", 309*da0073e9SAndroid Build Coastguard Worker return_type="Tensor", 310*da0073e9SAndroid Build Coastguard Worker ), 311*da0073e9SAndroid Build Coastguard Worker tmpl.format(name=name, args=", ".join(arg_list_positional)).format( 312*da0073e9SAndroid Build Coastguard Worker return_indices="return_indices: Literal[True]", 313*da0073e9SAndroid Build Coastguard Worker return_type="Tuple[Tensor, Tensor]", 314*da0073e9SAndroid Build Coastguard Worker ), 315*da0073e9SAndroid Build Coastguard Worker tmpl.format(name=name, args=", ".join(arg_list_keyword)).format( 316*da0073e9SAndroid Build Coastguard Worker return_indices="return_indices: Literal[True]", 317*da0073e9SAndroid Build Coastguard Worker return_type="Tuple[Tensor, Tensor]", 318*da0073e9SAndroid Build Coastguard Worker ), 319*da0073e9SAndroid Build Coastguard Worker ] 320*da0073e9SAndroid Build Coastguard Worker } 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Workerdef gen_nn_functional(fm: FileManager) -> None: 324*da0073e9SAndroid Build Coastguard Worker INPUT = "input: Tensor" 325*da0073e9SAndroid Build Coastguard Worker KERNEL_SIZE = "kernel_size: Union[_int, _size]" 326*da0073e9SAndroid Build Coastguard Worker STRIDE_PADDING = ", ".join( 327*da0073e9SAndroid Build Coastguard Worker [ 328*da0073e9SAndroid Build Coastguard Worker "stride: Optional[Union[_int, _size]] = None", 329*da0073e9SAndroid Build Coastguard Worker "padding: Union[_int, _size] = 0", 330*da0073e9SAndroid Build Coastguard Worker ] 331*da0073e9SAndroid Build Coastguard Worker ) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker # TODO the list for `torch._C._nn` is nonexhaustive 334*da0073e9SAndroid Build Coastguard Worker unsorted_c_nn_function_hints: dict[str, list[str]] = {} 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker for d in (2, 3): 337*da0073e9SAndroid Build Coastguard Worker unsorted_c_nn_function_hints.update( 338*da0073e9SAndroid Build Coastguard Worker { 339*da0073e9SAndroid Build Coastguard Worker f"avg_pool{d}d": [ 340*da0073e9SAndroid Build Coastguard Worker f"def avg_pool{d}d({{}}) -> Tensor: ...".format( 341*da0073e9SAndroid Build Coastguard Worker ", ".join( 342*da0073e9SAndroid Build Coastguard Worker [ 343*da0073e9SAndroid Build Coastguard Worker f"{INPUT}", 344*da0073e9SAndroid Build Coastguard Worker f"{KERNEL_SIZE}", 345*da0073e9SAndroid Build Coastguard Worker f"{STRIDE_PADDING}", 346*da0073e9SAndroid Build Coastguard Worker "ceil_mode: bool = False", 347*da0073e9SAndroid Build Coastguard Worker "count_include_pad: bool = True", 348*da0073e9SAndroid Build Coastguard Worker "divisor_override: Optional[int] = None", 349*da0073e9SAndroid Build Coastguard Worker ] 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker ) 352*da0073e9SAndroid Build Coastguard Worker ], 353*da0073e9SAndroid Build Coastguard Worker f"fractional_max_pool{d}d": [ 354*da0073e9SAndroid Build Coastguard Worker f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format( 355*da0073e9SAndroid Build Coastguard Worker ", ".join( 356*da0073e9SAndroid Build Coastguard Worker [ 357*da0073e9SAndroid Build Coastguard Worker f"{INPUT}", 358*da0073e9SAndroid Build Coastguard Worker f"{KERNEL_SIZE}", 359*da0073e9SAndroid Build Coastguard Worker "output_size: Union[_int, _size]", 360*da0073e9SAndroid Build Coastguard Worker "_random_samples: Tensor", 361*da0073e9SAndroid Build Coastguard Worker ] 362*da0073e9SAndroid Build Coastguard Worker ), 363*da0073e9SAndroid Build Coastguard Worker "Tuple[Tensor, Tensor]", 364*da0073e9SAndroid Build Coastguard Worker ) 365*da0073e9SAndroid Build Coastguard Worker ], 366*da0073e9SAndroid Build Coastguard Worker f"adaptive_max_pool{d}d": [ 367*da0073e9SAndroid Build Coastguard Worker f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format( 368*da0073e9SAndroid Build Coastguard Worker ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]), 369*da0073e9SAndroid Build Coastguard Worker "Tuple[Tensor, Tensor]", 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker ], 372*da0073e9SAndroid Build Coastguard Worker } 373*da0073e9SAndroid Build Coastguard Worker ) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker unsorted_c_nn_function_hints.update( 376*da0073e9SAndroid Build Coastguard Worker { 377*da0073e9SAndroid Build Coastguard Worker "hardtanh": [ 378*da0073e9SAndroid Build Coastguard Worker "def hardtanh({}) -> Tensor: ...".format( 379*da0073e9SAndroid Build Coastguard Worker ", ".join( 380*da0073e9SAndroid Build Coastguard Worker [ 381*da0073e9SAndroid Build Coastguard Worker "input: Tensor", 382*da0073e9SAndroid Build Coastguard Worker "min_val: float = ...", 383*da0073e9SAndroid Build Coastguard Worker "max_val: float = ...", 384*da0073e9SAndroid Build Coastguard Worker "*", 385*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 386*da0073e9SAndroid Build Coastguard Worker ] 387*da0073e9SAndroid Build Coastguard Worker ) 388*da0073e9SAndroid Build Coastguard Worker ) 389*da0073e9SAndroid Build Coastguard Worker ], 390*da0073e9SAndroid Build Coastguard Worker "hardtanh_": [ 391*da0073e9SAndroid Build Coastguard Worker "def hardtanh_({}) -> Tensor: ...".format( 392*da0073e9SAndroid Build Coastguard Worker ", ".join( 393*da0073e9SAndroid Build Coastguard Worker [ 394*da0073e9SAndroid Build Coastguard Worker "input: Tensor", 395*da0073e9SAndroid Build Coastguard Worker "min_val: float = ...", 396*da0073e9SAndroid Build Coastguard Worker "max_val: float = ...", 397*da0073e9SAndroid Build Coastguard Worker ] 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker ) 400*da0073e9SAndroid Build Coastguard Worker ], 401*da0073e9SAndroid Build Coastguard Worker "elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."], 402*da0073e9SAndroid Build Coastguard Worker "leaky_relu": [ 403*da0073e9SAndroid Build Coastguard Worker "def leaky_relu({}) -> Tensor: ...".format( 404*da0073e9SAndroid Build Coastguard Worker ", ".join( 405*da0073e9SAndroid Build Coastguard Worker [ 406*da0073e9SAndroid Build Coastguard Worker "input: Tensor", 407*da0073e9SAndroid Build Coastguard Worker "negative_slope: float = ...", 408*da0073e9SAndroid Build Coastguard Worker "*", 409*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 410*da0073e9SAndroid Build Coastguard Worker ] 411*da0073e9SAndroid Build Coastguard Worker ) 412*da0073e9SAndroid Build Coastguard Worker ) 413*da0073e9SAndroid Build Coastguard Worker ], 414*da0073e9SAndroid Build Coastguard Worker "leaky_relu_": [ 415*da0073e9SAndroid Build Coastguard Worker f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..." 416*da0073e9SAndroid Build Coastguard Worker ], 417*da0073e9SAndroid Build Coastguard Worker "log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."], 418*da0073e9SAndroid Build Coastguard Worker "gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."], 419*da0073e9SAndroid Build Coastguard Worker "softplus": [ 420*da0073e9SAndroid Build Coastguard Worker "def softplus({}) -> Tensor: ...".format( 421*da0073e9SAndroid Build Coastguard Worker ", ".join( 422*da0073e9SAndroid Build Coastguard Worker ["input: Tensor", "beta: float = ...", "threshold: float = ..."] 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker ) 425*da0073e9SAndroid Build Coastguard Worker ], 426*da0073e9SAndroid Build Coastguard Worker "softshrink": [ 427*da0073e9SAndroid Build Coastguard Worker "def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..." 428*da0073e9SAndroid Build Coastguard Worker ], 429*da0073e9SAndroid Build Coastguard Worker "hardsigmoid": [ 430*da0073e9SAndroid Build Coastguard Worker f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..." 431*da0073e9SAndroid Build Coastguard Worker ], 432*da0073e9SAndroid Build Coastguard Worker "linear": [ 433*da0073e9SAndroid Build Coastguard Worker "def linear({}) -> Tensor: ...".format( 434*da0073e9SAndroid Build Coastguard Worker ", ".join( 435*da0073e9SAndroid Build Coastguard Worker [ 436*da0073e9SAndroid Build Coastguard Worker "input: Tensor", 437*da0073e9SAndroid Build Coastguard Worker "weight: Tensor", 438*da0073e9SAndroid Build Coastguard Worker "bias: Optional[Tensor] = None", 439*da0073e9SAndroid Build Coastguard Worker ] 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker ) 442*da0073e9SAndroid Build Coastguard Worker ], 443*da0073e9SAndroid Build Coastguard Worker "pad": [ 444*da0073e9SAndroid Build Coastguard Worker "def pad({}) -> Tensor: ...".format( 445*da0073e9SAndroid Build Coastguard Worker ", ".join( 446*da0073e9SAndroid Build Coastguard Worker [ 447*da0073e9SAndroid Build Coastguard Worker "input: Tensor", 448*da0073e9SAndroid Build Coastguard Worker "pad: Sequence[int]", 449*da0073e9SAndroid Build Coastguard Worker "mode: str = ...", 450*da0073e9SAndroid Build Coastguard Worker "value: Optional[float] = None", 451*da0073e9SAndroid Build Coastguard Worker ] 452*da0073e9SAndroid Build Coastguard Worker ) 453*da0073e9SAndroid Build Coastguard Worker ) 454*da0073e9SAndroid Build Coastguard Worker ], 455*da0073e9SAndroid Build Coastguard Worker "one_hot": [ 456*da0073e9SAndroid Build Coastguard Worker "def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..." 457*da0073e9SAndroid Build Coastguard Worker ], 458*da0073e9SAndroid Build Coastguard Worker "scaled_dot_product_attention": [ 459*da0073e9SAndroid Build Coastguard Worker "def scaled_dot_product_attention({}) -> Tensor: ...".format( 460*da0073e9SAndroid Build Coastguard Worker ", ".join( 461*da0073e9SAndroid Build Coastguard Worker [ 462*da0073e9SAndroid Build Coastguard Worker "query: Tensor", 463*da0073e9SAndroid Build Coastguard Worker "key: Tensor", 464*da0073e9SAndroid Build Coastguard Worker "value: Tensor", 465*da0073e9SAndroid Build Coastguard Worker "attn_mask: Optional[Tensor] = None", 466*da0073e9SAndroid Build Coastguard Worker "dropout_p: float = 0.0", 467*da0073e9SAndroid Build Coastguard Worker "is_causal: bool = False", 468*da0073e9SAndroid Build Coastguard Worker "scale: Optional[float] = None", 469*da0073e9SAndroid Build Coastguard Worker "enable_gqa: bool = False", 470*da0073e9SAndroid Build Coastguard Worker ] 471*da0073e9SAndroid Build Coastguard Worker ) 472*da0073e9SAndroid Build Coastguard Worker ) 473*da0073e9SAndroid Build Coastguard Worker ], 474*da0073e9SAndroid Build Coastguard Worker } 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker c_nn_function_hints: list[str] = [] 478*da0073e9SAndroid Build Coastguard Worker for _, hints in sorted(unsorted_c_nn_function_hints.items()): 479*da0073e9SAndroid Build Coastguard Worker if len(hints) > 1: 480*da0073e9SAndroid Build Coastguard Worker hints = ["@overload\n" + h for h in hints] 481*da0073e9SAndroid Build Coastguard Worker c_nn_function_hints += hints 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered 484*da0073e9SAndroid Build Coastguard Worker # through an `_add_docstr` call 485*da0073e9SAndroid Build Coastguard Worker torch_imports = [ 486*da0073e9SAndroid Build Coastguard Worker "conv1d", 487*da0073e9SAndroid Build Coastguard Worker "conv2d", 488*da0073e9SAndroid Build Coastguard Worker "conv3d", 489*da0073e9SAndroid Build Coastguard Worker "conv_transpose1d", 490*da0073e9SAndroid Build Coastguard Worker "conv_transpose2d", 491*da0073e9SAndroid Build Coastguard Worker "conv_transpose3d", 492*da0073e9SAndroid Build Coastguard Worker "conv_tbc", 493*da0073e9SAndroid Build Coastguard Worker "avg_pool1d", 494*da0073e9SAndroid Build Coastguard Worker "adaptive_avg_pool1d", 495*da0073e9SAndroid Build Coastguard Worker "relu_", 496*da0073e9SAndroid Build Coastguard Worker "selu_", 497*da0073e9SAndroid Build Coastguard Worker "celu_", 498*da0073e9SAndroid Build Coastguard Worker "prelu", 499*da0073e9SAndroid Build Coastguard Worker "rrelu_", 500*da0073e9SAndroid Build Coastguard Worker "hardshrink", 501*da0073e9SAndroid Build Coastguard Worker "bilinear", 502*da0073e9SAndroid Build Coastguard Worker "pixel_shuffle", 503*da0073e9SAndroid Build Coastguard Worker "pixel_unshuffle", 504*da0073e9SAndroid Build Coastguard Worker "channel_shuffle", 505*da0073e9SAndroid Build Coastguard Worker "native_channel_shuffle", 506*da0073e9SAndroid Build Coastguard Worker "pairwise_distance", 507*da0073e9SAndroid Build Coastguard Worker "pdist", 508*da0073e9SAndroid Build Coastguard Worker "cosine_similarity", 509*da0073e9SAndroid Build Coastguard Worker ] 510*da0073e9SAndroid Build Coastguard Worker imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports] 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker # Functions imported into `torch.nn.functional` from `torch._C._nn` 513*da0073e9SAndroid Build Coastguard Worker c_nn_imports = [ 514*da0073e9SAndroid Build Coastguard Worker "avg_pool2d", 515*da0073e9SAndroid Build Coastguard Worker "avg_pool3d", 516*da0073e9SAndroid Build Coastguard Worker "hardtanh_", 517*da0073e9SAndroid Build Coastguard Worker "elu_", 518*da0073e9SAndroid Build Coastguard Worker "leaky_relu_", 519*da0073e9SAndroid Build Coastguard Worker "gelu", 520*da0073e9SAndroid Build Coastguard Worker "softplus", 521*da0073e9SAndroid Build Coastguard Worker "softshrink", 522*da0073e9SAndroid Build Coastguard Worker "linear", 523*da0073e9SAndroid Build Coastguard Worker "pad", 524*da0073e9SAndroid Build Coastguard Worker "one_hot", 525*da0073e9SAndroid Build Coastguard Worker "scaled_dot_product_attention", 526*da0073e9SAndroid Build Coastguard Worker ] 527*da0073e9SAndroid Build Coastguard Worker imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports] 528*da0073e9SAndroid Build Coastguard Worker # This is from `torch._C._nn` but renamed 529*da0073e9SAndroid Build Coastguard Worker imported_hints.append( 530*da0073e9SAndroid Build Coastguard Worker "from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid" 531*da0073e9SAndroid Build Coastguard Worker ) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` 534*da0073e9SAndroid Build Coastguard Worker unsorted_dispatched_hints: dict[str, list[str]] = {} 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker for d in (1, 2, 3): 537*da0073e9SAndroid Build Coastguard Worker unsorted_dispatched_hints.update( 538*da0073e9SAndroid Build Coastguard Worker **get_max_pool_dispatch( 539*da0073e9SAndroid Build Coastguard Worker f"max_pool{d}d", 540*da0073e9SAndroid Build Coastguard Worker [ 541*da0073e9SAndroid Build Coastguard Worker f"{INPUT}", 542*da0073e9SAndroid Build Coastguard Worker f"{KERNEL_SIZE}", 543*da0073e9SAndroid Build Coastguard Worker f"{STRIDE_PADDING}", 544*da0073e9SAndroid Build Coastguard Worker "dilation: Union[_int, _size] = 1", 545*da0073e9SAndroid Build Coastguard Worker "ceil_mode: bool = False", 546*da0073e9SAndroid Build Coastguard Worker "{return_indices}", 547*da0073e9SAndroid Build Coastguard Worker ], 548*da0073e9SAndroid Build Coastguard Worker ), 549*da0073e9SAndroid Build Coastguard Worker **get_max_pool_dispatch( 550*da0073e9SAndroid Build Coastguard Worker f"fractional_max_pool{d}d", 551*da0073e9SAndroid Build Coastguard Worker [ 552*da0073e9SAndroid Build Coastguard Worker f"{INPUT}", 553*da0073e9SAndroid Build Coastguard Worker f"{KERNEL_SIZE}", 554*da0073e9SAndroid Build Coastguard Worker "output_size: Optional[Union[_int, _size]] = None", 555*da0073e9SAndroid Build Coastguard Worker "output_ratio: Optional[_ratio_any_t] = None", 556*da0073e9SAndroid Build Coastguard Worker "{return_indices}", 557*da0073e9SAndroid Build Coastguard Worker "_random_samples: Optional[Tensor] = None", 558*da0073e9SAndroid Build Coastguard Worker ], 559*da0073e9SAndroid Build Coastguard Worker ), 560*da0073e9SAndroid Build Coastguard Worker **get_max_pool_dispatch( 561*da0073e9SAndroid Build Coastguard Worker f"adaptive_max_pool{d}d", 562*da0073e9SAndroid Build Coastguard Worker [f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"], 563*da0073e9SAndroid Build Coastguard Worker ), 564*da0073e9SAndroid Build Coastguard Worker ) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker # There's no fractional_max_pool1d 567*da0073e9SAndroid Build Coastguard Worker del unsorted_dispatched_hints["fractional_max_pool1d"] 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker dispatched_hints: list[str] = [] 570*da0073e9SAndroid Build Coastguard Worker for _, hints in sorted(unsorted_dispatched_hints.items()): 571*da0073e9SAndroid Build Coastguard Worker if len(hints) > 1: 572*da0073e9SAndroid Build Coastguard Worker hints = ["@overload\n" + h for h in hints] 573*da0073e9SAndroid Build Coastguard Worker dispatched_hints += hints 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 576*da0073e9SAndroid Build Coastguard Worker "torch/nn/functional.pyi", 577*da0073e9SAndroid Build Coastguard Worker "torch/nn/functional.pyi.in", 578*da0073e9SAndroid Build Coastguard Worker lambda: { 579*da0073e9SAndroid Build Coastguard Worker "imported_hints": imported_hints, 580*da0073e9SAndroid Build Coastguard Worker "dispatched_hints": dispatched_hints, 581*da0073e9SAndroid Build Coastguard Worker }, 582*da0073e9SAndroid Build Coastguard Worker ) 583*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 584*da0073e9SAndroid Build Coastguard Worker "torch/_C/_nn.pyi", 585*da0073e9SAndroid Build Coastguard Worker "torch/_C/_nn.pyi.in", 586*da0073e9SAndroid Build Coastguard Worker lambda: { 587*da0073e9SAndroid Build Coastguard Worker "c_nn_function_hints": c_nn_function_hints, 588*da0073e9SAndroid Build Coastguard Worker }, 589*da0073e9SAndroid Build Coastguard Worker ) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker""" 593*da0073e9SAndroid Build Coastguard WorkerWe gather the docstrings for torch with the following steps: 594*da0073e9SAndroid Build Coastguard Worker1. Mock torch and torch._C, which are the only dependencies of the docs files 595*da0073e9SAndroid Build Coastguard Worker2. Mock the _add_docstr function to save the docstrings 596*da0073e9SAndroid Build Coastguard Worker3. Import the docs files to trigger mocked _add_docstr and collect docstrings 597*da0073e9SAndroid Build Coastguard Worker""" 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Workerdef gather_docstrs() -> dict[str, str]: 601*da0073e9SAndroid Build Coastguard Worker docstrs = {} 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker def mock_add_docstr(func: Mock, docstr: str) -> None: 604*da0073e9SAndroid Build Coastguard Worker docstrs[func._extract_mock_name()] = docstr.strip() 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker # sys.modules and sys.path are restored after the context manager exits 607*da0073e9SAndroid Build Coastguard Worker with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]): 608*da0073e9SAndroid Build Coastguard Worker # mock the torch module and torch._C._add_docstr 609*da0073e9SAndroid Build Coastguard Worker sys.modules["torch"] = Mock(name="torch") 610*da0073e9SAndroid Build Coastguard Worker sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker try: 613*da0073e9SAndroid Build Coastguard Worker # manually import torch._torch_docs and torch._tensor_docs to trigger 614*da0073e9SAndroid Build Coastguard Worker # the mocked _add_docstr and collect docstrings 615*da0073e9SAndroid Build Coastguard Worker sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs") 616*da0073e9SAndroid Build Coastguard Worker sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs") 617*da0073e9SAndroid Build Coastguard Worker except ModuleNotFoundError: 618*da0073e9SAndroid Build Coastguard Worker # Gracefully fail if these modules are not importable 619*da0073e9SAndroid Build Coastguard Worker warn( 620*da0073e9SAndroid Build Coastguard Worker "Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files." 621*da0073e9SAndroid Build Coastguard Worker ) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker return docstrs 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Workerdef add_docstr_to_hint(docstr: str, hint: str) -> str: 627*da0073e9SAndroid Build Coastguard Worker if "..." in hint: # function or method 628*da0073e9SAndroid Build Coastguard Worker assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'" 629*da0073e9SAndroid Build Coastguard Worker hint = hint[:-3] # remove "..." 630*da0073e9SAndroid Build Coastguard Worker return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."]) 631*da0073e9SAndroid Build Coastguard Worker else: # attribute or property 632*da0073e9SAndroid Build Coastguard Worker return f'{hint}\nr"""{docstr}"""\n' 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Workerdef gen_pyi( 636*da0073e9SAndroid Build Coastguard Worker native_yaml_path: str, 637*da0073e9SAndroid Build Coastguard Worker tags_yaml_path: str, 638*da0073e9SAndroid Build Coastguard Worker deprecated_yaml_path: str, 639*da0073e9SAndroid Build Coastguard Worker fm: FileManager, 640*da0073e9SAndroid Build Coastguard Worker) -> None: 641*da0073e9SAndroid Build Coastguard Worker """gen_pyi() 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker This function generates a pyi file for torch. 644*da0073e9SAndroid Build Coastguard Worker """ 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker # Some of this logic overlaps with generate_python_signature in 647*da0073e9SAndroid Build Coastguard Worker # tools/autograd/gen_python_functions.py; however, this 648*da0073e9SAndroid Build Coastguard Worker # function is all about generating mypy type signatures, whereas 649*da0073e9SAndroid Build Coastguard Worker # the other function generates are custom format for argument 650*da0073e9SAndroid Build Coastguard Worker # checking. If you are update this, consider if your change 651*da0073e9SAndroid Build Coastguard Worker # also needs to update the other file. 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker # Dictionary for NamedTuple definitions 654*da0073e9SAndroid Build Coastguard Worker structseqs: dict[str, str] = {} 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker # Generate type signatures for top-level functions 657*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker for n, n1, n2 in [ 662*da0073e9SAndroid Build Coastguard Worker ("csr", "crow", "col"), 663*da0073e9SAndroid Build Coastguard Worker ("csc", "ccol", "row"), 664*da0073e9SAndroid Build Coastguard Worker ("bsr", "crow", "col"), 665*da0073e9SAndroid Build Coastguard Worker ("bsc", "ccol", "row"), 666*da0073e9SAndroid Build Coastguard Worker ]: 667*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints.update( 668*da0073e9SAndroid Build Coastguard Worker { 669*da0073e9SAndroid Build Coastguard Worker f"sparse_{n}_tensor": [ 670*da0073e9SAndroid Build Coastguard Worker f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format( 671*da0073e9SAndroid Build Coastguard Worker ", ".join( 672*da0073e9SAndroid Build Coastguard Worker [ 673*da0073e9SAndroid Build Coastguard Worker f"{n1}_indices: Union[Tensor, List]", 674*da0073e9SAndroid Build Coastguard Worker f"{n2}_indices: Union[Tensor, List]", 675*da0073e9SAndroid Build Coastguard Worker "values: Union[Tensor, List]", 676*da0073e9SAndroid Build Coastguard Worker "size: Optional[_size] = None", 677*da0073e9SAndroid Build Coastguard Worker "*", 678*da0073e9SAndroid Build Coastguard Worker "dtype: Optional[_dtype] = None", 679*da0073e9SAndroid Build Coastguard Worker "device: Optional[DeviceLikeType] = None", 680*da0073e9SAndroid Build Coastguard Worker "requires_grad: _bool = False", 681*da0073e9SAndroid Build Coastguard Worker "check_invariants: Optional[_bool] = None", 682*da0073e9SAndroid Build Coastguard Worker ] 683*da0073e9SAndroid Build Coastguard Worker ), 684*da0073e9SAndroid Build Coastguard Worker ) 685*da0073e9SAndroid Build Coastguard Worker ], 686*da0073e9SAndroid Build Coastguard Worker } 687*da0073e9SAndroid Build Coastguard Worker ) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints.update( 690*da0073e9SAndroid Build Coastguard Worker { 691*da0073e9SAndroid Build Coastguard Worker "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."], 692*da0073e9SAndroid Build Coastguard Worker "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."], 693*da0073e9SAndroid Build Coastguard Worker "asarray": [ 694*da0073e9SAndroid Build Coastguard Worker "def asarray({}) -> Tensor: ...".format( 695*da0073e9SAndroid Build Coastguard Worker ", ".join( 696*da0073e9SAndroid Build Coastguard Worker [ 697*da0073e9SAndroid Build Coastguard Worker "obj: Any", 698*da0073e9SAndroid Build Coastguard Worker "*", 699*da0073e9SAndroid Build Coastguard Worker "dtype: Optional[_dtype] = None", 700*da0073e9SAndroid Build Coastguard Worker "device: Optional[DeviceLikeType] = None", 701*da0073e9SAndroid Build Coastguard Worker "copy: Optional[_bool] = None", 702*da0073e9SAndroid Build Coastguard Worker "requires_grad: _bool = False", 703*da0073e9SAndroid Build Coastguard Worker ] 704*da0073e9SAndroid Build Coastguard Worker ) 705*da0073e9SAndroid Build Coastguard Worker ) 706*da0073e9SAndroid Build Coastguard Worker ], 707*da0073e9SAndroid Build Coastguard Worker "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."], 708*da0073e9SAndroid Build Coastguard Worker "frombuffer": [ 709*da0073e9SAndroid Build Coastguard Worker "def frombuffer({}) -> Tensor: ...".format( 710*da0073e9SAndroid Build Coastguard Worker ", ".join( 711*da0073e9SAndroid Build Coastguard Worker [ 712*da0073e9SAndroid Build Coastguard Worker "buffer: Any", 713*da0073e9SAndroid Build Coastguard Worker "*", 714*da0073e9SAndroid Build Coastguard Worker "dtype: _dtype", 715*da0073e9SAndroid Build Coastguard Worker "count: int = -1", 716*da0073e9SAndroid Build Coastguard Worker "offset: int = 0", 717*da0073e9SAndroid Build Coastguard Worker "requires_grad: _bool = False", 718*da0073e9SAndroid Build Coastguard Worker ] 719*da0073e9SAndroid Build Coastguard Worker ) 720*da0073e9SAndroid Build Coastguard Worker ) 721*da0073e9SAndroid Build Coastguard Worker ], 722*da0073e9SAndroid Build Coastguard Worker "numel": ["def numel(self: Tensor) -> _int: ..."], 723*da0073e9SAndroid Build Coastguard Worker "as_tensor": [ 724*da0073e9SAndroid Build Coastguard Worker "def as_tensor({}) -> Tensor: ...".format( 725*da0073e9SAndroid Build Coastguard Worker ", ".join( 726*da0073e9SAndroid Build Coastguard Worker [ 727*da0073e9SAndroid Build Coastguard Worker "data: Any", 728*da0073e9SAndroid Build Coastguard Worker "dtype: Optional[_dtype] = None", 729*da0073e9SAndroid Build Coastguard Worker DEVICE_PARAM, 730*da0073e9SAndroid Build Coastguard Worker ] 731*da0073e9SAndroid Build Coastguard Worker ) 732*da0073e9SAndroid Build Coastguard Worker ) 733*da0073e9SAndroid Build Coastguard Worker ], 734*da0073e9SAndroid Build Coastguard Worker "get_num_threads": ["def get_num_threads() -> _int: ..."], 735*da0073e9SAndroid Build Coastguard Worker "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."], 736*da0073e9SAndroid Build Coastguard Worker "init_num_threads": ["def init_num_threads() -> None: ..."], 737*da0073e9SAndroid Build Coastguard Worker "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."], 738*da0073e9SAndroid Build Coastguard Worker "set_num_interop_threads": [ 739*da0073e9SAndroid Build Coastguard Worker "def set_num_interop_threads(num: _int) -> None: ..." 740*da0073e9SAndroid Build Coastguard Worker ], 741*da0073e9SAndroid Build Coastguard Worker # These functions are explicitly disabled by 742*da0073e9SAndroid Build Coastguard Worker # SKIP_PYTHON_BINDINGS because they are hand bound. 743*da0073e9SAndroid Build Coastguard Worker # Correspondingly, we must hand-write their signatures. 744*da0073e9SAndroid Build Coastguard Worker "tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."], 745*da0073e9SAndroid Build Coastguard Worker "sparse_coo_tensor": [ 746*da0073e9SAndroid Build Coastguard Worker "def sparse_coo_tensor({}) -> Tensor: ...".format( 747*da0073e9SAndroid Build Coastguard Worker ", ".join( 748*da0073e9SAndroid Build Coastguard Worker [ 749*da0073e9SAndroid Build Coastguard Worker "indices: Tensor", 750*da0073e9SAndroid Build Coastguard Worker "values: Union[Tensor, List]", 751*da0073e9SAndroid Build Coastguard Worker "size: Optional[_size] = None", 752*da0073e9SAndroid Build Coastguard Worker "*", 753*da0073e9SAndroid Build Coastguard Worker "dtype: Optional[_dtype] = None", 754*da0073e9SAndroid Build Coastguard Worker "device: Optional[DeviceLikeType] = None", 755*da0073e9SAndroid Build Coastguard Worker "requires_grad: _bool = False", 756*da0073e9SAndroid Build Coastguard Worker "check_invariants: Optional[_bool] = None", 757*da0073e9SAndroid Build Coastguard Worker "is_coalesced: Optional[_bool] = None", 758*da0073e9SAndroid Build Coastguard Worker ] 759*da0073e9SAndroid Build Coastguard Worker ) 760*da0073e9SAndroid Build Coastguard Worker ) 761*da0073e9SAndroid Build Coastguard Worker ], 762*da0073e9SAndroid Build Coastguard Worker "sparse_compressed_tensor": [ 763*da0073e9SAndroid Build Coastguard Worker "def sparse_compressed_tensor({}) -> Tensor: ...".format( 764*da0073e9SAndroid Build Coastguard Worker ", ".join( 765*da0073e9SAndroid Build Coastguard Worker [ 766*da0073e9SAndroid Build Coastguard Worker "compressed_indices: Union[Tensor, List]", 767*da0073e9SAndroid Build Coastguard Worker "plain_indices: Union[Tensor, List]", 768*da0073e9SAndroid Build Coastguard Worker "values: Union[Tensor, List]", 769*da0073e9SAndroid Build Coastguard Worker "size: Optional[_size] = None", 770*da0073e9SAndroid Build Coastguard Worker "*", 771*da0073e9SAndroid Build Coastguard Worker "dtype: Optional[_dtype] = None", 772*da0073e9SAndroid Build Coastguard Worker "layout: Optional[_layout] = None", 773*da0073e9SAndroid Build Coastguard Worker "device: Optional[DeviceLikeType] = None", 774*da0073e9SAndroid Build Coastguard Worker "requires_grad: _bool = False", 775*da0073e9SAndroid Build Coastguard Worker "check_invariants: Optional[_bool] = None", 776*da0073e9SAndroid Build Coastguard Worker ] 777*da0073e9SAndroid Build Coastguard Worker ) 778*da0073e9SAndroid Build Coastguard Worker ) 779*da0073e9SAndroid Build Coastguard Worker ], 780*da0073e9SAndroid Build Coastguard Worker "_sync": ["def _sync(t: Tensor) -> None: ..."], 781*da0073e9SAndroid Build Coastguard Worker "_is_functional_tensor": [ 782*da0073e9SAndroid Build Coastguard Worker "def _is_functional_tensor(t: Tensor) -> _bool: ..." 783*da0073e9SAndroid Build Coastguard Worker ], 784*da0073e9SAndroid Build Coastguard Worker "_is_functional_tensor_base": [ 785*da0073e9SAndroid Build Coastguard Worker "def _is_functional_tensor_base(t: Tensor) -> _bool: ..." 786*da0073e9SAndroid Build Coastguard Worker ], 787*da0073e9SAndroid Build Coastguard Worker "_from_functional_tensor": [ 788*da0073e9SAndroid Build Coastguard Worker "def _from_functional_tensor(t: Tensor) -> Tensor: ..." 789*da0073e9SAndroid Build Coastguard Worker ], 790*da0073e9SAndroid Build Coastguard Worker "_to_functional_tensor": [ 791*da0073e9SAndroid Build Coastguard Worker "def _to_functional_tensor(t: Tensor) -> Tensor: ..." 792*da0073e9SAndroid Build Coastguard Worker ], 793*da0073e9SAndroid Build Coastguard Worker "_functionalize_replace": [ 794*da0073e9SAndroid Build Coastguard Worker "def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..." 795*da0073e9SAndroid Build Coastguard Worker ], 796*da0073e9SAndroid Build Coastguard Worker "_functionalize_commit_update": [ 797*da0073e9SAndroid Build Coastguard Worker "def _functionalize_commit_update(t: Tensor) -> None: ..." 798*da0073e9SAndroid Build Coastguard Worker ], 799*da0073e9SAndroid Build Coastguard Worker "_functionalize_unsafe_set": [ 800*da0073e9SAndroid Build Coastguard Worker "def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..." 801*da0073e9SAndroid Build Coastguard Worker ], 802*da0073e9SAndroid Build Coastguard Worker "_functionalize_mark_mutation_hidden_from_autograd": [ 803*da0073e9SAndroid Build Coastguard Worker "def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..." 804*da0073e9SAndroid Build Coastguard Worker ], 805*da0073e9SAndroid Build Coastguard Worker "_functionalize_are_all_mutations_hidden_from_autograd": [ 806*da0073e9SAndroid Build Coastguard Worker "def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ..." 807*da0073e9SAndroid Build Coastguard Worker ], 808*da0073e9SAndroid Build Coastguard Worker "_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [ 809*da0073e9SAndroid Build Coastguard Worker "def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ..." 810*da0073e9SAndroid Build Coastguard Worker ], 811*da0073e9SAndroid Build Coastguard Worker "_functionalize_was_inductor_storage_resized": [ 812*da0073e9SAndroid Build Coastguard Worker "def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ..." 813*da0073e9SAndroid Build Coastguard Worker ], 814*da0073e9SAndroid Build Coastguard Worker "_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."], 815*da0073e9SAndroid Build Coastguard Worker "_functionalize_was_storage_changed": [ 816*da0073e9SAndroid Build Coastguard Worker "def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..." 817*da0073e9SAndroid Build Coastguard Worker ], 818*da0073e9SAndroid Build Coastguard Worker "_functionalize_set_storage_changed": [ 819*da0073e9SAndroid Build Coastguard Worker "def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..." 820*da0073e9SAndroid Build Coastguard Worker ], 821*da0073e9SAndroid Build Coastguard Worker "_functionalize_has_metadata_mutation": [ 822*da0073e9SAndroid Build Coastguard Worker "def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..." 823*da0073e9SAndroid Build Coastguard Worker ], 824*da0073e9SAndroid Build Coastguard Worker "_functionalize_apply_view_metas": [ 825*da0073e9SAndroid Build Coastguard Worker "def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ..." 826*da0073e9SAndroid Build Coastguard Worker ], 827*da0073e9SAndroid Build Coastguard Worker "_functionalize_is_symbolic": [ 828*da0073e9SAndroid Build Coastguard Worker "def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..." 829*da0073e9SAndroid Build Coastguard Worker ], 830*da0073e9SAndroid Build Coastguard Worker "_enable_functionalization": [ 831*da0073e9SAndroid Build Coastguard Worker "def _enable_functionalization(*, reapply_views: _bool = False): ..." 832*da0073e9SAndroid Build Coastguard Worker ], 833*da0073e9SAndroid Build Coastguard Worker "_disable_functionalization": ["def _disable_functionalization(): ..."], 834*da0073e9SAndroid Build Coastguard Worker "range": [ 835*da0073e9SAndroid Build Coastguard Worker "def range({}) -> Tensor: ...".format( 836*da0073e9SAndroid Build Coastguard Worker ", ".join( 837*da0073e9SAndroid Build Coastguard Worker [ 838*da0073e9SAndroid Build Coastguard Worker "start: Number", 839*da0073e9SAndroid Build Coastguard Worker "end: Number", 840*da0073e9SAndroid Build Coastguard Worker "step: Number = 1", 841*da0073e9SAndroid Build Coastguard Worker "*", 842*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 843*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 844*da0073e9SAndroid Build Coastguard Worker ] 845*da0073e9SAndroid Build Coastguard Worker ) 846*da0073e9SAndroid Build Coastguard Worker ) 847*da0073e9SAndroid Build Coastguard Worker ], 848*da0073e9SAndroid Build Coastguard Worker "arange": [ 849*da0073e9SAndroid Build Coastguard Worker "def arange({}) -> Tensor: ...".format( 850*da0073e9SAndroid Build Coastguard Worker ", ".join( 851*da0073e9SAndroid Build Coastguard Worker [ 852*da0073e9SAndroid Build Coastguard Worker "start: Number", 853*da0073e9SAndroid Build Coastguard Worker "end: Number", 854*da0073e9SAndroid Build Coastguard Worker "step: Number", 855*da0073e9SAndroid Build Coastguard Worker "*", 856*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 857*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 858*da0073e9SAndroid Build Coastguard Worker ] 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker ), 861*da0073e9SAndroid Build Coastguard Worker "def arange({}) -> Tensor: ...".format( 862*da0073e9SAndroid Build Coastguard Worker ", ".join( 863*da0073e9SAndroid Build Coastguard Worker [ 864*da0073e9SAndroid Build Coastguard Worker "start: Number", 865*da0073e9SAndroid Build Coastguard Worker "end: Number", 866*da0073e9SAndroid Build Coastguard Worker "*", 867*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 868*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 869*da0073e9SAndroid Build Coastguard Worker ] 870*da0073e9SAndroid Build Coastguard Worker ) 871*da0073e9SAndroid Build Coastguard Worker ), 872*da0073e9SAndroid Build Coastguard Worker "def arange({}) -> Tensor: ...".format( 873*da0073e9SAndroid Build Coastguard Worker ", ".join( 874*da0073e9SAndroid Build Coastguard Worker [ 875*da0073e9SAndroid Build Coastguard Worker "end: Number", 876*da0073e9SAndroid Build Coastguard Worker "*", 877*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 878*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 879*da0073e9SAndroid Build Coastguard Worker ] 880*da0073e9SAndroid Build Coastguard Worker ) 881*da0073e9SAndroid Build Coastguard Worker ), 882*da0073e9SAndroid Build Coastguard Worker ], 883*da0073e9SAndroid Build Coastguard Worker "linspace": [ 884*da0073e9SAndroid Build Coastguard Worker "def linspace({}) -> Tensor: ...".format( 885*da0073e9SAndroid Build Coastguard Worker ", ".join( 886*da0073e9SAndroid Build Coastguard Worker [ 887*da0073e9SAndroid Build Coastguard Worker "start: Number", 888*da0073e9SAndroid Build Coastguard Worker "end: Number", 889*da0073e9SAndroid Build Coastguard Worker "steps: Optional[_int] = None", 890*da0073e9SAndroid Build Coastguard Worker "*", 891*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 892*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 893*da0073e9SAndroid Build Coastguard Worker ] 894*da0073e9SAndroid Build Coastguard Worker ) 895*da0073e9SAndroid Build Coastguard Worker ) 896*da0073e9SAndroid Build Coastguard Worker ], 897*da0073e9SAndroid Build Coastguard Worker "logspace": [ 898*da0073e9SAndroid Build Coastguard Worker "def logspace({}) -> Tensor: ...".format( 899*da0073e9SAndroid Build Coastguard Worker ", ".join( 900*da0073e9SAndroid Build Coastguard Worker [ 901*da0073e9SAndroid Build Coastguard Worker "start: Number", 902*da0073e9SAndroid Build Coastguard Worker "end: Number", 903*da0073e9SAndroid Build Coastguard Worker "steps: Optional[_int] = None", 904*da0073e9SAndroid Build Coastguard Worker "base: _float = 10.0", 905*da0073e9SAndroid Build Coastguard Worker "*", 906*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 907*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 908*da0073e9SAndroid Build Coastguard Worker ] 909*da0073e9SAndroid Build Coastguard Worker ) 910*da0073e9SAndroid Build Coastguard Worker ) 911*da0073e9SAndroid Build Coastguard Worker ], 912*da0073e9SAndroid Build Coastguard Worker "randint": [ 913*da0073e9SAndroid Build Coastguard Worker "def randint({}) -> Tensor: ...".format( 914*da0073e9SAndroid Build Coastguard Worker ", ".join( 915*da0073e9SAndroid Build Coastguard Worker [ 916*da0073e9SAndroid Build Coastguard Worker "low: _int", 917*da0073e9SAndroid Build Coastguard Worker "high: _int", 918*da0073e9SAndroid Build Coastguard Worker "size: _size", 919*da0073e9SAndroid Build Coastguard Worker "*", 920*da0073e9SAndroid Build Coastguard Worker "generator: Optional[Generator] = None", 921*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 922*da0073e9SAndroid Build Coastguard Worker ] 923*da0073e9SAndroid Build Coastguard Worker ) 924*da0073e9SAndroid Build Coastguard Worker ), 925*da0073e9SAndroid Build Coastguard Worker "def randint({}) -> Tensor: ...".format( 926*da0073e9SAndroid Build Coastguard Worker ", ".join( 927*da0073e9SAndroid Build Coastguard Worker [ 928*da0073e9SAndroid Build Coastguard Worker "high: _int", 929*da0073e9SAndroid Build Coastguard Worker "size: _size", 930*da0073e9SAndroid Build Coastguard Worker "*", 931*da0073e9SAndroid Build Coastguard Worker "generator: Optional[Generator] = None", 932*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 933*da0073e9SAndroid Build Coastguard Worker ] 934*da0073e9SAndroid Build Coastguard Worker ) 935*da0073e9SAndroid Build Coastguard Worker ), 936*da0073e9SAndroid Build Coastguard Worker ], 937*da0073e9SAndroid Build Coastguard Worker "full": [ 938*da0073e9SAndroid Build Coastguard Worker "def full({}) -> Tensor: ...".format( 939*da0073e9SAndroid Build Coastguard Worker ", ".join( 940*da0073e9SAndroid Build Coastguard Worker [ 941*da0073e9SAndroid Build Coastguard Worker "size: _size", 942*da0073e9SAndroid Build Coastguard Worker "fill_value: Union[Number, _complex]", 943*da0073e9SAndroid Build Coastguard Worker "*", 944*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 945*da0073e9SAndroid Build Coastguard Worker "layout: _layout = strided", 946*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 947*da0073e9SAndroid Build Coastguard Worker ] 948*da0073e9SAndroid Build Coastguard Worker ) 949*da0073e9SAndroid Build Coastguard Worker ), 950*da0073e9SAndroid Build Coastguard Worker "def full({}) -> Tensor: ...".format( 951*da0073e9SAndroid Build Coastguard Worker ", ".join( 952*da0073e9SAndroid Build Coastguard Worker [ 953*da0073e9SAndroid Build Coastguard Worker "size: _size", 954*da0073e9SAndroid Build Coastguard Worker "fill_value: Union[Number, _complex]", 955*da0073e9SAndroid Build Coastguard Worker "*", 956*da0073e9SAndroid Build Coastguard Worker "names: List[Union[str, None]]", 957*da0073e9SAndroid Build Coastguard Worker "layout: _layout = strided", 958*da0073e9SAndroid Build Coastguard Worker FACTORY_PARAMS, 959*da0073e9SAndroid Build Coastguard Worker ] 960*da0073e9SAndroid Build Coastguard Worker ) 961*da0073e9SAndroid Build Coastguard Worker ), 962*da0073e9SAndroid Build Coastguard Worker ], 963*da0073e9SAndroid Build Coastguard Worker "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."], 964*da0073e9SAndroid Build Coastguard Worker "is_inference_mode_enabled": [ 965*da0073e9SAndroid Build Coastguard Worker "def is_inference_mode_enabled() -> _bool: ..." 966*da0073e9SAndroid Build Coastguard Worker ], 967*da0073e9SAndroid Build Coastguard Worker "nonzero": [ 968*da0073e9SAndroid Build Coastguard Worker "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...", 969*da0073e9SAndroid Build Coastguard Worker "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", 970*da0073e9SAndroid Build Coastguard Worker ], 971*da0073e9SAndroid Build Coastguard Worker "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], 972*da0073e9SAndroid Build Coastguard Worker "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], 973*da0073e9SAndroid Build Coastguard Worker "saddmm": [ 974*da0073e9SAndroid Build Coastguard Worker "def saddmm({}) -> Tensor: ...".format( 975*da0073e9SAndroid Build Coastguard Worker ", ".join( 976*da0073e9SAndroid Build Coastguard Worker [ 977*da0073e9SAndroid Build Coastguard Worker "input: Tensor", 978*da0073e9SAndroid Build Coastguard Worker "mat1: Tensor", 979*da0073e9SAndroid Build Coastguard Worker "mat2: Tensor", 980*da0073e9SAndroid Build Coastguard Worker "*", 981*da0073e9SAndroid Build Coastguard Worker "beta: Number = 1", 982*da0073e9SAndroid Build Coastguard Worker "alpha: Number = 1", 983*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 984*da0073e9SAndroid Build Coastguard Worker ] 985*da0073e9SAndroid Build Coastguard Worker ) 986*da0073e9SAndroid Build Coastguard Worker ) 987*da0073e9SAndroid Build Coastguard Worker ], 988*da0073e9SAndroid Build Coastguard Worker "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], 989*da0073e9SAndroid Build Coastguard Worker "div": [ 990*da0073e9SAndroid Build Coastguard Worker "def div({}) -> Tensor: ...".format( 991*da0073e9SAndroid Build Coastguard Worker ", ".join( 992*da0073e9SAndroid Build Coastguard Worker [ 993*da0073e9SAndroid Build Coastguard Worker "input: Union[Tensor, Number]", 994*da0073e9SAndroid Build Coastguard Worker "other: Union[Tensor, Number]", 995*da0073e9SAndroid Build Coastguard Worker "*", 996*da0073e9SAndroid Build Coastguard Worker "rounding_mode: Optional[str] = None", 997*da0073e9SAndroid Build Coastguard Worker "out: Optional[Tensor] = None", 998*da0073e9SAndroid Build Coastguard Worker ] 999*da0073e9SAndroid Build Coastguard Worker ) 1000*da0073e9SAndroid Build Coastguard Worker ) 1001*da0073e9SAndroid Build Coastguard Worker ], 1002*da0073e9SAndroid Build Coastguard Worker } 1003*da0073e9SAndroid Build Coastguard Worker ) 1004*da0073e9SAndroid Build Coastguard Worker for binop in ["true_divide", "floor_divide"]: 1005*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints[binop].append( 1006*da0073e9SAndroid Build Coastguard Worker f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], " 1007*da0073e9SAndroid Build Coastguard Worker "*, out: Optional[Tensor] = None) -> Tensor: ..." 1008*da0073e9SAndroid Build Coastguard Worker ) 1009*da0073e9SAndroid Build Coastguard Worker for binop in ["mul"]: 1010*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints[binop].append( 1011*da0073e9SAndroid Build Coastguard Worker f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], " 1012*da0073e9SAndroid Build Coastguard Worker "*, out: Optional[Tensor] = None) -> Tensor: ..." 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker for binop in ["add", "sub"]: 1015*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints[binop].append( 1016*da0073e9SAndroid Build Coastguard Worker f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], " 1017*da0073e9SAndroid Build Coastguard Worker "*, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: ..." 1018*da0073e9SAndroid Build Coastguard Worker ) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker native_functions = parse_native_yaml( 1021*da0073e9SAndroid Build Coastguard Worker native_yaml_path, tags_yaml_path 1022*da0073e9SAndroid Build Coastguard Worker ).native_functions 1023*da0073e9SAndroid Build Coastguard Worker native_functions = list(filter(should_generate_py_binding, native_functions)) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker function_signatures = load_signatures( 1026*da0073e9SAndroid Build Coastguard Worker native_functions, deprecated_yaml_path, method=False, pyi=True 1027*da0073e9SAndroid Build Coastguard Worker ) 1028*da0073e9SAndroid Build Coastguard Worker sig_groups = get_py_torch_functions(function_signatures) 1029*da0073e9SAndroid Build Coastguard Worker for group in sorted(sig_groups, key=lambda g: g.signature.name): 1030*da0073e9SAndroid Build Coastguard Worker name = group.signature.name 1031*da0073e9SAndroid Build Coastguard Worker unsorted_function_hints[name] += generate_type_hints(group) 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Worker structseq = returns_structseq_pyi(group.signature) 1034*da0073e9SAndroid Build Coastguard Worker if structseq is not None and not group.signature.deprecated: 1035*da0073e9SAndroid Build Coastguard Worker # deprecated structseqs are currently not included for torch functions 1036*da0073e9SAndroid Build Coastguard Worker tuple_name, tuple_def = structseq 1037*da0073e9SAndroid Build Coastguard Worker if tuple_name in structseqs: 1038*da0073e9SAndroid Build Coastguard Worker assert structseqs[tuple_name] == tuple_def 1039*da0073e9SAndroid Build Coastguard Worker else: 1040*da0073e9SAndroid Build Coastguard Worker structseqs[tuple_name] = tuple_def 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker def replace_special_case(hint: str) -> str: 1043*da0073e9SAndroid Build Coastguard Worker # NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h 1044*da0073e9SAndroid Build Coastguard Worker hint = hint.replace("at::Reduction::Mean", "1") 1045*da0073e9SAndroid Build Coastguard Worker hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None") 1046*da0073e9SAndroid Build Coastguard Worker # Match both: 1047*da0073e9SAndroid Build Coastguard Worker # ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None" 1048*da0073e9SAndroid Build Coastguard Worker # ": Union[Tuple[Tensor, ...], List[Tensor]] = None" 1049*da0073e9SAndroid Build Coastguard Worker hint = hint.replace( 1050*da0073e9SAndroid Build Coastguard Worker "Tuple[Tensor, ...], List[Tensor]] = None", 1051*da0073e9SAndroid Build Coastguard Worker "Tuple[Tensor, ...], List[Tensor], None] = None", 1052*da0073e9SAndroid Build Coastguard Worker ) 1053*da0073e9SAndroid Build Coastguard Worker return hint 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker docstrs = gather_docstrs() 1056*da0073e9SAndroid Build Coastguard Worker function_hints = [] 1057*da0073e9SAndroid Build Coastguard Worker for name, hints in sorted(unsorted_function_hints.items()): 1058*da0073e9SAndroid Build Coastguard Worker hints = [replace_special_case(h) for h in hints] 1059*da0073e9SAndroid Build Coastguard Worker if len(hints) > 1: 1060*da0073e9SAndroid Build Coastguard Worker hints = ["@overload\n" + h for h in hints] 1061*da0073e9SAndroid Build Coastguard Worker docstr = docstrs.get(f"torch.{name}") 1062*da0073e9SAndroid Build Coastguard Worker if docstr is not None: 1063*da0073e9SAndroid Build Coastguard Worker hints = [add_docstr_to_hint(docstr, h) for h in hints] 1064*da0073e9SAndroid Build Coastguard Worker function_hints += hints 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker # Generate type signatures for Tensor methods 1067*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list) 1070*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints.update( 1071*da0073e9SAndroid Build Coastguard Worker { 1072*da0073e9SAndroid Build Coastguard Worker "size": [ 1073*da0073e9SAndroid Build Coastguard Worker "def size(self, dim: None = None) -> Size: ...", 1074*da0073e9SAndroid Build Coastguard Worker "def size(self, dim: _int) -> _int: ...", 1075*da0073e9SAndroid Build Coastguard Worker ], 1076*da0073e9SAndroid Build Coastguard Worker "stride": [ 1077*da0073e9SAndroid Build Coastguard Worker "def stride(self, dim: None = None) -> Tuple[_int, ...]: ...", 1078*da0073e9SAndroid Build Coastguard Worker "def stride(self, dim: _int) -> _int: ...", 1079*da0073e9SAndroid Build Coastguard Worker ], 1080*da0073e9SAndroid Build Coastguard Worker "new_ones": [ 1081*da0073e9SAndroid Build Coastguard Worker f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..." 1082*da0073e9SAndroid Build Coastguard Worker ], 1083*da0073e9SAndroid Build Coastguard Worker "new_tensor": [ 1084*da0073e9SAndroid Build Coastguard Worker f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..." 1085*da0073e9SAndroid Build Coastguard Worker ], 1086*da0073e9SAndroid Build Coastguard Worker "__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."], 1087*da0073e9SAndroid Build Coastguard Worker # new and __init__ have the same signatures differ only in return type 1088*da0073e9SAndroid Build Coastguard Worker # Adapted from legacy_tensor_ctor and legacy_tensor_new 1089*da0073e9SAndroid Build Coastguard Worker "new": [ 1090*da0073e9SAndroid Build Coastguard Worker f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...", 1091*da0073e9SAndroid Build Coastguard Worker "def new(cls, storage: Storage) -> Self: ...", 1092*da0073e9SAndroid Build Coastguard Worker "def new(cls, other: Tensor) -> Self: ...", 1093*da0073e9SAndroid Build Coastguard Worker f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...", 1094*da0073e9SAndroid Build Coastguard Worker ], 1095*da0073e9SAndroid Build Coastguard Worker "__init__": [ 1096*da0073e9SAndroid Build Coastguard Worker f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...", 1097*da0073e9SAndroid Build Coastguard Worker "def __init__(self, storage: Storage) -> None: ...", 1098*da0073e9SAndroid Build Coastguard Worker "def __init__(self, other: Tensor) -> None: ...", 1099*da0073e9SAndroid Build Coastguard Worker f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...", 1100*da0073e9SAndroid Build Coastguard Worker ], 1101*da0073e9SAndroid Build Coastguard Worker "as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."], 1102*da0073e9SAndroid Build Coastguard Worker "_make_subclass": [ 1103*da0073e9SAndroid Build Coastguard Worker "@staticmethod \ndef _make_subclass({}) -> S: ...".format( 1104*da0073e9SAndroid Build Coastguard Worker ", ".join( 1105*da0073e9SAndroid Build Coastguard Worker [ 1106*da0073e9SAndroid Build Coastguard Worker "cls: _Type[S]", 1107*da0073e9SAndroid Build Coastguard Worker "data: Tensor", 1108*da0073e9SAndroid Build Coastguard Worker "require_grad: _bool = False", 1109*da0073e9SAndroid Build Coastguard Worker "dispatch_strides: _bool = False", 1110*da0073e9SAndroid Build Coastguard Worker "dispatch_device: _bool = False", 1111*da0073e9SAndroid Build Coastguard Worker "device_for_backend_keys: Optional[_device] = None", 1112*da0073e9SAndroid Build Coastguard Worker ] 1113*da0073e9SAndroid Build Coastguard Worker ) 1114*da0073e9SAndroid Build Coastguard Worker ) 1115*da0073e9SAndroid Build Coastguard Worker ], 1116*da0073e9SAndroid Build Coastguard Worker "__contains__": ["def __contains__(self, other: Any, /) -> _bool: ..."], 1117*da0073e9SAndroid Build Coastguard Worker "__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."], 1118*da0073e9SAndroid Build Coastguard Worker "__setitem__": [ 1119*da0073e9SAndroid Build Coastguard Worker f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..." 1120*da0073e9SAndroid Build Coastguard Worker ], 1121*da0073e9SAndroid Build Coastguard Worker "tolist": ["def tolist(self) -> List: ..."], 1122*da0073e9SAndroid Build Coastguard Worker "requires_grad_": [ 1123*da0073e9SAndroid Build Coastguard Worker "def requires_grad_(self, mode: _bool = True) -> Tensor: ..." 1124*da0073e9SAndroid Build Coastguard Worker ], 1125*da0073e9SAndroid Build Coastguard Worker "element_size": ["def element_size(self) -> _int: ..."], 1126*da0073e9SAndroid Build Coastguard Worker "data_ptr": ["def data_ptr(self) -> _int: ..."], 1127*da0073e9SAndroid Build Coastguard Worker "dim": ["def dim(self) -> _int: ..."], 1128*da0073e9SAndroid Build Coastguard Worker "nonzero": [ 1129*da0073e9SAndroid Build Coastguard Worker "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...", 1130*da0073e9SAndroid Build Coastguard Worker "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", 1131*da0073e9SAndroid Build Coastguard Worker ], 1132*da0073e9SAndroid Build Coastguard Worker "numel": ["def numel(self) -> _int: ..."], 1133*da0073e9SAndroid Build Coastguard Worker "ndimension": ["def ndimension(self) -> _int: ..."], 1134*da0073e9SAndroid Build Coastguard Worker "nelement": ["def nelement(self) -> _int: ..."], 1135*da0073e9SAndroid Build Coastguard Worker "cuda": [ 1136*da0073e9SAndroid Build Coastguard Worker "def cuda({}) -> Tensor: ...".format( 1137*da0073e9SAndroid Build Coastguard Worker ", ".join( 1138*da0073e9SAndroid Build Coastguard Worker [ 1139*da0073e9SAndroid Build Coastguard Worker "self", 1140*da0073e9SAndroid Build Coastguard Worker "device: Optional[Union[_device, _int, str]] = None", 1141*da0073e9SAndroid Build Coastguard Worker "non_blocking: _bool = False", 1142*da0073e9SAndroid Build Coastguard Worker "memory_format: torch.memory_format = torch.preserve_format", 1143*da0073e9SAndroid Build Coastguard Worker ] 1144*da0073e9SAndroid Build Coastguard Worker ) 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker ], 1147*da0073e9SAndroid Build Coastguard Worker "xpu": [ 1148*da0073e9SAndroid Build Coastguard Worker "def xpu({}) -> Tensor: ...".format( 1149*da0073e9SAndroid Build Coastguard Worker ", ".join( 1150*da0073e9SAndroid Build Coastguard Worker [ 1151*da0073e9SAndroid Build Coastguard Worker "self", 1152*da0073e9SAndroid Build Coastguard Worker "device: Optional[Union[_device, _int, str]] = None", 1153*da0073e9SAndroid Build Coastguard Worker "non_blocking: _bool = False", 1154*da0073e9SAndroid Build Coastguard Worker "memory_format: torch.memory_format = torch.preserve_format", 1155*da0073e9SAndroid Build Coastguard Worker ] 1156*da0073e9SAndroid Build Coastguard Worker ) 1157*da0073e9SAndroid Build Coastguard Worker ) 1158*da0073e9SAndroid Build Coastguard Worker ], 1159*da0073e9SAndroid Build Coastguard Worker "cpu": [ 1160*da0073e9SAndroid Build Coastguard Worker "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..." 1161*da0073e9SAndroid Build Coastguard Worker ], 1162*da0073e9SAndroid Build Coastguard Worker "numpy": ["def numpy(self, *, force: _bool = False) -> numpy.ndarray: ..."], 1163*da0073e9SAndroid Build Coastguard Worker "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."], 1164*da0073e9SAndroid Build Coastguard Worker "map_": [ 1165*da0073e9SAndroid Build Coastguard Worker "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..." 1166*da0073e9SAndroid Build Coastguard Worker ], 1167*da0073e9SAndroid Build Coastguard Worker "map2_": [ 1168*da0073e9SAndroid Build Coastguard Worker "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..." 1169*da0073e9SAndroid Build Coastguard Worker ], 1170*da0073e9SAndroid Build Coastguard Worker "storage": ["def untyped_storage(self) -> UntypedStorage: ..."], 1171*da0073e9SAndroid Build Coastguard Worker "storage_type": ["def storage_type(self) -> Storage: ..."], 1172*da0073e9SAndroid Build Coastguard Worker "type": [ 1173*da0073e9SAndroid Build Coastguard Worker "def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...", 1174*da0073e9SAndroid Build Coastguard Worker "def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...", 1175*da0073e9SAndroid Build Coastguard Worker ], 1176*da0073e9SAndroid Build Coastguard Worker "get_device": ["def get_device(self) -> _int: ..."], 1177*da0073e9SAndroid Build Coastguard Worker "contiguous": [ 1178*da0073e9SAndroid Build Coastguard Worker "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..." 1179*da0073e9SAndroid Build Coastguard Worker ], 1180*da0073e9SAndroid Build Coastguard Worker "has_names": ["def has_names(self) -> _bool: ..."], 1181*da0073e9SAndroid Build Coastguard Worker "is_contiguous": [ 1182*da0073e9SAndroid Build Coastguard Worker "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..." 1183*da0073e9SAndroid Build Coastguard Worker ], 1184*da0073e9SAndroid Build Coastguard Worker "_is_view": ["def _is_view(self) -> _bool: ..."], 1185*da0073e9SAndroid Build Coastguard Worker "is_cpu": ["is_cpu: _bool"], 1186*da0073e9SAndroid Build Coastguard Worker "is_cuda": ["is_cuda: _bool"], 1187*da0073e9SAndroid Build Coastguard Worker "is_leaf": ["is_leaf: _bool"], 1188*da0073e9SAndroid Build Coastguard Worker "is_nested": ["is_nested: _bool"], 1189*da0073e9SAndroid Build Coastguard Worker "is_sparse": ["is_sparse: _bool"], 1190*da0073e9SAndroid Build Coastguard Worker "is_sparse_csr": ["is_sparse_csr: _bool"], 1191*da0073e9SAndroid Build Coastguard Worker "is_quantized": ["is_quantized: _bool"], 1192*da0073e9SAndroid Build Coastguard Worker "is_meta": ["is_meta: _bool"], 1193*da0073e9SAndroid Build Coastguard Worker "is_mps": ["is_mps: _bool"], 1194*da0073e9SAndroid Build Coastguard Worker "is_mtia": ["is_mtia: _bool"], 1195*da0073e9SAndroid Build Coastguard Worker "is_maia": ["is_maia: _bool"], 1196*da0073e9SAndroid Build Coastguard Worker "is_mkldnn": ["is_mkldnn: _bool"], 1197*da0073e9SAndroid Build Coastguard Worker "is_vulkan": ["is_vulkan: _bool"], 1198*da0073e9SAndroid Build Coastguard Worker "is_ipu": ["is_ipu: _bool"], 1199*da0073e9SAndroid Build Coastguard Worker "storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."], 1200*da0073e9SAndroid Build Coastguard Worker "to": [ 1201*da0073e9SAndroid Build Coastguard Worker ( 1202*da0073e9SAndroid Build Coastguard Worker f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, " 1203*da0073e9SAndroid Build Coastguard Worker "memory_format: Optional[torch.memory_format] = None) -> Tensor: ..." 1204*da0073e9SAndroid Build Coastguard Worker ) 1205*da0073e9SAndroid Build Coastguard Worker for args in [ 1206*da0073e9SAndroid Build Coastguard Worker "dtype: _dtype", 1207*da0073e9SAndroid Build Coastguard Worker "device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None", 1208*da0073e9SAndroid Build Coastguard Worker "other: Tensor", 1209*da0073e9SAndroid Build Coastguard Worker ] 1210*da0073e9SAndroid Build Coastguard Worker ], 1211*da0073e9SAndroid Build Coastguard Worker "item": ["def item(self) -> Number: ..."], 1212*da0073e9SAndroid Build Coastguard Worker "copy_": [ 1213*da0073e9SAndroid Build Coastguard Worker "def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..." 1214*da0073e9SAndroid Build Coastguard Worker ], 1215*da0073e9SAndroid Build Coastguard Worker "set_": [ 1216*da0073e9SAndroid Build Coastguard Worker "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], " 1217*da0073e9SAndroid Build Coastguard Worker "offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...", 1218*da0073e9SAndroid Build Coastguard Worker "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...", 1219*da0073e9SAndroid Build Coastguard Worker ], 1220*da0073e9SAndroid Build Coastguard Worker "split": [ 1221*da0073e9SAndroid Build Coastguard Worker "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...", 1222*da0073e9SAndroid Build Coastguard Worker "def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...", 1223*da0073e9SAndroid Build Coastguard Worker ], 1224*da0073e9SAndroid Build Coastguard Worker "div": [ 1225*da0073e9SAndroid Build Coastguard Worker "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." 1226*da0073e9SAndroid Build Coastguard Worker ], 1227*da0073e9SAndroid Build Coastguard Worker "div_": [ 1228*da0073e9SAndroid Build Coastguard Worker "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." 1229*da0073e9SAndroid Build Coastguard Worker ], 1230*da0073e9SAndroid Build Coastguard Worker } 1231*da0073e9SAndroid Build Coastguard Worker ) 1232*da0073e9SAndroid Build Coastguard Worker for binop in ["true_divide", "floor_divide"]: 1233*da0073e9SAndroid Build Coastguard Worker for inplace in [False, True]: 1234*da0073e9SAndroid Build Coastguard Worker out_suffix = ", *, out: Optional[Tensor] = None" 1235*da0073e9SAndroid Build Coastguard Worker if inplace: 1236*da0073e9SAndroid Build Coastguard Worker binop += "_" 1237*da0073e9SAndroid Build Coastguard Worker out_suffix = "" 1238*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints[binop].append( 1239*da0073e9SAndroid Build Coastguard Worker f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})" 1240*da0073e9SAndroid Build Coastguard Worker " -> Tensor: ..." 1241*da0073e9SAndroid Build Coastguard Worker ) 1242*da0073e9SAndroid Build Coastguard Worker for binop in ["mul"]: 1243*da0073e9SAndroid Build Coastguard Worker for inplace in [False, True]: 1244*da0073e9SAndroid Build Coastguard Worker out_suffix = ", *, out: Optional[Tensor] = None" 1245*da0073e9SAndroid Build Coastguard Worker if inplace: 1246*da0073e9SAndroid Build Coastguard Worker binop += "_" 1247*da0073e9SAndroid Build Coastguard Worker out_suffix = "" 1248*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints[binop].append( 1249*da0073e9SAndroid Build Coastguard Worker f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat]{out_suffix})" 1250*da0073e9SAndroid Build Coastguard Worker " -> Tensor: ..." 1251*da0073e9SAndroid Build Coastguard Worker ) 1252*da0073e9SAndroid Build Coastguard Worker for binop in ["add", "sub"]: 1253*da0073e9SAndroid Build Coastguard Worker for inplace in [False, True]: 1254*da0073e9SAndroid Build Coastguard Worker out_suffix = ", out: Optional[Tensor] = None" 1255*da0073e9SAndroid Build Coastguard Worker if inplace: 1256*da0073e9SAndroid Build Coastguard Worker binop += "_" 1257*da0073e9SAndroid Build Coastguard Worker out_suffix = "" 1258*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints[binop].append( 1259*da0073e9SAndroid Build Coastguard Worker f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], " 1260*da0073e9SAndroid Build Coastguard Worker f"*, alpha: Optional[Union[Number, _complex]] = 1{out_suffix})" 1261*da0073e9SAndroid Build Coastguard Worker " -> Tensor: ..." 1262*da0073e9SAndroid Build Coastguard Worker ) 1263*da0073e9SAndroid Build Coastguard Worker simple_conversions = [ 1264*da0073e9SAndroid Build Coastguard Worker "byte", 1265*da0073e9SAndroid Build Coastguard Worker "char", 1266*da0073e9SAndroid Build Coastguard Worker "double", 1267*da0073e9SAndroid Build Coastguard Worker "float", 1268*da0073e9SAndroid Build Coastguard Worker "half", 1269*da0073e9SAndroid Build Coastguard Worker "int", 1270*da0073e9SAndroid Build Coastguard Worker "long", 1271*da0073e9SAndroid Build Coastguard Worker "short", 1272*da0073e9SAndroid Build Coastguard Worker "bool", 1273*da0073e9SAndroid Build Coastguard Worker "bfloat16", 1274*da0073e9SAndroid Build Coastguard Worker ] 1275*da0073e9SAndroid Build Coastguard Worker for name in simple_conversions: 1276*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...") 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker # pyi tensor methods don't currently include deprecated signatures for some reason 1279*da0073e9SAndroid Build Coastguard Worker # TODO: we should probably add them in 1280*da0073e9SAndroid Build Coastguard Worker tensor_method_signatures = load_signatures( 1281*da0073e9SAndroid Build Coastguard Worker native_functions, 1282*da0073e9SAndroid Build Coastguard Worker deprecated_yaml_path, 1283*da0073e9SAndroid Build Coastguard Worker method=True, 1284*da0073e9SAndroid Build Coastguard Worker skip_deprecated=True, 1285*da0073e9SAndroid Build Coastguard Worker pyi=True, 1286*da0073e9SAndroid Build Coastguard Worker ) 1287*da0073e9SAndroid Build Coastguard Worker tensor_method_sig_groups = get_py_torch_functions( 1288*da0073e9SAndroid Build Coastguard Worker tensor_method_signatures, method=True 1289*da0073e9SAndroid Build Coastguard Worker ) 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): 1292*da0073e9SAndroid Build Coastguard Worker name = group.signature.name 1293*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints[name] += generate_type_hints(group) 1294*da0073e9SAndroid Build Coastguard Worker 1295*da0073e9SAndroid Build Coastguard Worker structseq = returns_structseq_pyi(group.signature) 1296*da0073e9SAndroid Build Coastguard Worker if structseq is not None and not group.signature.deprecated: 1297*da0073e9SAndroid Build Coastguard Worker # deprecated structseqs are currently not included for torch functions 1298*da0073e9SAndroid Build Coastguard Worker tuple_name, tuple_def = structseq 1299*da0073e9SAndroid Build Coastguard Worker if tuple_name in structseqs: 1300*da0073e9SAndroid Build Coastguard Worker assert structseqs[tuple_name] == tuple_def 1301*da0073e9SAndroid Build Coastguard Worker else: 1302*da0073e9SAndroid Build Coastguard Worker structseqs[tuple_name] = tuple_def 1303*da0073e9SAndroid Build Coastguard Worker 1304*da0073e9SAndroid Build Coastguard Worker for op in all_ops: 1305*da0073e9SAndroid Build Coastguard Worker name = f"__{op}__" 1306*da0073e9SAndroid Build Coastguard Worker unsorted_tensor_method_hints[name] += sig_for_ops(name) 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker tensor_method_hints = [] 1309*da0073e9SAndroid Build Coastguard Worker for name, hints in sorted(unsorted_tensor_method_hints.items()): 1310*da0073e9SAndroid Build Coastguard Worker if len(hints) > 1: 1311*da0073e9SAndroid Build Coastguard Worker hints = ["@overload\n" + h for h in hints] 1312*da0073e9SAndroid Build Coastguard Worker docstr = docstrs.get(f"torch._C.TensorBase.{name}") 1313*da0073e9SAndroid Build Coastguard Worker if docstr is not None: 1314*da0073e9SAndroid Build Coastguard Worker hints = [add_docstr_to_hint(docstr, h) for h in hints] 1315*da0073e9SAndroid Build Coastguard Worker tensor_method_hints += hints 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker # TODO: Missing type hints for nn 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker # Generate structseq definitions 1320*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker structseq_defs = [f"{defn}\n" for defn in structseqs.values()] 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker # Generate type signatures for legacy classes 1325*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker legacy_storage_base_hints = ["class StorageBase(object): ..."] 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker legacy_class_hints = [] 1330*da0073e9SAndroid Build Coastguard Worker for c in ( 1331*da0073e9SAndroid Build Coastguard Worker "DoubleTensor", 1332*da0073e9SAndroid Build Coastguard Worker "FloatTensor", 1333*da0073e9SAndroid Build Coastguard Worker "BFloat16Tensor", 1334*da0073e9SAndroid Build Coastguard Worker "LongTensor", 1335*da0073e9SAndroid Build Coastguard Worker "IntTensor", 1336*da0073e9SAndroid Build Coastguard Worker "ShortTensor", 1337*da0073e9SAndroid Build Coastguard Worker "HalfTensor", 1338*da0073e9SAndroid Build Coastguard Worker "CharTensor", 1339*da0073e9SAndroid Build Coastguard Worker "ByteTensor", 1340*da0073e9SAndroid Build Coastguard Worker "BoolTensor", 1341*da0073e9SAndroid Build Coastguard Worker ): 1342*da0073e9SAndroid Build Coastguard Worker legacy_class_hints.append(f"class {c}(Tensor): ...") 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker # Generate type signatures for dtype classes 1345*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker # TODO: don't explicitly list dtypes here; get it from canonical 1348*da0073e9SAndroid Build Coastguard Worker # source 1349*da0073e9SAndroid Build Coastguard Worker dtype_class_hints = [ 1350*da0073e9SAndroid Build Coastguard Worker f"{n}: dtype = ..." 1351*da0073e9SAndroid Build Coastguard Worker for n in [ 1352*da0073e9SAndroid Build Coastguard Worker "float32", 1353*da0073e9SAndroid Build Coastguard Worker "float", 1354*da0073e9SAndroid Build Coastguard Worker "float64", 1355*da0073e9SAndroid Build Coastguard Worker "double", 1356*da0073e9SAndroid Build Coastguard Worker "float16", 1357*da0073e9SAndroid Build Coastguard Worker "bfloat16", 1358*da0073e9SAndroid Build Coastguard Worker "float8_e4m3fn", 1359*da0073e9SAndroid Build Coastguard Worker "float8_e4m3fnuz", 1360*da0073e9SAndroid Build Coastguard Worker "float8_e5m2", 1361*da0073e9SAndroid Build Coastguard Worker "float8_e5m2fnuz", 1362*da0073e9SAndroid Build Coastguard Worker "half", 1363*da0073e9SAndroid Build Coastguard Worker "uint8", 1364*da0073e9SAndroid Build Coastguard Worker "uint16", 1365*da0073e9SAndroid Build Coastguard Worker "uint32", 1366*da0073e9SAndroid Build Coastguard Worker "uint64", 1367*da0073e9SAndroid Build Coastguard Worker "int8", 1368*da0073e9SAndroid Build Coastguard Worker "int16", 1369*da0073e9SAndroid Build Coastguard Worker "short", 1370*da0073e9SAndroid Build Coastguard Worker "int32", 1371*da0073e9SAndroid Build Coastguard Worker "int", 1372*da0073e9SAndroid Build Coastguard Worker "int64", 1373*da0073e9SAndroid Build Coastguard Worker "long", 1374*da0073e9SAndroid Build Coastguard Worker "complex32", 1375*da0073e9SAndroid Build Coastguard Worker "complex64", 1376*da0073e9SAndroid Build Coastguard Worker "chalf", 1377*da0073e9SAndroid Build Coastguard Worker "cfloat", 1378*da0073e9SAndroid Build Coastguard Worker "complex128", 1379*da0073e9SAndroid Build Coastguard Worker "cdouble", 1380*da0073e9SAndroid Build Coastguard Worker "quint8", 1381*da0073e9SAndroid Build Coastguard Worker "qint8", 1382*da0073e9SAndroid Build Coastguard Worker "qint32", 1383*da0073e9SAndroid Build Coastguard Worker "bool", 1384*da0073e9SAndroid Build Coastguard Worker "quint4x2", 1385*da0073e9SAndroid Build Coastguard Worker "quint2x4", 1386*da0073e9SAndroid Build Coastguard Worker "bits1x8", 1387*da0073e9SAndroid Build Coastguard Worker "bits2x4", 1388*da0073e9SAndroid Build Coastguard Worker "bits4x2", 1389*da0073e9SAndroid Build Coastguard Worker "bits8", 1390*da0073e9SAndroid Build Coastguard Worker "bits16", 1391*da0073e9SAndroid Build Coastguard Worker ] 1392*da0073e9SAndroid Build Coastguard Worker ] 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker # Generate __all__ directive 1395*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker # Include only the functions that contain hints, to prevent undefined 1398*da0073e9SAndroid Build Coastguard Worker # symbols to be included in the `__all__` directive. 1399*da0073e9SAndroid Build Coastguard Worker hinted_function_names = [ 1400*da0073e9SAndroid Build Coastguard Worker name for name, hint in unsorted_function_hints.items() if hint 1401*da0073e9SAndroid Build Coastguard Worker ] 1402*da0073e9SAndroid Build Coastguard Worker all_symbols = sorted(list(structseqs.keys()) + hinted_function_names) 1403*da0073e9SAndroid Build Coastguard Worker all_directive = pformat(all_symbols, width=100, compact=True).split("\n") 1404*da0073e9SAndroid Build Coastguard Worker all_directive[0] = f"__all__ = {all_directive[0]}" 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker # Dispatch key hints 1407*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~ 1408*da0073e9SAndroid Build Coastguard Worker dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] 1409*da0073e9SAndroid Build Coastguard Worker torch_dispatch_mode_key_hints = [ 1410*da0073e9SAndroid Build Coastguard Worker f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey 1411*da0073e9SAndroid Build Coastguard Worker ] 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker # Tags Enum type hints 1414*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~ 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker tag_names = sorted(parse_tags_yaml(tags_yaml_path)) 1417*da0073e9SAndroid Build Coastguard Worker tag_attributes = "\n".join( 1418*da0073e9SAndroid Build Coastguard Worker f"{name}: _int = {index}" for index, name in enumerate(tag_names) 1419*da0073e9SAndroid Build Coastguard Worker ) 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker # Write out the stub 1422*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~ 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker env = { 1425*da0073e9SAndroid Build Coastguard Worker "structseq_defs": structseq_defs, 1426*da0073e9SAndroid Build Coastguard Worker "function_hints": function_hints, 1427*da0073e9SAndroid Build Coastguard Worker "tensor_method_hints": tensor_method_hints, 1428*da0073e9SAndroid Build Coastguard Worker "legacy_class_hints": legacy_class_hints, 1429*da0073e9SAndroid Build Coastguard Worker "legacy_storage_base_hints": legacy_storage_base_hints, 1430*da0073e9SAndroid Build Coastguard Worker "dtype_class_hints": dtype_class_hints, 1431*da0073e9SAndroid Build Coastguard Worker "dispatch_key_hints": dispatch_key_hints, 1432*da0073e9SAndroid Build Coastguard Worker "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints, 1433*da0073e9SAndroid Build Coastguard Worker "all_directive": all_directive, 1434*da0073e9SAndroid Build Coastguard Worker "tag_attributes": tag_attributes, 1435*da0073e9SAndroid Build Coastguard Worker } 1436*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1437*da0073e9SAndroid Build Coastguard Worker "torch/_C/__init__.pyi", 1438*da0073e9SAndroid Build Coastguard Worker "torch/_C/__init__.pyi.in", 1439*da0073e9SAndroid Build Coastguard Worker lambda: env, 1440*da0073e9SAndroid Build Coastguard Worker ) 1441*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1442*da0073e9SAndroid Build Coastguard Worker "torch/_C/_VariableFunctions.pyi", 1443*da0073e9SAndroid Build Coastguard Worker "torch/_C/_VariableFunctions.pyi.in", 1444*da0073e9SAndroid Build Coastguard Worker lambda: env, 1445*da0073e9SAndroid Build Coastguard Worker ) 1446*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1447*da0073e9SAndroid Build Coastguard Worker "torch/_VF.pyi", 1448*da0073e9SAndroid Build Coastguard Worker "torch/_C/_VariableFunctions.pyi.in", 1449*da0073e9SAndroid Build Coastguard Worker lambda: env, 1450*da0073e9SAndroid Build Coastguard Worker ) 1451*da0073e9SAndroid Build Coastguard Worker fm.write_with_template( 1452*da0073e9SAndroid Build Coastguard Worker "torch/return_types.pyi", 1453*da0073e9SAndroid Build Coastguard Worker "torch/_C/return_types.pyi.in", 1454*da0073e9SAndroid Build Coastguard Worker lambda: env, 1455*da0073e9SAndroid Build Coastguard Worker ) 1456*da0073e9SAndroid Build Coastguard Worker gen_nn_functional(fm) 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Workerdef main() -> None: 1460*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch") 1461*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1462*da0073e9SAndroid Build Coastguard Worker "--native-functions-path", 1463*da0073e9SAndroid Build Coastguard Worker metavar="NATIVE", 1464*da0073e9SAndroid Build Coastguard Worker default="aten/src/ATen/native/native_functions.yaml", 1465*da0073e9SAndroid Build Coastguard Worker help="path to native_functions.yaml", 1466*da0073e9SAndroid Build Coastguard Worker ) 1467*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1468*da0073e9SAndroid Build Coastguard Worker "--tags-path", 1469*da0073e9SAndroid Build Coastguard Worker metavar="TAGS", 1470*da0073e9SAndroid Build Coastguard Worker default="aten/src/ATen/native/tags.yaml", 1471*da0073e9SAndroid Build Coastguard Worker help="path to tags.yaml", 1472*da0073e9SAndroid Build Coastguard Worker ) 1473*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1474*da0073e9SAndroid Build Coastguard Worker "--deprecated-functions-path", 1475*da0073e9SAndroid Build Coastguard Worker metavar="DEPRECATED", 1476*da0073e9SAndroid Build Coastguard Worker default="tools/autograd/deprecated.yaml", 1477*da0073e9SAndroid Build Coastguard Worker help="path to deprecated.yaml", 1478*da0073e9SAndroid Build Coastguard Worker ) 1479*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 1480*da0073e9SAndroid Build Coastguard Worker "--out", metavar="OUT", default=".", help="path to output directory" 1481*da0073e9SAndroid Build Coastguard Worker ) 1482*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 1483*da0073e9SAndroid Build Coastguard Worker fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) 1484*da0073e9SAndroid Build Coastguard Worker gen_pyi( 1485*da0073e9SAndroid Build Coastguard Worker args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm 1486*da0073e9SAndroid Build Coastguard Worker ) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1490*da0073e9SAndroid Build Coastguard Worker main() 1491