xref: /aosp_15_r20/external/pytorch/torch/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2The torch package contains data structures for multi-dimensional
3tensors and defines mathematical operations over these tensors.
4Additionally, it provides many utilities for efficient serialization of
5Tensors and arbitrary types, and other useful utilities.
6
7It has a CUDA counterpart, that enables you to run your tensor computations
8on an NVIDIA GPU with compute capability >= 3.0.
9"""
10
11# mypy: allow-untyped-defs
12
13import builtins
14import ctypes
15import glob
16import importlib
17import inspect
18import math
19import os
20import platform
21import sys
22import textwrap
23import threading
24from typing import (
25    Any as _Any,
26    Callable as _Callable,
27    Dict as _Dict,
28    Optional as _Optional,
29    overload as _overload,
30    Set as _Set,
31    Tuple as _Tuple,
32    Type as _Type,
33    TYPE_CHECKING,
34    TypeVar as _TypeVar,
35    Union as _Union,
36)
37from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
38
39
40if TYPE_CHECKING:
41    from .types import IntLikeType
42
43
44# multipy/deploy is setting this import before importing torch, this is the most
45# reliable way we have to detect if we're running within deploy.
46# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137
47def _running_with_deploy() -> builtins.bool:
48    return sys.modules.get("torch._meta_registrations", None) is object
49
50
51from torch._utils import (
52    _functionalize_sync as _sync,
53    _import_dotted_name,
54    classproperty,
55)
56from torch._utils_internal import (
57    get_file_path,
58    prepare_multiprocessing_environment,
59    USE_GLOBAL_DEPS,
60    USE_RTLD_GLOBAL_WITH_LIBTORCH,
61)
62
63
64# TODO(torch_deploy) figure out how to freeze version.py in fbcode build
65if _running_with_deploy():
66    __version__ = "torch-deploy-1.8"
67else:
68    from torch.torch_version import __version__ as __version__
69
70__all__ = [
71    "BoolStorage",
72    "BoolTensor",
73    "ByteStorage",
74    "ByteTensor",
75    "CharStorage",
76    "CharTensor",
77    "DoubleStorage",
78    "DoubleTensor",
79    "FloatStorage",
80    "FloatTensor",
81    "GradScaler",
82    "IntStorage",
83    "IntTensor",
84    "LongStorage",
85    "LongTensor",
86    "ShortStorage",
87    "ShortTensor",
88    "SymBool",
89    "SymFloat",
90    "SymInt",
91    "Tensor",
92    "TypedStorage",
93    "UntypedStorage",
94    "are_deterministic_algorithms_enabled",
95    "autocast",
96    "chunk",
97    "compile",
98    "cond",
99    "enable_grad",
100    "export",
101    "get_default_device",
102    "get_deterministic_debug_mode",
103    "get_device_module",
104    "get_float32_matmul_precision",
105    "get_rng_state",
106    "inference_mode",
107    "initial_seed",
108    "is_deterministic_algorithms_warn_only_enabled",
109    "is_storage",
110    "is_tensor",
111    "is_warn_always_enabled",
112    "load",
113    "lobpcg",
114    "manual_seed",
115    "matmul",
116    "no_grad",
117    "rand",
118    "randn",
119    "save",
120    "seed",
121    "set_default_device",
122    "set_default_tensor_type",
123    "set_deterministic_debug_mode",
124    "set_float32_matmul_precision",
125    "set_printoptions",
126    "set_rng_state",
127    "set_warn_always",
128    "split",
129    "stack",
130    "sym_float",
131    "sym_int",
132    "sym_ite",
133    "sym_max",
134    "sym_min",
135    "sym_not",
136    "typename",
137    "unravel_index",
138    "use_deterministic_algorithms",
139    "vmap",
140]
141
142# Please keep this list sorted
143assert __all__ == sorted(__all__)
144
145################################################################################
146# Load the extension module
147################################################################################
148
149if sys.platform == "win32":
150
151    def _load_dll_libraries() -> None:
152        import sysconfig
153
154        from torch.version import cuda as cuda_version
155
156        pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files")
157        py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
158        th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
159        usebase_path = os.path.join(
160            sysconfig.get_config_var("userbase"), "Library", "bin"
161        )
162
163        # When users create a virtualenv that inherits the base environment,
164        # we will need to add the corresponding library directory into
165        # DLL search directories. Otherwise, it will rely on `PATH` which
166        # is dependent on user settings.
167        if sys.exec_prefix != sys.base_exec_prefix:
168            base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
169        else:
170            base_py_dll_path = ""
171
172        dll_paths = [
173            p
174            for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path)
175            if os.path.exists(p)
176        ]
177
178        if not builtins.any(
179            os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths
180        ):
181            nvtoolsext_dll_path = os.path.join(
182                os.getenv(
183                    "NVTOOLSEXT_PATH",
184                    os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"),
185                ),
186                "bin",
187                "x64",
188            )
189        else:
190            nvtoolsext_dll_path = ""
191
192        if cuda_version and builtins.all(
193            not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths
194        ):
195            cuda_version_1 = cuda_version.replace(".", "_")
196            cuda_path_var = "CUDA_PATH_V" + cuda_version_1
197            default_path = os.path.join(
198                pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}"
199            )
200            cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin")
201        else:
202            cuda_path = ""
203
204        dll_paths.extend(
205            p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)
206        )
207
208        kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
209        with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
210        prev_error_mode = kernel32.SetErrorMode(0x0001)
211
212        kernel32.LoadLibraryW.restype = ctypes.c_void_p
213        if with_load_library_flags:
214            kernel32.LoadLibraryExW.restype = ctypes.c_void_p
215
216        for dll_path in dll_paths:
217            os.add_dll_directory(dll_path)
218
219        try:
220            ctypes.CDLL("vcruntime140.dll")
221            ctypes.CDLL("msvcp140.dll")
222            ctypes.CDLL("vcruntime140_1.dll")
223        except OSError:
224            print(
225                textwrap.dedent(
226                    """
227                    Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
228                    It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe
229                    """
230                ).strip()
231            )
232
233        dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
234        path_patched = False
235        for dll in dlls:
236            is_loaded = False
237            if with_load_library_flags:
238                res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
239                last_error = ctypes.get_last_error()
240                if res is None and last_error != 126:
241                    err = ctypes.WinError(last_error)
242                    err.strerror += (
243                        f' Error loading "{dll}" or one of its dependencies.'
244                    )
245                    raise err
246                elif res is not None:
247                    is_loaded = True
248            if not is_loaded:
249                if not path_patched:
250                    os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
251                    path_patched = True
252                res = kernel32.LoadLibraryW(dll)
253                if res is None:
254                    err = ctypes.WinError(ctypes.get_last_error())
255                    err.strerror += (
256                        f' Error loading "{dll}" or one of its dependencies.'
257                    )
258                    raise err
259
260        kernel32.SetErrorMode(prev_error_mode)
261
262    _load_dll_libraries()
263    del _load_dll_libraries
264
265
266def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
267    """Preloads cuda deps if they could not be found otherwise."""
268    # Should only be called on Linux if default path resolution have failed
269    assert platform.system() == "Linux", "Should only be called on Linux"
270
271    lib_path = None
272    for path in sys.path:
273        nvidia_path = os.path.join(path, "nvidia")
274        if not os.path.exists(nvidia_path):
275            continue
276        candidate_lib_paths = glob.glob(
277            os.path.join(nvidia_path, lib_folder, "lib", lib_name)
278        )
279        if candidate_lib_paths and not lib_path:
280            lib_path = candidate_lib_paths[0]
281        if lib_path:
282            break
283    if not lib_path:
284        raise ValueError(f"{lib_name} not found in the system path {sys.path}")
285    ctypes.CDLL(lib_path)
286
287
288# See Note [Global dependencies]
289def _load_global_deps() -> None:
290    if _running_with_deploy() or platform.system() == "Windows":
291        return
292
293    # Determine the file extension based on the platform
294    lib_ext = ".dylib" if platform.system() == "Darwin" else ".so"
295    lib_name = f"libtorch_global_deps{lib_ext}"
296    here = os.path.abspath(__file__)
297    global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name)
298
299    try:
300        ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
301    except OSError as err:
302        # Can only happen for wheel with cuda libs as PYPI deps
303        # As PyTorch is not purelib, but nvidia-*-cu12 is
304        cuda_libs: _Dict[str, str] = {
305            "cublas": "libcublas.so.*[0-9]",
306            "cudnn": "libcudnn.so.*[0-9]",
307            "cuda_nvrtc": "libnvrtc.so.*[0-9]",
308            "cuda_runtime": "libcudart.so.*[0-9]",
309            "cuda_cupti": "libcupti.so.*[0-9]",
310            "cufft": "libcufft.so.*[0-9]",
311            "curand": "libcurand.so.*[0-9]",
312            "nvjitlink": "libnvJitLink.so.*[0-9]",
313            "cusparse": "libcusparse.so.*[0-9]",
314            "cusolver": "libcusolver.so.*[0-9]",
315            "nccl": "libnccl.so.*[0-9]",
316            "nvtx": "libnvToolsExt.so.*[0-9]",
317        }
318        is_cuda_lib_err = [
319            lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0]
320        ]
321        if not is_cuda_lib_err:
322            raise err
323        for lib_folder, lib_name in cuda_libs.items():
324            _preload_cuda_deps(lib_folder, lib_name)
325        ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
326
327
328if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
329    _running_with_deploy() or platform.system() != "Windows"
330):
331    # Do it the hard way.  You might want to load libtorch with RTLD_GLOBAL in a
332    # few circumstances:
333    #
334    #   1. You're in a build environment (e.g., fbcode) where
335    #      libtorch_global_deps is not available, but you still need
336    #      to get mkl to link in with RTLD_GLOBAL or it will just
337    #      not work.
338    #
339    #   2. You're trying to run PyTorch under UBSAN and you need
340    #      to ensure that only one copy of libtorch is loaded, so
341    #      vptr checks work properly
342    #
343    # If you're using this setting, you must verify that all the libraries
344    # you load consistently use the same libstdc++, or you may have
345    # mysterious segfaults.
346    #
347    old_flags = sys.getdlopenflags()
348    sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
349
350    from torch._C import *  # noqa: F403
351
352    sys.setdlopenflags(old_flags)
353    del old_flags
354
355else:
356    # Easy way.  You want this most of the time, because it will prevent
357    # C++ symbols from libtorch clobbering C++ symbols from other
358    # libraries, leading to mysterious segfaults.
359    #
360    # If building in an environment where libtorch_global_deps isn't available
361    # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
362    # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
363    #
364    # See Note [Global dependencies]
365    if USE_GLOBAL_DEPS:
366        _load_global_deps()
367    from torch._C import *  # noqa: F403
368
369
370class SymInt:
371    """
372    Like an int (including magic methods), but redirects all operations on the
373    wrapped node. This is used in particular to symbolically record operations
374    in the symbolic shape workflow.
375    """
376
377    def __init__(self, node):
378        # This field MUST be named node; C++ binding code assumes that this
379        # class has a field named node that stores SymNode
380        self.node = node
381
382    def __bool__(self):
383        return builtins.bool(self != 0)
384
385    def __int__(self):
386        return self.node.int_()
387
388    def __index__(self):
389        return self.node.int_()
390
391    # Magic methods installed by torch.fx.experimental.sym_node
392
393    def __round__(self, ndigits=None):
394        return self
395
396    def __truediv__(self, other):
397        if isinstance(other, (builtins.float, SymFloat)):
398            return sym_float(self).__float_truediv__(other)
399        if not isinstance(other, (builtins.int, SymInt)):
400            return NotImplemented
401        return self.__int_truediv__(other)
402
403    def __rtruediv__(self, other):
404        if isinstance(other, (builtins.float, SymFloat)):
405            return sym_float(self).__rfloat_truediv__(other)
406        if not isinstance(other, (builtins.int, SymInt)):
407            return NotImplemented
408        return self.__rint_truediv__(other)
409
410    def __floordiv__(self, other):
411        if isinstance(other, (builtins.float, SymFloat)):
412            return sym_float(math.floor(sym_float(self) / other))
413        if not isinstance(other, (builtins.int, SymInt)):
414            return NotImplemented
415        return self.__int_floordiv__(other)
416
417    def __rfloordiv__(self, other):
418        if isinstance(other, (builtins.float, SymFloat)):
419            return sym_float(math.floor(other / sym_float(self)))
420        if not isinstance(other, (builtins.int, SymInt)):
421            return NotImplemented
422        return self.__rint_floordiv__(other)
423
424    # nb: complex is impossible to handle correctly lol, with
425    # negative base and integral float need to diverge semantics and
426    # just always return complex.  Neener neener pretend this problem
427    # doesn't exist
428    def __pow__(self, other):
429        if isinstance(other, (builtins.float, SymFloat)):
430            return sym_float(self).__pow__(other)
431        if not isinstance(other, (builtins.int, SymInt)):
432            return NotImplemented
433        # Guards!  This guard is necessary because we need to know it to
434        # determine the output type of this operation
435        if other >= 0:
436            return self.__pow_by_natural__(other)
437        else:
438            # Mercifully, when the exponent is negative, Python just promotes
439            # to doubles and does a float pow:
440            #
441            #   if (Py_SIZE(b) < 0 && c == NULL) {
442            #       /* if exponent is negative and there's no modulus:
443            #              return a float.  This works because we know
444            #              that this calls float_pow() which converts its
445            #              arguments to double. */
446            #       Py_DECREF(a);
447            #       Py_DECREF(b);
448            #       return PyFloat_Type.tp_as_number->nb_power(v, w, x);
449            #   }
450            return sym_float(self).__pow__(sym_float(other))
451
452    def __rpow__(self, other):
453        if isinstance(other, (builtins.float, SymFloat)):
454            return sym_float(self).__rpow__(other)
455        if not isinstance(other, (builtins.int, SymInt)):
456            return NotImplemented
457        if self >= 0:  # self is exponent
458            return self.__rpow_by_natural__(other)
459        else:
460            return sym_float(self).__rpow__(sym_float(other))
461
462    def __eq__(self, other: object) -> builtins.bool:
463        raise TypeError("type stub not overridden")
464
465    def __lt__(self, other) -> builtins.bool:
466        raise TypeError("type stub not overridden")
467
468    def __gt__(self, other) -> builtins.bool:
469        raise TypeError("type stub not overridden")
470
471    def __le__(self, other) -> builtins.bool:
472        raise TypeError("type stub not overridden")
473
474    def __ge__(self, other) -> builtins.bool:
475        raise TypeError("type stub not overridden")
476
477    def __add__(self, other) -> "SymInt":
478        raise TypeError("type stub not overridden")
479
480    def __mod__(self, other: "IntLikeType") -> "SymInt":
481        raise TypeError("type stub not overridden")
482
483    def __mul__(self, other) -> "SymInt":
484        raise TypeError("type stub not overridden")
485
486    def __pow_by_natural__(self, other) -> "SymInt":
487        raise TypeError("type stub not overridden")
488
489    def __rpow_by_natural__(self, other) -> "SymInt":
490        raise TypeError("type stub not overridden")
491
492    def __int_truediv__(self, other) -> "SymFloat":
493        raise TypeError("type stub not overridden")
494
495    def __rint_truediv__(self, other) -> "SymFloat":
496        raise TypeError("type stub not overridden")
497
498    def __int_floordiv__(self, other) -> "SymFloat":
499        raise TypeError("type stub not overridden")
500
501    def __rint_floordiv__(self, other) -> "SymFloat":
502        raise TypeError("type stub not overridden")
503
504    def __sym_max__(self, other):
505        raise TypeError("type stub not overridden")
506
507    def __sym_min__(self, other):
508        raise TypeError("type stub not overridden")
509
510    def __sym_float__(self):
511        raise TypeError("type stub not overridden")
512
513    def __neg__(self):
514        raise TypeError("type stub not overridden")
515
516    def __sub__(self, other: "IntLikeType") -> "SymInt":
517        raise TypeError("type stub not overridden")
518
519    def __repr__(self):
520        return self.node._graph_repr()
521
522    def _sympy_(self):
523        return self.node.expr
524
525    def __hash__(self) -> builtins.int:
526        if self.node.is_nested_int():
527            return hash(self.node.nested_int())
528        else:
529            # We could support constant SymInts as well, but not doing it for now
530            raise TypeError("unhashable type: non-nested SymInt")
531            # TODO: Force specialization
532            # This can't be done because the TypeError here is load bearing
533            # for einops
534            # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
535            # return hash(builtins.int(self))
536
537    def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
538        """Represent this int as an exact integer ratio"""
539        return self, 1
540
541    def bit_length(self) -> builtins.int:
542        # TODO: A more relaxed guard is possible here, where you guard to
543        # allow all integer quantities which would result in the same bit
544        # length.  We can also just make a dedicated Sympy function for
545        # computing this quantity and represent it symbolically.
546        return builtins.int(self).bit_length()
547
548    def conjugate(self) -> "SymInt":
549        return self
550
551
552class SymFloat:
553    """
554    Like an float (including magic methods), but redirects all operations on the
555    wrapped node. This is used in particular to symbolically record operations
556    in the symbolic shape workflow.
557    """
558
559    def __init__(self, node):
560        # This field MUST be named node; C++ binding code assumes that this
561        # class has a field named node that stores SymNode
562        self.node = node
563
564    def __truediv__(self, other):
565        if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
566            return NotImplemented
567        return self.__float_truediv__(sym_float(other))
568
569    def __rtruediv__(self, other):
570        if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
571            return NotImplemented
572        return self.__rfloat_truediv__(sym_float(other))
573
574    def __floordiv__(self, other):
575        if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
576            return NotImplemented
577        return sym_float(math.floor(self / sym_float(other)))
578
579    def __rfloordiv__(self, other):
580        if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
581            return NotImplemented
582        return sym_float(math.floor(sym_float(other) / self))
583
584    def __bool__(self):
585        return self.node.bool_()
586
587    def __float__(self):
588        return self.node.guard_float("", 0)
589
590    # Symbolic power does NOT work with negative base, this is to avoid
591    # potential complex outputs
592    def __pow__(self, other):
593        if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
594            return NotImplemented
595        torch._check(self >= 0)
596        return self.__float_pow__(other)
597
598    def __rpow__(self, other):
599        if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
600            return NotImplemented
601        torch._check(other >= 0)
602        return self.__rfloat_pow__(other)
603
604    # Magic methods installed by torch.fx.experimental.sym_node
605
606    def __eq__(self, other: object) -> builtins.bool:
607        raise TypeError("type stub not overridden")
608
609    def __lt__(self, other) -> builtins.bool:
610        raise TypeError("type stub not overridden")
611
612    def __gt__(self, other) -> builtins.bool:
613        raise TypeError("type stub not overridden")
614
615    def __le__(self, other) -> builtins.bool:
616        raise TypeError("type stub not overridden")
617
618    def __ge__(self, other) -> builtins.bool:
619        raise TypeError("type stub not overridden")
620
621    def __float_pow__(self, other) -> "SymFloat":
622        raise TypeError("type stub not overridden")
623
624    def __rfloat_pow__(self, other) -> "SymFloat":
625        raise TypeError("type stub not overridden")
626
627    def __float_truediv__(self, other) -> "SymFloat":
628        raise TypeError("type stub not overridden")
629
630    def __rfloat_truediv__(self, other) -> "SymFloat":
631        raise TypeError("type stub not overridden")
632
633    def __trunc__(self):
634        raise TypeError("type stub not overridden")
635
636    def __sym_max__(self, other):
637        raise TypeError("type stub not overridden")
638
639    def __sym_min__(self, other):
640        raise TypeError("type stub not overridden")
641
642    def __sym_int__(self):
643        raise TypeError("type stub not overridden")
644
645    def is_integer(self):
646        """Return True if the float is an integer."""
647        raise TypeError("type stub not overridden")
648
649    def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
650        """Represent this float as an exact integer ratio"""
651        return builtins.float(self).as_integer_ratio()
652
653    def __repr__(self):
654        return self.node._graph_repr()
655
656    def _sympy_(self):
657        return self.node.expr
658
659    def __hash__(self):
660        return hash(builtins.float(self))
661
662
663class SymBool:
664    """
665    Like an bool (including magic methods), but redirects all operations on the
666    wrapped node. This is used in particular to symbolically record operations
667    in the symbolic shape workflow.
668
669    Unlike regular bools, regular boolean operators will force extra guards instead
670    of symbolically evaluate.  Use the bitwise operators instead to handle this.
671    """
672
673    def __init__(self, node):
674        # This field MUST be named node; C++ binding code assumes that this
675        # class has a field named node that stores SymNode
676        self.node = node
677
678    def __bool__(self):
679        return self.node.bool_()
680
681    def __int__(self):
682        return builtins.int(self.node.bool_())
683
684    # Magic methods installed by torch.fx.experimental.sym_node
685    def __and__(self, other) -> "SymBool":
686        raise TypeError("type stub not overridden")
687
688    def __or__(self, other) -> "SymBool":
689        raise TypeError("type stub not overridden")
690
691    # We very carefully define __sym_not__, and not a number of other
692    # plausible alternatives:
693    #
694    #   - We do not override __not__ because this is not a real magic
695    #     method; you cannot override the meaning of the not builtin in
696    #     Python.  We use the name 'sym_not' to clarify that in user code you
697    #     cannot use the builtin not or operator.not_ or operator.__not__ and
698    #     hit this magic method; you must use our custom sym_not operator.
699    #
700    #   - We do not override the __invert__ method because SymBool is
701    #     meant to be usable in situations where bool is expected.  However,
702    #     bitwise negation ~a does the wrong thing with booleans (because
703    #     bool is a subclass of int, so ~1 = -2 which is not falseish.)
704    #     This would be a giant footgun, so we get around it by defining
705    #     our own operator.  Note that bitwise and/or do the right thing,
706    #     so we reuse the conventional operators there for readability.
707    #
708    def __sym_not__(self) -> "SymBool":
709        raise TypeError("type stub not overridden")
710
711    def __sym_ite__(self, then_val, else_val):
712        raise TypeError("type stub not overridden")
713
714    def __eq__(self, other) -> builtins.bool:
715        raise TypeError("type stub not overridden")
716
717    def __repr__(self):
718        return self.node._graph_repr()
719
720    def _sympy_(self):
721        return self.node.expr
722
723    def __hash__(self):
724        if self.node.is_constant():
725            return hash(self.node.bool_())
726        else:
727            # Force specialization
728            return hash(builtins.bool(self))
729
730
731def sym_not(a):
732    r"""SymInt-aware utility for logical negation.
733
734    Args:
735        a (SymBool or bool): Object to negate
736    """
737    import sympy
738
739    if overrides.has_torch_function_unary(a):
740        return overrides.handle_torch_function(sym_not, (a,), a)
741    if hasattr(a, "__sym_not__"):
742        return a.__sym_not__()
743    if isinstance(a, sympy.Basic):
744        return ~a  # type: ignore[operator]
745    return not a
746
747
748def sym_float(a):
749    r"""SymInt-aware utility for float casting.
750
751    Args:
752        a (SymInt, SymFloat, or object): Object to cast
753    """
754    if overrides.has_torch_function_unary(a):
755        return overrides.handle_torch_function(sym_float, (a,), a)
756    if isinstance(a, SymFloat):
757        return a
758    elif hasattr(a, "__sym_float__"):
759        return a.__sym_float__()
760    return builtins.float(a)  # type: ignore[operator]
761
762
763def sym_int(a):
764    r"""SymInt-aware utility for int casting.
765
766    Args:
767        a (SymInt, SymFloat, or object): Object to cast
768    """
769    if overrides.has_torch_function_unary(a):
770        return overrides.handle_torch_function(sym_int, (a,), a)
771    if isinstance(a, SymInt):
772        return a
773    elif isinstance(a, SymFloat):
774        return math.trunc(a)
775    return builtins.int(a)  # type: ignore[operator]
776
777
778def sym_max(a, b):
779    """
780    SymInt-aware utility for max which avoids branching on a < b.
781    Unlike builtins.max(), this only works for int/float, and it always
782    promotes to float if any argument is float (unlike builtins.max, which
783    will faithfully preserve the type of the input argument).
784    """
785    if overrides.has_torch_function((a, b)):
786        return overrides.handle_torch_function(sym_max, (a, b), a, b)
787    if isinstance(a, (SymInt, SymFloat)):
788        return a.__sym_max__(b)
789    elif isinstance(b, (SymInt, SymFloat)):
790        # Due to promotion semantics, this is operator is commutative:
791        # max(1, 1.0) === max(1.0, 1) === 1.0
792        return b.__sym_max__(a)
793    # TODO: Probably can make bool work too, just lazy
794
795    all_types, float_types = __all_and_float_types()
796
797    assert isinstance(a, all_types), type(a)
798    assert isinstance(b, all_types), type(b)
799    if isinstance(a, float_types) or isinstance(b, float_types):
800        return builtins.float(builtins.max(a, b))
801    else:
802        return builtins.max(a, b)
803
804
805def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
806    try:
807        import numpy as np
808
809        all_types: _Tuple[_Type, ...] = (
810            np.integer,
811            np.floating,
812            builtins.int,
813            builtins.float,
814        )
815        float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
816    except ModuleNotFoundError:
817        all_types = (builtins.int, builtins.float)
818        float_types = (builtins.float,)
819
820    return all_types, float_types
821
822
823def sym_min(a, b):
824    """SymInt-aware utility for min()."""
825    if overrides.has_torch_function((a, b)):
826        return overrides.handle_torch_function(sym_min, (a, b), a, b)
827    if isinstance(a, (SymInt, SymFloat)):
828        return a.__sym_min__(b)
829    elif isinstance(b, (SymInt, SymFloat)):
830        return b.__sym_min__(a)
831
832    all_types, float_types = __all_and_float_types()
833
834    assert isinstance(a, all_types), type(a)
835    assert isinstance(b, all_types), type(b)
836    if isinstance(a, float_types) or isinstance(b, float_types):
837        return builtins.float(builtins.min(a, b))
838    else:
839        return builtins.min(a, b)
840
841
842# Drop in replacement for math.sqrt, math.sin, math.cos etc
843def _get_sym_math_fn(name):
844    def fn(a):
845        if overrides.has_torch_function_unary(a):
846            return overrides.handle_torch_function(fn, (a,), a)
847        if hasattr(a, f"__sym_{name}__"):
848            return getattr(a, f"__sym_{name}__")()
849        return getattr(math, name)(a)
850
851    return fn
852
853
854__fn, __name, __sym_name = None, "", ""
855for __name in (
856    "sqrt",
857    "cos",
858    "cosh",
859    "sin",
860    "sinh",
861    "tan",
862    "tanh",
863    "asin",
864    "acos",
865    "atan",
866):
867    __sym_name = f"_sym_{__name}"
868    __fn = _get_sym_math_fn(__name)
869    __fn.__qualname__ = __fn.__name__ = __sym_name
870    globals()[__sym_name] = __fn
871
872del __fn, __name, __sym_name, _get_sym_math_fn
873
874# Adding temporary shortcut
875sym_sqrt = globals()["_sym_sqrt"]
876__all__.append("sym_sqrt")
877
878
879def sym_ite(b, t, f):
880    if overrides.has_torch_function((b, t, f)):
881        return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
882    assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
883    if isinstance(b, SymBool):
884        return b.__sym_ite__(t, f)
885    return t if b else f
886
887
888# Check to see if we can load C extensions, and if not provide some guidance
889# on what the problem might be.
890try:
891    # _initExtension is chosen (arbitrarily) as a sentinel.
892    from torch._C import _initExtension
893except ImportError:
894    import torch._C as _C_for_compiled_check
895
896    # The __file__ check only works for Python 3.7 and above.
897    if _C_for_compiled_check.__file__ is None:
898        raise ImportError(
899            textwrap.dedent(
900                """
901                Failed to load PyTorch C extensions:
902                    It appears that PyTorch has loaded the `torch/_C` folder
903                    of the PyTorch repository rather than the C extensions which
904                    are expected in the `torch._C` namespace. This can occur when
905                    using the `install` workflow. e.g.
906                        $ python setup.py install && python -c "import torch"
907
908                    This error can generally be solved using the `develop` workflow
909                        $ python setup.py develop && python -c "import torch"  # This should succeed
910                    or by running Python from a different directory.
911                """
912            ).strip()
913        ) from None
914    raise  # If __file__ is not None the cause is unknown, so just re-raise.
915
916# The torch._C submodule is already loaded via `from torch._C import *` above
917# Make an explicit reference to the _C submodule to appease linters
918from torch import _C as _C
919
920
921__name, __obj = "", None
922for __name in dir(_C):
923    if __name[0] != "_" and not __name.endswith("Base"):
924        __all__.append(__name)
925        __obj = getattr(_C, __name)
926        if callable(__obj) or inspect.isclass(__obj):
927            if __obj.__module__ != __name__:  # "torch"
928                # TODO: fix their module from C++ side
929                if __name not in {
930                    "DisableTorchFunctionSubclass",
931                    "DisableTorchFunction",
932                    "Generator",
933                }:
934                    __obj.__module__ = __name__  # "torch"
935    elif __name == "TensorBase":
936        # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
937        delattr(sys.modules[__name__], __name)
938
939del __name, __obj
940
941if not TYPE_CHECKING:
942    # issue 38137 and python issue 43367. Submodules of a C extension are
943    # non-standard, and attributes of those submodules cannot be pickled since
944    # pickle expect to be able to import them as "from _C.sub import attr"
945    # which fails with "_C is not a package
946    def _import_extension_to_sys_modules(module, memo=None):
947        if memo is None:
948            memo = set()
949        if module in memo:
950            return
951        memo.add(module)
952        module_name = module.__name__
953        for name in dir(module):
954            member = getattr(module, name)
955            member_name = getattr(member, "__name__", "")
956            if inspect.ismodule(member) and member_name.startswith(module_name):
957                sys.modules.setdefault(member_name, member)
958                # Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
959                _import_extension_to_sys_modules(member, memo)
960
961    _import_extension_to_sys_modules(_C)
962    del _import_extension_to_sys_modules
963
964################################################################################
965# Define basic utilities
966################################################################################
967
968
969def typename(obj: _Any, /) -> str:
970    """
971    String representation of the type of an object.
972
973    This function returns a fully qualified string representation of an object's type.
974    Args:
975        obj (object): The object whose type to represent
976    Returns:
977        str: the type of the object `o`
978    Example:
979        >>> x = torch.tensor([1, 2, 3])
980        >>> torch.typename(x)
981        'torch.LongTensor'
982        >>> torch.typename(torch.nn.Parameter)
983        'torch.nn.parameter.Parameter'
984    """
985    if isinstance(obj, torch.Tensor):
986        return obj.type()
987
988    module = getattr(obj, "__module__", "") or ""
989    qualname = ""
990
991    if hasattr(obj, "__qualname__"):
992        qualname = obj.__qualname__
993    elif hasattr(obj, "__name__"):
994        qualname = obj.__name__
995    else:
996        module = obj.__class__.__module__ or ""
997        qualname = obj.__class__.__qualname__
998
999    if module in {"", "builtins"}:
1000        return qualname
1001    return f"{module}.{qualname}"
1002
1003
1004def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
1005    r"""Returns True if `obj` is a PyTorch tensor.
1006
1007    Note that this function is simply doing ``isinstance(obj, Tensor)``.
1008    Using that ``isinstance`` check is better for typechecking with mypy,
1009    and more explicit - so it's recommended to use that instead of
1010    ``is_tensor``.
1011
1012    Args:
1013        obj (object): Object to test
1014    Example::
1015
1016        >>> x = torch.tensor([1, 2, 3])
1017        >>> torch.is_tensor(x)
1018        True
1019
1020    """
1021    return isinstance(obj, torch.Tensor)
1022
1023
1024def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
1025    r"""Returns True if `obj` is a PyTorch storage object.
1026
1027    Args:
1028        obj (Object): Object to test
1029    """
1030    return type(obj) in _storage_classes
1031
1032
1033_GLOBAL_DEVICE_CONTEXT = threading.local()
1034
1035
1036def get_default_device() -> "torch.device":
1037    r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
1038    global _GLOBAL_DEVICE_CONTEXT
1039
1040    if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
1041        device = _GLOBAL_DEVICE_CONTEXT.device_context.device
1042        if device.index is not None:
1043            return device
1044        else:
1045            # TODO: Call like get_device_index() method corresponding to
1046            # each device type
1047            return torch.tensor([]).device
1048    else:
1049        return torch.device("cpu")
1050
1051
1052def set_default_device(
1053    device: _Optional[_Union["torch.device", str, builtins.int]],
1054) -> None:
1055    """Sets the default ``torch.Tensor`` to be allocated on ``device``.  This
1056    does not affect factory function calls which are called with an explicit
1057    ``device`` argument.  Factory calls will be performed as if they
1058    were passed ``device`` as an argument.
1059
1060    To only temporarily change the default device instead of setting it
1061    globally, use ``with torch.device(device):`` instead.
1062
1063    The default device is initially ``cpu``.  If you set the default tensor
1064    device to another device (e.g., ``cuda``) without a device index, tensors
1065    will be allocated on whatever the current device for the device type,
1066    even after :func:`torch.cuda.set_device` is called.
1067
1068    .. warning::
1069
1070        This function imposes a slight performance cost on every Python
1071        call to the torch API (not just factory functions).  If this
1072        is causing problems for you, please comment on
1073        https://github.com/pytorch/pytorch/issues/92701
1074
1075    .. note::
1076
1077        This doesn't affect functions that create tensors that share the same memory as the input, like:
1078        :func:`torch.from_numpy` and :func:`torch.frombuffer`
1079
1080    Args:
1081        device (device or string): the device to set as default
1082
1083    Example::
1084
1085        >>> # xdoctest: +SKIP("requires cuda, changes global state")
1086        >>> torch.get_default_device()
1087        device(type='cpu')
1088        >>> torch.set_default_device('cuda')  # current device is 0
1089        >>> torch.get_default_device()
1090        device(type='cuda', index=0)
1091        >>> torch.set_default_device('cuda')
1092        >>> torch.cuda.set_device('cuda:1')  # current device is 1
1093        >>> torch.get_default_device()
1094        device(type='cuda', index=1)
1095        >>> torch.set_default_device('cuda:1')
1096        >>> torch.get_default_device()
1097        device(type='cuda', index=1)
1098
1099    """
1100    global _GLOBAL_DEVICE_CONTEXT
1101    if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
1102        device_context = _GLOBAL_DEVICE_CONTEXT.device_context
1103        if device_context is not None:
1104            device_context.__exit__(None, None, None)
1105
1106    if device is None:
1107        device_context = None
1108    else:
1109        from torch.utils._device import DeviceContext
1110
1111        device_context = DeviceContext(device)
1112        device_context.__enter__()
1113    _GLOBAL_DEVICE_CONTEXT.device_context = device_context
1114
1115
1116def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
1117    r"""
1118    .. warning::
1119
1120        This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and
1121        :func:`torch.set_default_device()` as alternatives.
1122
1123    Sets the default ``torch.Tensor`` type to floating point tensor type
1124    ``t``. This type will also be used as default floating point type for
1125    type inference in :func:`torch.tensor`.
1126
1127    The default floating point tensor type is initially ``torch.FloatTensor``.
1128
1129    Args:
1130        t (type or string): the floating point tensor type or its name
1131
1132    Example::
1133
1134        >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
1135        >>> torch.tensor([1.2, 3]).dtype    # initial default for floating point is torch.float32
1136        torch.float32
1137        >>> torch.set_default_tensor_type(torch.DoubleTensor)
1138        >>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
1139        torch.float64
1140
1141    """
1142    if isinstance(t, str):
1143        t = _import_dotted_name(t)
1144    _C._set_default_tensor_type(t)
1145
1146
1147def set_default_dtype(d: "torch.dtype", /) -> None:
1148    r"""
1149
1150    Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
1151    as inputs. Other dtypes will cause torch to raise an exception.
1152
1153    When PyTorch is initialized its default floating point dtype is torch.float32,
1154    and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
1155    type inference. The default floating point dtype is used to:
1156
1157    1. Implicitly determine the default complex dtype. When the default floating type is float16,
1158       the default complex dtype is complex32. For float32, the default complex dtype is complex64.
1159       For float64, it is complex128. For bfloat16, an exception will be raised because
1160       there is no corresponding complex type for bfloat16.
1161    2. Infer the dtype for tensors constructed using Python floats or complex Python
1162       numbers. See examples below.
1163    3. Determine the result of type promotion between bool and integer tensors and
1164       Python floats and complex Python numbers.
1165
1166    Args:
1167        d (:class:`torch.dtype`): the floating point dtype to make the default.
1168
1169    Example:
1170        >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
1171        >>> # initial default for floating point is torch.float32
1172        >>> # Python floats are interpreted as float32
1173        >>> torch.tensor([1.2, 3]).dtype
1174        torch.float32
1175        >>> # initial default for floating point is torch.complex64
1176        >>> # Complex Python numbers are interpreted as complex64
1177        >>> torch.tensor([1.2, 3j]).dtype
1178        torch.complex64
1179
1180        >>> torch.set_default_dtype(torch.float64)
1181        >>> # Python floats are now interpreted as float64
1182        >>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
1183        torch.float64
1184        >>> # Complex Python numbers are now interpreted as complex128
1185        >>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
1186        torch.complex128
1187
1188        >>> torch.set_default_dtype(torch.float16)
1189        >>> # Python floats are now interpreted as float16
1190        >>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
1191        torch.float16
1192        >>> # Complex Python numbers are now interpreted as complex128
1193        >>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
1194        torch.complex32
1195
1196    """
1197    _C._set_default_dtype(d)
1198
1199
1200def use_deterministic_algorithms(
1201    mode: builtins.bool,
1202    *,
1203    warn_only: builtins.bool = False,
1204) -> None:
1205    r"""Sets whether PyTorch operations must use "deterministic"
1206    algorithms. That is, algorithms which, given the same input, and when
1207    run on the same software and hardware, always produce the same output.
1208    When enabled, operations will use deterministic algorithms when available,
1209    and if only nondeterministic algorithms are available they will throw a
1210    :class:`RuntimeError` when called.
1211
1212    .. note:: This setting alone is not always enough to make an application
1213        reproducible. Refer to :ref:`reproducibility` for more information.
1214
1215    .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
1216        interface for this feature.
1217
1218    The following normally-nondeterministic operations will act
1219    deterministically when ``mode=True``:
1220
1221        * :class:`torch.nn.Conv1d` when called on CUDA tensor
1222        * :class:`torch.nn.Conv2d` when called on CUDA tensor
1223        * :class:`torch.nn.Conv3d` when called on CUDA tensor
1224        * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
1225        * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
1226        * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
1227        * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
1228        * :func:`torch.bmm` when called on sparse-dense CUDA tensors
1229        * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
1230          and the index is a list of tensors
1231        * :func:`torch.Tensor.index_put` with ``accumulate=False``
1232        * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
1233          tensor
1234        * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
1235          tensor
1236        * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
1237        * :func:`torch.gather` when called on a CUDA tensor that requires grad
1238        * :func:`torch.index_add` when called on CUDA tensor
1239        * :func:`torch.index_select` when attempting to differentiate a CUDA tensor
1240        * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
1241        * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
1242        * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
1243        * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
1244
1245    The following normally-nondeterministic operations will throw a
1246    :class:`RuntimeError` when ``mode=True``:
1247
1248        * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
1249        * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
1250        * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
1251        * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
1252        * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
1253        * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
1254        * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
1255        * :class:`torch.nn.MaxUnpool1d`
1256        * :class:`torch.nn.MaxUnpool2d`
1257        * :class:`torch.nn.MaxUnpool3d`
1258        * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
1259          and one of the following modes is used:
1260
1261          - ``linear``
1262          - ``bilinear``
1263          - ``bicubic``
1264          - ``trilinear``
1265
1266        * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
1267        * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
1268        * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
1269        * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
1270        * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
1271        * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
1272        * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
1273        * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
1274          ``mode='max'``
1275        * :func:`torch.Tensor.put_` when ``accumulate=False``
1276        * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
1277        * :func:`torch.histc` when called on a CUDA tensor
1278        * :func:`torch.bincount` when called on a CUDA tensor and ``weights``
1279          tensor is given
1280        * :func:`torch.kthvalue` with called on a CUDA tensor
1281        * :func:`torch.median` with indices output when called on a CUDA tensor
1282        * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
1283        * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
1284        * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
1285        * :func:`torch.Tensor.resize_` when called with a quantized tensor
1286
1287    In addition, several operations fill uninitialized memory when this setting
1288    is turned on and when
1289    :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
1290    See the documentation for that attribute for more information.
1291
1292    A handful of CUDA operations are nondeterministic if the CUDA version is
1293    10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
1294    or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
1295    details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_
1296    If one of these environment variable configurations is not set, a :class:`RuntimeError`
1297    will be raised from these operations when called with CUDA tensors:
1298
1299        * :func:`torch.mm`
1300        * :func:`torch.mv`
1301        * :func:`torch.bmm`
1302
1303    Note that deterministic operations tend to have worse performance than
1304    nondeterministic operations.
1305
1306    .. note::
1307
1308        This flag does not detect or prevent nondeterministic behavior caused
1309        by calling an inplace operation on a tensor with an internal memory
1310        overlap or by giving such a tensor as the :attr:`out` argument for an
1311        operation. In these cases, multiple writes of different data may target
1312        a single memory location, and the order of writes is not guaranteed.
1313
1314    Args:
1315        mode (:class:`bool`): If True, makes potentially nondeterministic
1316            operations switch to a deterministic algorithm or throw a runtime
1317            error. If False, allows nondeterministic operations.
1318
1319    Keyword args:
1320        warn_only (:class:`bool`, optional): If True, operations that do not
1321            have a deterministic implementation will throw a warning instead of
1322            an error. Default: ``False``
1323
1324    Example::
1325
1326        >>> # xdoctest: +SKIP
1327        >>> torch.use_deterministic_algorithms(True)
1328
1329        # Forward mode nondeterministic error
1330        >>> torch.randn(10, device='cuda').kthvalue(1)
1331        ...
1332        RuntimeError: kthvalue CUDA does not have a deterministic implementation...
1333
1334        # Backward mode nondeterministic error
1335        >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
1336        ...
1337        RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...
1338    """
1339    _C._set_deterministic_algorithms(mode, warn_only=warn_only)
1340
1341
1342def are_deterministic_algorithms_enabled() -> builtins.bool:
1343    r"""Returns True if the global deterministic flag is turned on. Refer to
1344    :func:`torch.use_deterministic_algorithms` documentation for more details.
1345    """
1346    return _C._get_deterministic_algorithms()
1347
1348
1349def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
1350    r"""Returns True if the global deterministic flag is set to warn only.
1351    Refer to :func:`torch.use_deterministic_algorithms` documentation for more
1352    details.
1353    """
1354    return _C._get_deterministic_algorithms_warn_only()
1355
1356
1357def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None:
1358    r"""Sets the debug mode for deterministic operations.
1359
1360    .. note:: This is an alternative interface for
1361        :func:`torch.use_deterministic_algorithms`. Refer to that function's
1362        documentation for details about affected operations.
1363
1364    Args:
1365        debug_mode(str or int): If "default" or 0, don't error or warn on
1366            nondeterministic operations. If "warn" or 1, warn on
1367            nondeterministic operations. If "error" or 2, error on
1368            nondeterministic operations.
1369    """
1370
1371    # NOTE: builtins.int is used here because int in this scope resolves
1372    # to torch.int
1373    if not isinstance(debug_mode, (builtins.int, str)):
1374        raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")
1375
1376    if isinstance(debug_mode, str):
1377        if debug_mode == "default":
1378            debug_mode = 0
1379        elif debug_mode == "warn":
1380            debug_mode = 1
1381        elif debug_mode == "error":
1382            debug_mode = 2
1383        else:
1384            raise RuntimeError(
1385                "invalid value of debug_mode, expected one of `default`, "
1386                f"`warn`, `error`, but got {debug_mode}"
1387            )
1388
1389    if debug_mode == 0:
1390        _C._set_deterministic_algorithms(False)
1391    elif debug_mode == 1:
1392        _C._set_deterministic_algorithms(True, warn_only=True)
1393    elif debug_mode == 2:
1394        _C._set_deterministic_algorithms(True)
1395    else:
1396        raise RuntimeError(
1397            "invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}"
1398        )
1399
1400
1401def get_deterministic_debug_mode() -> builtins.int:
1402    r"""Returns the current value of the debug mode for deterministic
1403    operations. Refer to :func:`torch.set_deterministic_debug_mode`
1404    documentation for more details.
1405    """
1406
1407    if _C._get_deterministic_algorithms():
1408        if _C._get_deterministic_algorithms_warn_only():
1409            return 1
1410        else:
1411            return 2
1412    else:
1413        return 0
1414
1415
1416def get_float32_matmul_precision() -> str:
1417    r"""Returns the current value of float32 matrix multiplication precision. Refer to
1418    :func:`torch.set_float32_matmul_precision` documentation for more details.
1419    """
1420    return _C._get_float32_matmul_precision()
1421
1422
1423def set_float32_matmul_precision(precision: str) -> None:
1424    r"""Sets the internal precision of float32 matrix multiplications.
1425
1426    Running float32 matrix multiplications in lower precision may significantly increase
1427    performance, and in some programs the loss of precision has a negligible impact.
1428
1429    Supports three settings:
1430
1431        * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
1432          bits with 23 bits explicitly stored) for internal computations.
1433        * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
1434          mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
1435          (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
1436          algorithms are available.  Otherwise float32 matrix multiplications are computed
1437          as if the precision is "highest".  See below for more information on the bfloat16
1438          approach.
1439        * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
1440          bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
1441          using that datatype internally is available. Otherwise float32
1442          matrix multiplications are computed as if the precision is "high".
1443
1444    When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
1445    that is more complicated than simply truncating to some smaller number mantissa bits
1446    (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored).  Refer to [Henry2019]_ for a complete
1447    description of this algorithm.  To briefly explain here, the first step is to realize
1448    that we can perfectly encode a single float32 number as the sum of three bfloat16
1449    numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
1450    same number of exponent bits).  This means that the product of two float32 numbers can
1451    be exactly given by the sum of nine products of bfloat16 numbers.  We can then trade
1452    accuracy for speed by dropping some of these products.  The "high" precision algorithm
1453    specifically keeps only the three most significant products, which conveniently excludes
1454    all of the products involving the last 8 mantissa bits of either input.  This means that
1455    we can represent our inputs as the sum of two bfloat16 numbers rather than three.
1456    Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than
1457    float32 ones, it's faster to do three multiplications and 2 additions with bfloat16
1458    precision than it is to do a single multiplication with float32 precision.
1459
1460    .. [Henry2019] http://arxiv.org/abs/1904.06376
1461
1462    .. note::
1463
1464        This does not change the output dtype of float32 matrix multiplications,
1465        it controls how the internal computation of the matrix multiplication is performed.
1466
1467    .. note::
1468
1469        This does not change the precision of convolution operations. Other flags,
1470        like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
1471        operations.
1472
1473    .. note::
1474
1475        This flag currently only affects one native device type: CUDA.
1476        If "high" or "medium" are set then the TensorFloat32 datatype will be used
1477        when computing float32 matrix multiplications, equivalent to setting
1478        `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
1479        is set then the float32 datatype is used for internal computations, equivalent
1480        to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
1481
1482    Args:
1483        precision(str): can be set to "highest" (default), "high", or "medium" (see above).
1484
1485    """
1486    _C._set_float32_matmul_precision(precision)
1487
1488
1489def set_warn_always(b: builtins.bool, /) -> None:
1490    r"""When this flag is False (default) then some PyTorch warnings may only
1491    appear once per process. This helps avoid excessive warning information.
1492    Setting it to True causes these warnings to always appear, which may be
1493    helpful when debugging.
1494
1495    Args:
1496        b (:class:`bool`): If True, force warnings to always be emitted
1497                           If False, set to the default behaviour
1498    """
1499    _C._set_warnAlways(b)
1500
1501
1502def is_warn_always_enabled() -> builtins.bool:
1503    r"""Returns True if the global warn_always flag is turned on. Refer to
1504    :func:`torch.set_warn_always` documentation for more details.
1505    """
1506    return _C._get_warnAlways()
1507
1508
1509################################################################################
1510# Define error checking functions
1511################################################################################
1512
1513# These error checking functions must be kept consistent with their C++
1514# equivalents. Their C++ equivalents are mentioned where applicable.
1515
1516
1517def _check_with(
1518    error_type,
1519    cond: _Union[builtins.bool, SymBool],
1520    message: _Callable[[], str],
1521):  # noqa: F811
1522    if not isinstance(cond, (builtins.bool, SymBool)):
1523        raise TypeError(f"cond must be a bool, but got {type(cond)}")
1524
1525    from torch.fx.experimental.symbolic_shapes import expect_true
1526
1527    if expect_true(cond):
1528        return
1529
1530    # error_type must be a subclass of Exception and not subclass of Warning
1531    assert issubclass(error_type, Exception) and not issubclass(error_type, Warning)
1532
1533    if message is None:
1534        message_evaluated = (
1535            "Expected cond to be True, but got False. (Could this error "
1536            "message be improved? If so, please report an enhancement request "
1537            "to PyTorch.)"
1538        )
1539
1540    else:
1541        if not callable(message):
1542            raise TypeError("message must be a callable")
1543
1544        message_evaluated = str(message())
1545
1546    raise error_type(message_evaluated)
1547
1548
1549def _check(cond, message=None):  # noqa: F811
1550    r"""Throws error containing an optional message if the specified condition
1551    is False.
1552
1553    Error type: ``RuntimeError``
1554
1555    C++ equivalent: ``TORCH_CHECK``
1556
1557    Args:
1558        cond (:class:`bool`): If False, throw error
1559
1560        message (Callable, optional): Callable that returns either a string or
1561            an object that has a ``__str__()`` method to be used as the error
1562            message. Default: ``None``
1563    """
1564    _check_with(RuntimeError, cond, message)
1565
1566
1567def _check_is_size(i, message=None):
1568    """Checks that a given integer is a valid size (i.e., is non-negative).
1569    You should use this over _check(i >= 0) because we can use the semantic
1570    information (that i is a size) to make some further inferences in case
1571    i is an unbacked SymInt.
1572
1573    NB: Do NOT use this in contexts where a -1 size would be valid (indicating
1574    to infer the size from context, or if you should wrap-around or truncate).
1575    Only use this if the only valid value is an honest to goodness size.
1576    """
1577    # This is responsible for the expect_true
1578    _check(i >= 0, message)
1579    from torch.fx.experimental.symbolic_shapes import _advise_is_size
1580
1581    _advise_is_size(i)
1582
1583
1584def _check_index(cond, message=None):  # noqa: F811
1585    r"""Throws error containing an optional message if the specified condition
1586    is False.
1587
1588    Error type: ``IndexError``
1589
1590    C++ equivalent: ``TORCH_CHECK_INDEX``
1591
1592    Args:
1593        cond (:class:`bool`): If False, throw error
1594
1595        message (Callable, optional): Callable that returns either a string or
1596            an object that has a ``__str__()`` method to be used as the error
1597            message. Default: ``None``
1598    """
1599    _check_with(IndexError, cond, message)
1600
1601
1602def _check_value(cond, message=None):  # noqa: F811
1603    r"""Throws error containing an optional message if the specified condition
1604    is False.
1605
1606    Error type: ``ValueError``
1607
1608    C++ equivalent: ``TORCH_CHECK_VALUE``
1609
1610    Args:
1611        cond (:class:`bool`): If False, throw error
1612
1613        message (Callable, optional): Callable that returns either a string or
1614            an object that has a ``__str__()`` method to be used as the error
1615            message. Default: ``None``
1616    """
1617    _check_with(ValueError, cond, message)
1618
1619
1620def _check_type(cond, message=None):  # noqa: F811
1621    r"""Throws error containing an optional message if the specified condition
1622    is False.
1623
1624    Error type: ``TypeError``
1625
1626    C++ equivalent: ``TORCH_CHECK_TYPE``
1627
1628    Args:
1629        cond (:class:`bool`): If False, throw error
1630
1631        message (Callable, optional): Callable that returns either a string or
1632            an object that has a ``__str__()`` method to be used as the error
1633            message. Default: ``None``
1634    """
1635    _check_with(TypeError, cond, message)
1636
1637
1638def _check_not_implemented(cond, message=None):  # noqa: F811
1639    r"""Throws error containing an optional message if the specified condition
1640    is False.
1641
1642    Error type: ``NotImplementedError``
1643
1644    C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
1645
1646    Args:
1647        cond (:class:`bool`): If False, throw error
1648
1649        message (Callable, optional): Callable that returns either a string or
1650            an object that has a ``__str__()`` method to be used as the error
1651            message. Default: ``None``
1652    """
1653    _check_with(NotImplementedError, cond, message)
1654
1655
1656def _check_tensor_all_with(error_type, cond, message=None):  # noqa: F811
1657    if not is_tensor(cond):
1658        raise TypeError(f"cond must be a tensor, but got {type(cond)}")
1659
1660    if not cond.dtype == torch.bool:
1661        raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
1662
1663    _check_with(error_type, cond._is_all_true().item(), message)  # type: ignore[arg-type]
1664
1665
1666# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
1667def _check_tensor_all(cond, message=None):  # noqa: F811
1668    r"""Throws error containing an optional message if the specified condition
1669    is False.
1670
1671    Error type: ``RuntimeError``
1672
1673    C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
1674
1675    Args:
1676        cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
1677            element is ``False``, throw error
1678
1679        message (Callable, optional): Callable that returns either a string or
1680            an object that has a ``__str__()`` method to be used as the error
1681            message. Default: ``None``
1682    """
1683    _check_tensor_all_with(RuntimeError, cond, message)
1684
1685
1686################################################################################
1687# Define numeric constants
1688################################################################################
1689
1690# For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
1691# NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
1692from math import e, inf, nan, pi
1693
1694
1695newaxis: None = None
1696
1697__all__.extend(["e", "pi", "nan", "inf", "newaxis"])
1698
1699################################################################################
1700# Define Storage and Tensor classes
1701################################################################################
1702
1703from torch._tensor import Tensor  # usort: skip
1704
1705# needs to be after torch.Tensor is defined to avoid circular dependencies
1706from torch import storage as storage  # usort: skip
1707from torch.storage import (
1708    _LegacyStorage,
1709    _StorageBase,
1710    _warn_typed_storage_removal,
1711    TypedStorage,
1712    UntypedStorage,
1713)
1714
1715
1716# NOTE: New <type>Storage classes should never be added. When adding a new
1717# dtype, use torch.storage.TypedStorage directly.
1718class ByteStorage(_LegacyStorage):
1719    @classproperty
1720    def dtype(self):
1721        _warn_typed_storage_removal(stacklevel=3)
1722        return self._dtype
1723
1724    @classproperty
1725    def _dtype(self):
1726        return torch.uint8
1727
1728
1729class DoubleStorage(_LegacyStorage):
1730    @classproperty
1731    def dtype(self):
1732        _warn_typed_storage_removal(stacklevel=3)
1733        return self._dtype
1734
1735    @classproperty
1736    def _dtype(self):
1737        return torch.double
1738
1739
1740class FloatStorage(_LegacyStorage):
1741    @classproperty
1742    def dtype(self):
1743        _warn_typed_storage_removal(stacklevel=3)
1744        return self._dtype
1745
1746    @classproperty
1747    def _dtype(self):
1748        return torch.float
1749
1750
1751class HalfStorage(_LegacyStorage):
1752    @classproperty
1753    def dtype(self):
1754        _warn_typed_storage_removal(stacklevel=3)
1755        return self._dtype
1756
1757    @classproperty
1758    def _dtype(self):
1759        return torch.half
1760
1761
1762class LongStorage(_LegacyStorage):
1763    @classproperty
1764    def dtype(self):
1765        _warn_typed_storage_removal(stacklevel=3)
1766        return self._dtype
1767
1768    @classproperty
1769    def _dtype(self):
1770        return torch.long
1771
1772
1773class IntStorage(_LegacyStorage):
1774    @classproperty
1775    def dtype(self):
1776        _warn_typed_storage_removal(stacklevel=3)
1777        return self._dtype
1778
1779    @classproperty
1780    def _dtype(self):
1781        return torch.int
1782
1783
1784class ShortStorage(_LegacyStorage):
1785    @classproperty
1786    def dtype(self):
1787        _warn_typed_storage_removal(stacklevel=3)
1788        return self._dtype
1789
1790    @classproperty
1791    def _dtype(self):
1792        return torch.short
1793
1794
1795class CharStorage(_LegacyStorage):
1796    @classproperty
1797    def dtype(self):
1798        _warn_typed_storage_removal(stacklevel=3)
1799        return self._dtype
1800
1801    @classproperty
1802    def _dtype(self):
1803        return torch.int8
1804
1805
1806class BoolStorage(_LegacyStorage):
1807    @classproperty
1808    def dtype(self):
1809        _warn_typed_storage_removal(stacklevel=3)
1810        return self._dtype
1811
1812    @classproperty
1813    def _dtype(self):
1814        return torch.bool
1815
1816
1817class BFloat16Storage(_LegacyStorage):
1818    @classproperty
1819    def dtype(self):
1820        _warn_typed_storage_removal(stacklevel=3)
1821        return self._dtype
1822
1823    @classproperty
1824    def _dtype(self):
1825        return torch.bfloat16
1826
1827
1828class ComplexDoubleStorage(_LegacyStorage):
1829    @classproperty
1830    def dtype(self):
1831        _warn_typed_storage_removal(stacklevel=3)
1832        return self._dtype
1833
1834    @classproperty
1835    def _dtype(self):
1836        return torch.cdouble
1837
1838
1839class ComplexFloatStorage(_LegacyStorage):
1840    @classproperty
1841    def dtype(self):
1842        _warn_typed_storage_removal(stacklevel=3)
1843        return self._dtype
1844
1845    @classproperty
1846    def _dtype(self):
1847        return torch.cfloat
1848
1849
1850class QUInt8Storage(_LegacyStorage):
1851    @classproperty
1852    def dtype(self):
1853        _warn_typed_storage_removal(stacklevel=3)
1854        return self._dtype
1855
1856    @classproperty
1857    def _dtype(self):
1858        return torch.quint8
1859
1860
1861class QInt8Storage(_LegacyStorage):
1862    @classproperty
1863    def dtype(self):
1864        _warn_typed_storage_removal(stacklevel=3)
1865        return self._dtype
1866
1867    @classproperty
1868    def _dtype(self):
1869        return torch.qint8
1870
1871
1872class QInt32Storage(_LegacyStorage):
1873    @classproperty
1874    def dtype(self):
1875        _warn_typed_storage_removal(stacklevel=3)
1876        return self._dtype
1877
1878    @classproperty
1879    def _dtype(self):
1880        return torch.qint32
1881
1882
1883class QUInt4x2Storage(_LegacyStorage):
1884    @classproperty
1885    def dtype(self):
1886        _warn_typed_storage_removal(stacklevel=3)
1887        return self._dtype
1888
1889    @classproperty
1890    def _dtype(self):
1891        return torch.quint4x2
1892
1893
1894class QUInt2x4Storage(_LegacyStorage):
1895    @classproperty
1896    def dtype(self):
1897        _warn_typed_storage_removal(stacklevel=3)
1898        return self._dtype
1899
1900    @classproperty
1901    def _dtype(self):
1902        return torch.quint2x4
1903
1904
1905_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
1906    UntypedStorage,
1907    DoubleStorage,
1908    FloatStorage,
1909    LongStorage,
1910    IntStorage,
1911    ShortStorage,
1912    CharStorage,
1913    ByteStorage,
1914    HalfStorage,
1915    BoolStorage,
1916    QUInt8Storage,
1917    QInt8Storage,
1918    QInt32Storage,
1919    BFloat16Storage,
1920    ComplexFloatStorage,
1921    ComplexDoubleStorage,
1922    QUInt4x2Storage,
1923    QUInt2x4Storage,
1924    TypedStorage,
1925}
1926
1927# The _tensor_classes set is initialized by the call to initialize_python_bindings.
1928_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
1929
1930# If you edit these imports, please update torch/__init__.py.in as well
1931from torch import amp as amp, random as random, serialization as serialization
1932from torch._tensor_str import set_printoptions
1933from torch.amp import autocast, GradScaler
1934from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
1935from torch.serialization import load, save
1936
1937
1938################################################################################
1939# Initialize extension
1940################################################################################
1941
1942
1943# Shared memory manager needs to know the exact location of manager executable
1944def _manager_path():
1945    if _running_with_deploy() or platform.system() == "Windows":
1946        return b""
1947    path = get_file_path("torch", "bin", "torch_shm_manager")
1948    prepare_multiprocessing_environment(get_file_path("torch"))
1949    if not os.path.exists(path):
1950        raise RuntimeError("Unable to find torch_shm_manager at " + path)
1951    return path.encode("utf-8")
1952
1953
1954_C._initExtension(_manager_path())
1955
1956del _manager_path
1957
1958# Appease the type checker: it can't deal with direct setting of globals().
1959# Note that we will see "too many" functions when reexporting this way; there
1960# is not a good way to fix this problem.  Perhaps, try to redesign VariableFunctions
1961# so that this import is good enough
1962if TYPE_CHECKING:
1963    # Some type signatures pulled in from _VariableFunctions here clash with
1964    # signatures already imported. For now these clashes are ignored; see
1965    # PR #43339 for details.
1966    from torch._C._VariableFunctions import *  # type: ignore[assignment, misc] # noqa: F403
1967
1968    # Fixup segment_reduce visibility
1969    _segment_reduce = segment_reduce
1970    del segment_reduce  # noqa: F821
1971
1972# Ops not to be exposed in `torch` namespace,
1973# mostly helper ops.
1974PRIVATE_OPS = ("unique_dim",)
1975
1976__name, __obj = "", None
1977for __name in dir(_C._VariableFunctions):
1978    if __name.startswith("__") or __name in PRIVATE_OPS:
1979        continue
1980    __obj = getattr(_C._VariableFunctions, __name)
1981    __obj.__module__ = __name__  # "torch"
1982    # Hide some APIs that should not be public
1983    if __name == "segment_reduce":
1984        # TODO: Once the undocumented FC window is passed, remove the line bellow
1985        globals()[__name] = __obj
1986        __name = "_" + __name
1987    globals()[__name] = __obj
1988    if not __name.startswith("_"):
1989        __all__.append(__name)
1990
1991del __name, __obj
1992
1993################################################################################
1994# Add torch.dtype instances to the public API
1995################################################################################
1996
1997import torch
1998
1999
2000__all__.extend(
2001    name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
2002)
2003
2004################################################################################
2005# Import TorchDynamo's lazy APIs to avoid circular dependenices
2006################################################################################
2007
2008# needs to be before from torch.functional import * to avoid circular dependencies
2009from torch._compile import _disable_dynamo  # usort: skip
2010
2011################################################################################
2012# Import interface functions defined in Python
2013################################################################################
2014
2015# needs to be after the above ATen bindings so we can overwrite from Python side
2016from torch import _VF as _VF, functional as functional  # usort: skip
2017from torch.functional import *  # usort: skip # noqa: F403
2018
2019################################################################################
2020# Remove unnecessary members
2021################################################################################
2022
2023del _StorageBase
2024del _LegacyStorage
2025
2026################################################################################
2027# Define _assert
2028################################################################################
2029
2030
2031# needs to be before the submodule imports to avoid circular dependencies
2032def _assert(condition, message):
2033    r"""A wrapper around Python's assert which is symbolically traceable."""
2034    if type(condition) is not torch.Tensor and overrides.has_torch_function(
2035        (condition,)
2036    ):
2037        return overrides.handle_torch_function(
2038            _assert, (condition,), condition, message
2039        )
2040    assert condition, message
2041
2042
2043################################################################################
2044# Import most common subpackages
2045################################################################################
2046
2047# Use the redundant form so that type checkers know that these are a part of
2048# the public API. The "regular" import lines are there solely for the runtime
2049# side effect of adding to the imported module's members for other users.
2050
2051# needs to be before import torch.nn as nn to avoid circular dependencies
2052from torch.autograd import (  # usort: skip
2053    enable_grad as enable_grad,
2054    inference_mode as inference_mode,
2055    no_grad as no_grad,
2056    set_grad_enabled as set_grad_enabled,
2057)
2058
2059from torch import (
2060    __config__ as __config__,
2061    __future__ as __future__,
2062    _awaits as _awaits,
2063    autograd as autograd,
2064    backends as backends,
2065    cpu as cpu,
2066    cuda as cuda,
2067    distributed as distributed,
2068    distributions as distributions,
2069    fft as fft,
2070    futures as futures,
2071    hub as hub,
2072    jit as jit,
2073    linalg as linalg,
2074    mps as mps,
2075    mtia as mtia,
2076    multiprocessing as multiprocessing,
2077    nested as nested,
2078    nn as nn,
2079    optim as optim,
2080    overrides as overrides,
2081    profiler as profiler,
2082    sparse as sparse,
2083    special as special,
2084    testing as testing,
2085    types as types,
2086    utils as utils,
2087    xpu as xpu,
2088)
2089from torch.signal import windows as windows
2090
2091
2092# Quantized, sparse, AO, etc. should be last to get imported, as nothing
2093# is expected to depend on them.
2094from torch import ao as ao  # usort: skip
2095
2096# nn.quant* depends on ao -- so should be after those.
2097import torch.nn.intrinsic
2098import torch.nn.qat
2099import torch.nn.quantizable
2100import torch.nn.quantized
2101
2102
2103_C._init_names(list(_storage_classes))
2104
2105# attach docstrings to torch and tensor functions
2106from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs
2107
2108
2109del _torch_docs, _tensor_docs, _storage_docs, _size_docs
2110
2111
2112def compiled_with_cxx11_abi() -> builtins.bool:
2113    r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
2114    return _C._GLIBCXX_USE_CXX11_ABI
2115
2116
2117from torch import _library as _library, _ops as _ops
2118
2119
2120# Import the ops and classes "namespace"
2121from torch._ops import ops as ops  # usort: skip
2122from torch._classes import classes as classes  # usort: skip
2123
2124sys.modules.setdefault(f"{__name__}.ops", ops)
2125sys.modules.setdefault(f"{__name__}.classes", classes)
2126
2127# quantization depends on torch.fx and torch.ops
2128# Import quantization
2129from torch import quantization as quantization  # usort: skip
2130
2131# Import the quasi random sampler
2132from torch import quasirandom as quasirandom  # usort: skip
2133
2134# If you are seeing this, it means that this call site was not checked if
2135# the memory format could be preserved, and it was switched to old default
2136# behaviour of contiguous
2137legacy_contiguous_format = contiguous_format  # defined by _C._initExtension()
2138
2139# Register fork handler to initialize OpenMP in child processes (see gh-28389)
2140from torch.multiprocessing._atfork import register_after_fork
2141
2142
2143register_after_fork(torch.get_num_threads)
2144del register_after_fork
2145
2146# Import tools that require fully imported torch (for applying
2147# torch.jit.script as a decorator, for instance):
2148from torch._lobpcg import lobpcg as lobpcg
2149
2150
2151# These were previously defined in native_functions.yaml and appeared on the
2152# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
2153# class usage. We add these lines here to preserve backward compatibility.
2154quantized_lstm = ops.aten.quantized_lstm
2155quantized_gru = ops.aten.quantized_gru
2156
2157# Import experimental masked operations support. See
2158# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
2159# information.
2160from torch import masked as masked
2161
2162# Import removed ops with error message about removal
2163from torch._linalg_utils import (  # type: ignore[misc]
2164    _symeig as symeig,
2165    eig,
2166    lstsq,
2167    matrix_rank,
2168    solve,
2169)
2170from torch.utils.dlpack import from_dlpack, to_dlpack
2171
2172
2173class _TorchCompileInductorWrapper:
2174    compiler_name = "inductor"
2175
2176    def __init__(self, mode, options, dynamic):
2177        self.config: _Dict[str, _Any] = {}
2178        self.dynamic = dynamic
2179        self.apply_mode(mode)
2180        self.apply_options(options)
2181
2182        if self.config.get("triton.cudagraphs", False):
2183            os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
2184            # FIXME: CUDA Graph does not work well with CUPTI teardown.
2185            #   1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
2186            #   2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
2187            # Workaround: turn off CUPTI teardown when using CUDA Graphs.
2188            os.environ["TEARDOWN_CUPTI"] = "0"
2189
2190    def __eq__(self, other):
2191        return (
2192            isinstance(other, _TorchCompileInductorWrapper)
2193            and self.config == other.config
2194            and self.dynamic == other.dynamic
2195        )
2196
2197    def apply_mode(self, mode: _Optional[str]):
2198        if mode is None or mode == "default":
2199            pass
2200        elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
2201            from torch._inductor import list_mode_options
2202
2203            self.apply_options(list_mode_options(mode, self.dynamic))
2204        else:
2205            raise RuntimeError(
2206                f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
2207            )
2208
2209    def apply_options(self, options: _Optional[_Dict[str, _Any]]):
2210        if not options:
2211            return
2212
2213        from torch._inductor import config
2214
2215        current_config: _Dict[str, _Any] = config.shallow_copy_dict()
2216
2217        for key, val in options.items():
2218            attr_name = key.replace("-", "_")
2219            if attr_name not in current_config:
2220                raise RuntimeError(
2221                    f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
2222                )
2223            if type(val) is not type(current_config[attr_name]):
2224                val_type_str = type(val).__name__
2225                expected_type_str = type(current_config[attr_name]).__name__
2226                raise RuntimeError(
2227                    f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
2228                )
2229            self.config[attr_name] = val
2230
2231    def __call__(self, model_, inputs_):
2232        from torch._inductor.compile_fx import compile_fx
2233
2234        return compile_fx(model_, inputs_, config_patches=self.config)
2235
2236    def get_compiler_config(self):
2237        from torch._inductor.compile_fx import get_patched_config_dict
2238
2239        return get_patched_config_dict(config_patches=self.config)
2240
2241    def reset(self):
2242        from torch._inductor import config
2243
2244        if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
2245            if self.config.get("triton.cudagraphs", True):
2246                from torch._inductor.cudagraph_trees import reset_cudagraph_trees
2247
2248                reset_cudagraph_trees()
2249
2250
2251class _TorchCompileWrapper:
2252    def __init__(self, backend, mode, options, dynamic):
2253        from torch._dynamo.backends.registry import lookup_backend
2254
2255        if isinstance(backend, str):
2256            self.compiler_name = backend
2257        elif hasattr(backend, "__name__"):
2258            self.compiler_name = backend.__name__
2259        else:
2260            self.compiler_name = str(backend)
2261        self.dynamic = dynamic
2262        self.compiler_fn = lookup_backend(backend)
2263        self.kwargs = {}
2264        # only pass the args if they non-empty
2265        if mode and mode != "default":
2266            self.kwargs["mode"] = mode
2267        if options:
2268            self.kwargs["options"] = options
2269
2270    def __eq__(self, other):
2271        return (
2272            isinstance(other, _TorchCompileWrapper)
2273            and self.compiler_fn == other.compiler_fn
2274            and self.kwargs == other.kwargs
2275            and self.dynamic == other.dynamic
2276        )
2277
2278    def __call__(self, model_, inputs_):
2279        return self.compiler_fn(model_, inputs_, **self.kwargs)
2280
2281    def reset(self):
2282        if hasattr(self.compiler_fn, "reset"):
2283            self.compiler_fn.reset()
2284
2285
2286_InputT = _ParamSpec("_InputT")
2287_RetT = _TypeVar("_RetT")
2288
2289
2290@_overload
2291def compile(
2292    model: _Callable[_InputT, _RetT],
2293    *,
2294    fullgraph: builtins.bool = False,
2295    dynamic: _Optional[builtins.bool] = None,
2296    backend: _Union[str, _Callable] = "inductor",
2297    mode: _Union[str, None] = None,
2298    options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2299    disable: builtins.bool = False,
2300) -> _Callable[_InputT, _RetT]: ...
2301
2302
2303@_overload
2304def compile(
2305    model: None = None,
2306    *,
2307    fullgraph: builtins.bool = False,
2308    dynamic: _Optional[builtins.bool] = None,
2309    backend: _Union[str, _Callable] = "inductor",
2310    mode: _Union[str, None] = None,
2311    options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2312    disable: builtins.bool = False,
2313) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
2314
2315
2316def compile(
2317    model: _Optional[_Callable] = None,
2318    *,
2319    fullgraph: builtins.bool = False,
2320    dynamic: _Optional[builtins.bool] = None,
2321    backend: _Union[str, _Callable] = "inductor",
2322    mode: _Union[str, None] = None,
2323    options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2324    disable: builtins.bool = False,
2325) -> _Union[
2326    _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
2327    _Callable[_InputT, _RetT],
2328]:
2329    """
2330    Optimizes given model/function using TorchDynamo and specified backend.
2331    If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
2332    to compile the module inplace without changing its structure.
2333
2334    Concretely, for every frame executed within the compiled region, we will attempt
2335    to compile it and cache the compiled result on the code object for future
2336    use.  A single frame may be compiled multiple times if previous compiled
2337    results are not applicable for subsequent calls (this is called a "guard
2338    failure), you can use TORCH_LOGS=guards to debug these situations.
2339    Multiple compiled results can be associated with a frame up to
2340    ``torch._dynamo.config.cache_size_limit``, which defaults to 8; at which
2341    point we will fall back to eager.  Note that compile caches are per
2342    *code object*, not frame; if you dynamically create multiple copies of a
2343    function, they will all share the same code cache.
2344
2345    Args:
2346       model (Callable): Module/function to optimize
2347       fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions
2348        in the function that it will optimize. If True, then we require that the entire function be
2349        capturable into a single graph. If this is not possible (that is, if there are graph breaks),
2350        then this will raise an error.
2351       dynamic (bool or None): Use dynamic shape tracing.  When this is True, we will up-front attempt
2352        to generate a kernel that is as dynamic as possible to avoid recompilations when
2353        sizes change.  This may not always work as some operations/optimizations will
2354        force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
2355        When this is False, we will NEVER generate dynamic kernels, we will always specialize.
2356        By default (None), we automatically detect if dynamism has occurred and compile a more
2357        dynamic kernel upon recompile.
2358       backend (str or Callable): backend to be used
2359
2360        - "inductor" is the default backend, which is a good balance between performance and overhead
2361
2362        - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
2363
2364        - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)`
2365
2366        - To register an out-of-tree custom backend:
2367          https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends
2368       mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"
2369
2370        - "default" is the default mode, which is a good balance between performance and overhead
2371
2372        - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs,
2373          useful for small batches.  Reduction of overhead can come at the cost of more memory
2374          usage, as we will cache the workspace memory required for the invocation so that we
2375          do not have to reallocate it on subsequent runs.  Reduction of overhead is not guaranteed
2376          to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.
2377          There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints
2378          to debug.
2379
2380        - "max-autotune" is a mode that leverages Triton or template based matrix multiplications
2381          on supported devices and Triton based convolutions on GPU.
2382          It enables CUDA graphs by default on GPU.
2383
2384        - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs
2385
2386        - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`
2387
2388       options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are
2389
2390        - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set
2391
2392        - `max_autotune` which will profile to pick the best matmul configuration
2393
2394        - `fallback_random` which is useful when debugging accuracy issues
2395
2396        - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores
2397
2398        - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs
2399
2400        - `trace.enabled` which is the most useful debugging flag to turn on
2401
2402        - `trace.graph_diagram` which will show you a picture of your graph after fusion
2403
2404        - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`
2405       disable (bool): Turn torch.compile() into a no-op for testing
2406
2407    Example::
2408
2409        @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
2410        def foo(x):
2411            return torch.sin(x) + torch.cos(x)
2412
2413    """
2414    _C._log_api_usage_once("torch.compile")
2415    if sys.version_info >= (3, 13):
2416        raise RuntimeError("Dynamo is not supported on Python 3.13+")
2417
2418    # Decorator mode
2419    if model is None:
2420
2421        def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
2422            if model is None:
2423                raise RuntimeError("Model can't be None")
2424            return compile(
2425                model,
2426                fullgraph=fullgraph,
2427                dynamic=dynamic,
2428                backend=backend,
2429                mode=mode,
2430                options=options,
2431                disable=disable,
2432            )
2433
2434        return fn
2435
2436    if mode is not None and options is not None:
2437        raise RuntimeError(
2438            "Either mode or options can be specified, but both can't be specified at the same time."
2439        )
2440    if mode is None and options is None:
2441        mode = "default"
2442    if backend == "inductor":
2443        backend = _TorchCompileInductorWrapper(mode, options, dynamic)
2444    else:
2445        backend = _TorchCompileWrapper(backend, mode, options, dynamic)
2446
2447    return torch._dynamo.optimize(
2448        backend=backend,
2449        nopython=fullgraph,
2450        dynamic=dynamic,
2451        disable=disable,
2452    )(model)  # type: ignore[return-value]
2453
2454
2455def _register_device_module(device_type, module):
2456    r"""Register an external runtime module of the specific :attr:`device_type`
2457    supported by torch.
2458
2459    After the :attr:`module` is registered correctly, the user can refer
2460    the external runtime module as part of torch with attribute torch.xxx.
2461    """
2462    # Make sure the device_type represent a supported device type for torch.
2463    device_type = torch.device(device_type).type
2464    m = sys.modules[__name__]
2465    if hasattr(m, device_type):
2466        raise RuntimeError(
2467            f"The runtime module of '{device_type}' has already "
2468            f"been registered with '{getattr(m, device_type)}'"
2469        )
2470    setattr(m, device_type, module)
2471    torch_module_name = ".".join([__name__, device_type])
2472    sys.modules[torch_module_name] = module
2473
2474
2475from torch import (
2476    export as export,
2477    func as func,
2478    library as library,
2479    return_types as return_types,
2480)
2481from torch._higher_order_ops import cond as cond, while_loop as while_loop
2482from torch.func import vmap as vmap
2483
2484
2485if not TYPE_CHECKING:
2486    from torch import _meta_registrations
2487
2488# Enable CUDA Sanitizer
2489if "TORCH_CUDA_SANITIZER" in os.environ:
2490    import torch.cuda._sanitizer as csan
2491
2492    csan.enable_cuda_sanitizer()
2493
2494# Populate magic methods on SymInt and SymFloat
2495import torch.fx.experimental.sym_node
2496
2497
2498# Register MPS specific decomps
2499torch.backends.mps._init()
2500
2501if not _running_with_deploy():
2502    from torch import compiler as compiler
2503
2504    class _TritonLibrary:
2505        lib = torch.library.Library("triton", "DEF")
2506        ops_table: _Dict[_Tuple[str, str], _Callable] = {}
2507
2508        @classmethod
2509        def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
2510            if (op_key, dispatch_key) not in cls.ops_table:
2511                cls.lib.define(full_schema)
2512                cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
2513                cls.ops_table[(op_key, dispatch_key)] = op_impl
2514
2515            return cls.ops_table[(op_key, dispatch_key)]
2516
2517
2518# Deprecated attributes
2519_deprecated_attrs = {
2520    "has_mps": torch.backends.mps.is_built,
2521    "has_cuda": torch.backends.cuda.is_built,
2522    "has_cudnn": torch.backends.cudnn.is_available,
2523    "has_mkldnn": torch.backends.mkldnn.is_available,
2524}
2525
2526if TYPE_CHECKING:
2527    # Import the following modules during type checking to enable code intelligence features,
2528    # such as auto-completion in tools like pylance, even when these modules are not explicitly
2529    # imported in user code.
2530    from torch import (
2531        _dynamo as _dynamo,
2532        _inductor as _inductor,
2533        _subclasses as _subclasses,
2534        onnx as onnx,
2535    )
2536
2537else:
2538    _lazy_modules = {
2539        "_dynamo",
2540        "_inductor",
2541        "_export",
2542        # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
2543        "onnx",
2544    }
2545
2546    def __getattr__(name):
2547        # Deprecated attrs
2548        replacement = _deprecated_attrs.get(name)
2549        if replacement is not None:
2550            import warnings
2551
2552            warnings.warn(
2553                f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",
2554                stacklevel=2,
2555            )
2556            return replacement()
2557
2558        # Lazy modules
2559        if name in _lazy_modules:
2560            return importlib.import_module(f".{name}", __name__)
2561
2562        raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
2563
2564
2565def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
2566    """
2567    Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
2568    If no device is given, return the module for the current accelerator or CPU if none is present.
2569    """
2570    if isinstance(device, torch.device):
2571        device_module_name = device.type
2572    elif isinstance(device, str):
2573        device_module_name = torch.device(device).type
2574    elif device is None:
2575        # Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
2576        device_module_name = torch._C._get_accelerator().type
2577    else:
2578        raise RuntimeError(
2579            f"Invalid value of device '{device}', expect torch.device, str, or None"
2580        )
2581    device_module = getattr(torch, device_module_name, None)
2582    if device_module is None:
2583        raise RuntimeError(
2584            f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
2585        )
2586    return device_module
2587
2588
2589def _constrain_as_size(
2590    symbol,
2591    min: _Optional[builtins.int] = None,
2592    max: _Optional[builtins.int] = None,
2593):
2594    """
2595    This indicates that a given int is size-like, and can be used in any context where a size is expected.
2596    You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
2597    which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
2598      GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
2599
2600    This function has unusual semantics in some circumstances in framework
2601    code, we will treat this int as >= 2 (when we do a size-oblivious guard).
2602    This makes it easier to use the unbacked int in size contexts,
2603    as we will often attempt to guard on a size being zero/one
2604    (e.g., when computing the contiguity of a tensor, or testing if
2605    broadcasting can occur), which will not work on unbacked SymInts.
2606    However, if we conservatively assume that the size is not zero/one, we will
2607    end up with a graph that will still work even if the size is zero/one.
2608
2609    For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
2610    ```
2611    """
2612    torch.sym_constrain_range_for_size(symbol, min=min, max=max)
2613
2614
2615from torch import _logging
2616
2617
2618_logging._init_logs()
2619
2620
2621def _import_device_backends():
2622    """
2623    Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
2624    See this RFC: https://github.com/pytorch/pytorch/issues/122468
2625    """
2626    from importlib.metadata import entry_points
2627
2628    group_name = "torch.backends"
2629    if sys.version_info < (3, 10):
2630        backend_extensions = entry_points().get(group_name, ())
2631    else:
2632        backend_extensions = entry_points(group=group_name)
2633
2634    for backend_extension in backend_extensions:
2635        try:
2636            # Load the extension
2637            entrypoint = backend_extension.load()
2638            # Call the entrypoint
2639            entrypoint()
2640        except Exception as err:
2641            raise RuntimeError(
2642                f"Failed to load the backend extension: {backend_extension.name}. "
2643                f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
2644            ) from err
2645
2646
2647def _is_device_backend_autoload_enabled() -> builtins.bool:
2648    """
2649    Whether autoloading out-of-the-tree device extensions is enabled.
2650    The switch depends on the value of the environment variable
2651    `TORCH_DEVICE_BACKEND_AUTOLOAD`.
2652
2653    Returns:
2654        bool: Whether to enable autoloading the extensions. Enabled by default.
2655
2656    Examples:
2657        >>> torch._is_device_backend_autoload_enabled()
2658        True
2659    """
2660    # enabled by default
2661    return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"
2662
2663
2664if _is_device_backend_autoload_enabled():
2665    _import_device_backends()
2666